diff --git a/pyproject.toml b/pyproject.toml index fcfd6f3e..56862a49 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,6 +83,7 @@ "--showlocals", "--strict-markers", "--strict-config", + "--doctest-glob='*.rst | *.py'", ] filterwarnings = [ "error", @@ -208,6 +209,7 @@ "missing-function-docstring", # TODO: resolve "missing-module-docstring", "no-member", # handled by mypy + "no-value-for-parameter", # pylint doesn't understand multiple dispatch "not-a-mapping", # pylint doesn't understand dataclass fields "protected-access", # handled by ruff "redefined-builtin", # handled by ruff diff --git a/src/coordinax/_base_dif.py b/src/coordinax/_base_dif.py index 78547341..7a1ef099 100644 --- a/src/coordinax/_base_dif.py +++ b/src/coordinax/_base_dif.py @@ -45,11 +45,11 @@ def integral_cls(cls) -> type["AbstractVectorDifferential"]: -------- >>> from coordinax import RadialDifferential, SphericalDifferential - >>> RadialDifferential.integral_cls - + >>> RadialDifferential.integral_cls.__name__ + 'RadialVector' - >>> SphericalDifferential.integral_cls - + >>> SphericalDifferential.integral_cls.__name__ + 'SphericalVector' """ raise NotImplementedError diff --git a/src/coordinax/_base_vec.py b/src/coordinax/_base_vec.py index d650d7cc..7f8bb10b 100644 --- a/src/coordinax/_base_vec.py +++ b/src/coordinax/_base_vec.py @@ -51,11 +51,11 @@ def differential_cls(cls) -> type["AbstractVectorDifferential"]: -------- >>> from coordinax import RadialVector, SphericalVector - >>> RadialVector.differential_cls - + >>> RadialVector.differential_cls.__name__ + 'RadialDifferential' - >>> SphericalVector.differential_cls - + >>> SphericalVector.differential_cls.__name__ + 'SphericalDifferential' """ raise NotImplementedError diff --git a/src/coordinax/_d2/transform.py b/src/coordinax/_d2/transform.py index 414efd89..faa2e19a 100644 --- a/src/coordinax/_d2/transform.py +++ b/src/coordinax/_d2/transform.py @@ -32,8 +32,6 @@ def represent_as( @dispatch.multi( (Cartesian2DVector, type[Cartesian2DVector]), (PolarVector, type[PolarVector]), - # (LnPolarVector, type[LnPolarVector]), - # (Log10PolarVector, type[Log10PolarVector]), ) def represent_as( current: Abstract2DVector, target: type[Abstract2DVector], /, **kwargs: Any @@ -57,20 +55,6 @@ def represent_as( return current -# @dispatch.multi( -# (Cartesian2DVector, type[LnPolarVector]), -# (Cartesian2DVector, type[Log10PolarVector]), -# (LnPolarVector, type[Cartesian2DVector]), -# (Log10PolarVector, type[Cartesian2DVector]), -# ) -# def represent_as( -# current: Abstract2DVector, target: type[Abstract2DVector], /, **kwargs: Any -# ) -> Abstract2DVector: -# """Abstract2DVector -> PolarVector -> Abstract2DVector.""" -# polar = represent_as(current, PolarVector) -# return represent_as(polar, target) - - # ============================================================================= # Cartesian2DVector @@ -107,65 +91,3 @@ def represent_as( x = current.r * xp.cos(current.phi) y = current.r * xp.sin(current.phi) return target(x=x, y=y) - - -# @dispatch -# def represent_as( -# current: PolarVector, target: type[LnPolarVector], /, **kwargs: Any -# ) -> LnPolarVector: -# """PolarVector -> LnPolarVector.""" -# return target(lnr=xp.log(current.r), phi=current.phi) - - -# @dispatch -# def represent_as( -# current: PolarVector, target: type[Log10PolarVector], /, **kwargs: Any -# ) -> Log10PolarVector: -# """PolarVector -> Log10PolarVector.""" -# return target(log10r=xp.log10(current.r), phi=current.phi) - - -# # ============================================================================= -# # LnPolarVector - -# # ----------------------------------------------- -# # 2D - - -# @dispatch -# def represent_as( -# current: LnPolarVector, target: type[PolarVector], /, **kwargs: Any -# ) -> PolarVector: -# """LnPolarVector -> PolarVector.""" -# return target(r=xp.exp(current.lnr), phi=current.phi) - - -# @dispatch -# def represent_as( -# current: LnPolarVector, target: type[Log10PolarVector], /, **kwargs: Any -# ) -> Log10PolarVector: -# """LnPolarVector -> Log10PolarVector.""" -# return target(log10r=current.lnr / xp.log(10), phi=current.phi) - - -# # ============================================================================= -# # Log10PolarVector - -# # ----------------------------------------------- -# # 2D - - -# @dispatch -# def represent_as( -# current: Log10PolarVector, target: type[PolarVector], /, **kwargs: Any -# ) -> PolarVector: -# """Log10PolarVector -> PolarVector.""" -# return target(r=xp.pow(10, current.log10r), phi=current.phi) - - -# @dispatch -# def represent_as( -# current: Log10PolarVector, target: type[LnPolarVector], /, **kwargs: Any -# ) -> LnPolarVector: -# """Log10PolarVector -> LnPolarVector.""" -# return target(lnr=current.log10r * xp.log(10), phi=current.phi) diff --git a/src/coordinax/_d3/__init__.py b/src/coordinax/_d3/__init__.py index f90673bc..438db47d 100644 --- a/src/coordinax/_d3/__init__.py +++ b/src/coordinax/_d3/__init__.py @@ -1,16 +1,18 @@ # pylint: disable=duplicate-code """3-dimensional representations.""" -from . import base, builtin, compat, operate, transform +from . import base, builtin, compat, operate, sphere, transform from .base import * from .builtin import * from .compat import * from .operate import * +from .sphere import * from .transform import * __all__: list[str] = [] __all__ += base.__all__ __all__ += builtin.__all__ +__all__ += sphere.__all__ __all__ += transform.__all__ __all__ += operate.__all__ __all__ += compat.__all__ diff --git a/src/coordinax/_d3/builtin.py b/src/coordinax/_d3/builtin.py index 52df7094..a39fe6c8 100644 --- a/src/coordinax/_d3/builtin.py +++ b/src/coordinax/_d3/builtin.py @@ -3,11 +3,9 @@ __all__ = [ # Position "Cartesian3DVector", - "SphericalVector", "CylindricalVector", # Differential "CartesianDifferential3D", - "SphericalDifferential", "CylindricalDifferential", ] @@ -19,12 +17,12 @@ import jax import quaxed.array_api as xp -from unxt import Distance, Quantity +from unxt import Quantity import coordinax._typing as ct from .base import Abstract3DVector, Abstract3DVectorDifferential from coordinax._base_vec import AbstractVector -from coordinax._checks import check_phi_range, check_r_non_negative, check_theta_range +from coordinax._checks import check_phi_range, check_r_non_negative from coordinax._converters import converter_phi_to_range from coordinax._utils import classproperty @@ -134,55 +132,6 @@ def norm(self) -> ct.BatchableLength: return xp.sqrt(self.x**2 + self.y**2 + self.z**2) -@final -class SphericalVector(Abstract3DVector): - """Spherical vector representation.""" - - r: ct.BatchableDistance = eqx.field( - converter=partial(Distance.constructor, dtype=float) - ) - r"""Radial distance :math:`r \in [0,+\infty)`.""" - - phi: ct.BatchableAngle = eqx.field( - converter=lambda x: converter_phi_to_range( - Quantity["angle"].constructor(x, dtype=float) # pylint: disable=E1120 - ) - ) - r"""Azimuthal angle :math:`\phi \in [0,360)`.""" - - theta: ct.BatchableAngle = eqx.field( - converter=partial(Quantity["angle"].constructor, dtype=float) - ) - r"""Inclination angle :math:`\phi \in [0,180]`.""" - - def __check_init__(self) -> None: - """Check the validity of the initialisation.""" - check_r_non_negative(self.r) - check_theta_range(self.theta) - check_phi_range(self.phi) - - @classproperty - @classmethod - def differential_cls(cls) -> type["SphericalDifferential"]: - return SphericalDifferential - - @partial(jax.jit) - def norm(self) -> ct.BatchableLength: - """Return the norm of the vector. - - Examples - -------- - >>> from unxt import Quantity - >>> from coordinax import SphericalVector - >>> s = SphericalVector(r=Quantity(3, "kpc"), theta=Quantity(90, "deg"), - ... phi=Quantity(0, "deg")) - >>> s.norm() - Distance(Array(3., dtype=float32), unit='kpc') - - """ - return self.r - - @final class CylindricalVector(Abstract3DVector): """Cylindrical vector representation.""" @@ -277,31 +226,6 @@ def norm(self, _: Abstract3DVector | None = None, /) -> ct.BatchableSpeed: return xp.sqrt(self.d_x**2 + self.d_y**2 + self.d_z**2) -@final -class SphericalDifferential(Abstract3DVectorDifferential): - """Spherical differential representation.""" - - d_r: ct.BatchableSpeed = eqx.field( - converter=partial(Quantity["speed"].constructor, dtype=float) - ) - r"""Radial speed :math:`dr/dt \in [-\infty, \infty].""" - - d_theta: ct.BatchableAngularSpeed = eqx.field( - converter=partial(Quantity["angular speed"].constructor, dtype=float) - ) - r"""Inclination speed :math:`d\theta/dt \in [-\infty, \infty].""" - - d_phi: ct.BatchableAngularSpeed = eqx.field( - converter=partial(Quantity["angular speed"].constructor, dtype=float) - ) - r"""Azimuthal speed :math:`d\phi/dt \in [-\infty, \infty].""" - - @classproperty - @classmethod - def integral_cls(cls) -> type[SphericalVector]: - return SphericalVector - - @final class CylindricalDifferential(Abstract3DVectorDifferential): """Cylindrical differential representation.""" diff --git a/src/coordinax/_d3/compat.py b/src/coordinax/_d3/compat.py index 16ae923c..af92c5b3 100644 --- a/src/coordinax/_d3/compat.py +++ b/src/coordinax/_d3/compat.py @@ -17,6 +17,11 @@ CartesianDifferential3D, CylindricalDifferential, CylindricalVector, +) +from .sphere import ( + LonCosLatSphericalDifferential, + LonLatSphericalDifferential, + LonLatSphericalVector, SphericalDifferential, SphericalVector, ) @@ -34,11 +39,11 @@ def constructor( Examples -------- + >>> import coordinax as cx >>> from astropy.coordinates import CartesianRepresentation - >>> from coordinax import Cartesian3DVector >>> cart = CartesianRepresentation(1, 2, 3, unit="kpc") - >>> vec = Cartesian3DVector.constructor(cart) + >>> vec = cx.Cartesian3DVector.constructor(cart) >>> vec.x Quantity['length'](Array(1., dtype=float32), unit='kpc') @@ -47,6 +52,29 @@ def constructor( return cls(x=obj.x, y=obj.y, z=obj.z) +@CylindricalVector.constructor._f.register # noqa: SLF001 +def constructor( + cls: type[CylindricalVector], obj: apyc.BaseRepresentation +) -> CylindricalVector: + """Construct from a :class:`astropy.coordinates.BaseRepresentation`. + + Examples + -------- + >>> import astropy.units as u + >>> import coordinax as cx + >>> from astropy.coordinates import CylindricalRepresentation + + >>> cyl = CylindricalRepresentation(rho=1 * u.kpc, phi=2 * u.deg, + ... z=30 * u.pc) + >>> vec = cx.CylindricalVector.constructor(cyl) + >>> vec.rho + Quantity['length'](Array(1., dtype=float32), unit='kpc') + + """ + obj = obj.represent_as(apyc.CylindricalRepresentation) + return cls(rho=obj.rho, phi=obj.phi, z=obj.z) + + @SphericalVector.constructor._f.register # noqa: SLF001 def constructor( cls: type[SphericalVector], obj: apyc.BaseRepresentation @@ -56,12 +84,12 @@ def constructor( Examples -------- >>> import astropy.units as u + >>> import coordinax as cx >>> from astropy.coordinates import PhysicsSphericalRepresentation - >>> from coordinax import SphericalVector >>> sph = PhysicsSphericalRepresentation(r=1 * u.kpc, theta=2 * u.deg, ... phi=3 * u.deg) - >>> vec = SphericalVector.constructor(sph) + >>> vec = cx.SphericalVector.constructor(sph) >>> vec.r Distance(Array(1., dtype=float32), unit='kpc') @@ -70,27 +98,27 @@ def constructor( return cls(r=obj.r, phi=obj.phi, theta=obj.theta) -@CylindricalVector.constructor._f.register # noqa: SLF001 +@LonLatSphericalVector.constructor._f.register # noqa: SLF001 def constructor( - cls: type[CylindricalVector], obj: apyc.BaseRepresentation -) -> CylindricalVector: + cls: type[LonLatSphericalVector], obj: apyc.BaseRepresentation +) -> LonLatSphericalVector: """Construct from a :class:`astropy.coordinates.BaseRepresentation`. Examples -------- >>> import astropy.units as u - >>> from astropy.coordinates import CylindricalRepresentation - >>> from coordinax import CylindricalVector + >>> import coordinax as cx + >>> from astropy.coordinates import SphericalRepresentation - >>> cyl = CylindricalRepresentation(rho=1 * u.kpc, phi=2 * u.deg, - ... z=30 * u.pc) - >>> vec = CylindricalVector.constructor(cyl) - >>> vec.rho - Quantity['length'](Array(1., dtype=float32), unit='kpc') + >>> sph = SphericalRepresentation(lon=3 * u.deg, lat=2 * u.deg, + ... distance=1 * u.kpc) + >>> vec = cx.LonLatSphericalVector.constructor(sph) + >>> vec.distance + Distance(Array(1., dtype=float32), unit='kpc') """ - obj = obj.represent_as(apyc.CylindricalRepresentation) - return cls(rho=obj.rho, phi=obj.phi, z=obj.z) + obj = obj.represent_as(apyc.SphericalRepresentation) + return cls(distance=obj.distance, lon=obj.lon, lat=obj.lat) @CartesianDifferential3D.constructor._f.register # noqa: SLF001 @@ -102,11 +130,11 @@ def constructor( Examples -------- >>> import astropy.units as u + >>> import coordinax as cx >>> from astropy.coordinates import CartesianDifferential - >>> from coordinax import CartesianDifferential3D >>> dcart = CartesianDifferential(1, 2, 3, unit="km/s") - >>> dif = CartesianDifferential3D.constructor(dcart) + >>> dif = cx.CartesianDifferential3D.constructor(dcart) >>> dif.d_x Quantity['speed'](Array(1., dtype=float32), unit='km / s') @@ -114,6 +142,28 @@ def constructor( return cls(d_x=obj.d_x, d_y=obj.d_y, d_z=obj.d_z) +@CylindricalDifferential.constructor._f.register # noqa: SLF001 +def constructor( + cls: type[CylindricalDifferential], obj: apyc.CylindricalDifferential +) -> CylindricalDifferential: + """Construct from a :class:`astropy.coordinates.CylindricalDifferential`. + + Examples + -------- + >>> import astropy.units as u + >>> import astropy.coordinates as apyc + >>> import coordinax as cx + + >>> dcyl = apyc.CylindricalDifferential(d_rho=1 * u.km / u.s, d_phi=2 * u.mas/u.yr, + ... d_z=2 * u.km / u.s) + >>> dif = cx.CylindricalDifferential.constructor(dcyl) + >>> dif.d_rho + Quantity['speed'](Array(1., dtype=float32), unit='km / s') + + """ + return cls(d_rho=obj.d_rho, d_phi=obj.d_phi, d_z=obj.d_z) + + @SphericalDifferential.constructor._f.register # noqa: SLF001 def constructor( cls: type[SphericalDifferential], obj: apyc.PhysicsSphericalDifferential @@ -123,12 +173,12 @@ def constructor( Examples -------- >>> import astropy.units as u + >>> import coordinax as cx >>> from astropy.coordinates import PhysicsSphericalDifferential - >>> from coordinax import SphericalDifferential >>> dsph = PhysicsSphericalDifferential(d_r=1 * u.km / u.s, d_theta=2 * u.mas/u.yr, ... d_phi=3 * u.mas/u.yr) - >>> dif = SphericalDifferential.constructor(dsph) + >>> dif = cx.SphericalDifferential.constructor(dsph) >>> dif.d_r Quantity['speed'](Array(1., dtype=float32), unit='km / s') @@ -136,26 +186,58 @@ def constructor( return cls(d_r=obj.d_r, d_phi=obj.d_phi, d_theta=obj.d_theta) -@CylindricalDifferential.constructor._f.register # noqa: SLF001 +@LonLatSphericalDifferential.constructor._f.register # noqa: SLF001 def constructor( - cls: type[CylindricalDifferential], obj: apyc.CylindricalDifferential -) -> CylindricalDifferential: - """Construct from a :class:`astropy.coordinates.CylindricalDifferential`. + cls: type[LonLatSphericalDifferential], obj: apyc.SphericalDifferential +) -> LonLatSphericalDifferential: + """Construct from a :class:`astropy.coordinates.SphericalDifferential`. Examples -------- >>> import astropy.units as u - >>> import astropy.coordinates as apyc - >>> from coordinax import CylindricalDifferential + >>> import coordinax as cx + >>> from astropy.coordinates import SphericalDifferential - >>> dcyl = apyc.CylindricalDifferential(d_rho=1 * u.km / u.s, d_phi=2 * u.mas/u.yr, - ... d_z=2 * u.km / u.s) - >>> dif = CylindricalDifferential.constructor(dcyl) - >>> dif.d_rho + >>> dsph = SphericalDifferential(d_distance=1 * u.km / u.s, + ... d_lon=2 * u.mas/u.yr, + ... d_lat=3 * u.mas/u.yr) + >>> dif = cx.LonLatSphericalDifferential.constructor(dsph) + >>> dif.d_distance Quantity['speed'](Array(1., dtype=float32), unit='km / s') """ - return cls(d_rho=obj.d_rho, d_phi=obj.d_phi, d_z=obj.d_z) + return cls(d_distance=obj.d_distance, d_lon=obj.d_lon, d_lat=obj.d_lat) + + +@LonCosLatSphericalDifferential.constructor._f.register # noqa: SLF001 +def constructor( + cls: type[LonCosLatSphericalDifferential], obj: apyc.SphericalCosLatDifferential +) -> LonCosLatSphericalDifferential: + """Construct from a :class:`astropy.coordinates.SphericalCosLatDifferential`. + + Examples + -------- + >>> import astropy.units as u + >>> import coordinax as cx + >>> from astropy.coordinates import SphericalCosLatDifferential + + >>> dsph = SphericalCosLatDifferential(d_distance=1 * u.km / u.s, + ... d_lon_coslat=2 * u.mas/u.yr, + ... d_lat=3 * u.mas/u.yr) + >>> dif = cx.LonCosLatSphericalDifferential.constructor(dsph) + >>> dif + LonCosLatSphericalDifferential( + d_distance=Quantity[...]( value=f32[], unit=Unit("km / s") ), + d_lon_coslat=Quantity[...]( value=f32[], unit=Unit("mas / yr") ), + d_lat=Quantity[...]( value=f32[], unit=Unit("mas / yr") ) + ) + >>> dif.d_distance + Quantity['speed'](Array(1., dtype=float32), unit='km / s') + + """ + return cls( + d_distance=obj.d_distance, d_lon_coslat=obj.d_lon_coslat, d_lat=obj.d_lat + ) ##################################################################### @@ -277,6 +359,66 @@ def apycart3_to_cart3(obj: apyc.CartesianRepresentation, /) -> Cartesian3DVector return Cartesian3DVector.constructor(obj) +# ===================================== +# CylindricalVector + + +# @conversion_method(CylindricalVector, apyc.BaseRepresentation) +# @conversion_method(CylindricalVector, apyc.CylindricalRepresentation) +def cyl_to_apycyl(obj: CylindricalVector, /) -> apyc.CylindricalRepresentation: + """`coordinax.CylindricalVector` -> `astropy.CylindricalRepresentation`. + + Examples + -------- + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> vec = cx.CylindricalVector(rho=Quantity(1, unit="kpc"), + ... phi=Quantity(2, unit="deg"), + ... z=Quantity(3, unit="pc")) + >>> convert(vec, apyc.CylindricalRepresentation) + + + >>> convert(vec, apyc.BaseRepresentation) + + + """ + return apyc.CylindricalRepresentation( + rho=convert(obj.rho, apyu.Quantity), + phi=convert(obj.phi, apyu.Quantity), + z=convert(obj.z, apyu.Quantity), + ) + + +# TODO: use decorator when https://github.com/beartype/plum/pull/135 +add_conversion_method(CylindricalVector, apyc.BaseRepresentation, cyl_to_apycyl) +add_conversion_method(CylindricalVector, apyc.CylindricalRepresentation, cyl_to_apycyl) + + +@conversion_method(apyc.CylindricalRepresentation, CylindricalVector) # type: ignore[misc] +def apycyl_to_cyl(obj: apyc.CylindricalRepresentation, /) -> CylindricalVector: + """`astropy.CylindricalRepresentation` -> `coordinax.CylindricalVector`. + + Examples + -------- + >>> import astropy.units as u + >>> import coordinax as cx + >>> from astropy.coordinates import CylindricalRepresentation + + >>> cyl = CylindricalRepresentation(rho=1 * u.kpc, phi=2 * u.deg, z=30 * u.pc) + >>> convert(cyl, cx.CylindricalVector) + CylindricalVector( + rho=Quantity[...](value=f32[], unit=Unit("kpc")), + phi=Quantity[...](value=f32[], unit=Unit("deg")), + z=Quantity[...](value=f32[], unit=Unit("pc")) + ) + + """ + return CylindricalVector.constructor(obj) + + # ===================================== # SphericalVector @@ -339,63 +481,66 @@ def apysph_to_sph(obj: apyc.PhysicsSphericalRepresentation, /) -> SphericalVecto # ===================================== -# CylindricalVector +# LonLatSphericalVector -# @conversion_method(CylindricalVector, apyc.BaseRepresentation) -# @conversion_method(CylindricalVector, apyc.CylindricalRepresentation) -def cyl_to_apycyl(obj: CylindricalVector, /) -> apyc.CylindricalRepresentation: - """`coordinax.CylindricalVector` -> `astropy.CylindricalRepresentation`. +# @conversion_method(LonLatSphericalVector, apyc.BaseRepresentation) +# @conversion_method( +# LonLatSphericalVector, apyc.PhysicsSphericalRepresentation +# ) +def lonlatsph_to_apysph(obj: LonLatSphericalVector, /) -> apyc.SphericalRepresentation: + """`coordinax.LonLatSphericalVector` -> `astropy.SphericalRepresentation`. Examples -------- >>> from unxt import Quantity >>> import coordinax as cx - >>> vec = cx.CylindricalVector(rho=Quantity(1, unit="kpc"), - ... phi=Quantity(2, unit="deg"), - ... z=Quantity(3, unit="pc")) - >>> convert(vec, apyc.CylindricalRepresentation) - - - >>> convert(vec, apyc.BaseRepresentation) - + >>> vec = cx.LonLatSphericalVector(lon=Quantity(2, unit="deg"), + ... lat=Quantity(3, unit="deg"), + ... distance=Quantity(1, unit="kpc")) + >>> convert(vec, apyc.SphericalRepresentation) + """ - return apyc.CylindricalRepresentation( - rho=convert(obj.rho, apyu.Quantity), - phi=convert(obj.phi, apyu.Quantity), - z=convert(obj.z, apyu.Quantity), + return apyc.SphericalRepresentation( + lon=convert(obj.lon, apyu.Quantity), + lat=convert(obj.lat, apyu.Quantity), + distance=convert(obj.distance, apyu.Quantity), ) # TODO: use decorator when https://github.com/beartype/plum/pull/135 -add_conversion_method(CylindricalVector, apyc.BaseRepresentation, cyl_to_apycyl) -add_conversion_method(CylindricalVector, apyc.CylindricalRepresentation, cyl_to_apycyl) +add_conversion_method( + LonLatSphericalVector, apyc.BaseRepresentation, lonlatsph_to_apysph +) +add_conversion_method( + LonLatSphericalVector, apyc.SphericalRepresentation, lonlatsph_to_apysph +) -@conversion_method(apyc.CylindricalRepresentation, CylindricalVector) # type: ignore[misc] -def apycyl_to_cyl(obj: apyc.CylindricalRepresentation, /) -> CylindricalVector: - """`astropy.CylindricalRepresentation` -> `coordinax.CylindricalVector`. +@conversion_method(apyc.SphericalRepresentation, LonLatSphericalVector) # type: ignore[misc] +def apysph_to_lonlatsph(obj: apyc.SphericalRepresentation, /) -> LonLatSphericalVector: + """`astropy.SphericalRepresentation` -> `coordinax.LonLatSphericalVector`. Examples -------- >>> import astropy.units as u >>> import coordinax as cx - >>> from astropy.coordinates import CylindricalRepresentation - - >>> cyl = CylindricalRepresentation(rho=1 * u.kpc, phi=2 * u.deg, z=30 * u.pc) - >>> convert(cyl, cx.CylindricalVector) - CylindricalVector( - rho=Quantity[...](value=f32[], unit=Unit("kpc")), - phi=Quantity[...](value=f32[], unit=Unit("deg")), - z=Quantity[...](value=f32[], unit=Unit("pc")) + >>> from astropy.coordinates import SphericalRepresentation + + >>> sph = SphericalRepresentation(lon=2 * u.deg, lat=3 * u.deg, + ... distance=1 * u.kpc) + >>> convert(sph, cx.LonLatSphericalVector) + LonLatSphericalVector( + distance=Distance(value=f32[], unit=Unit("kpc")), + lon=Quantity[...](value=f32[], unit=Unit("deg")), + lat=Quantity[...](value=f32[], unit=Unit("deg")) ) """ - return CylindricalVector.constructor(obj) + return LonLatSphericalVector.constructor(obj) # ===================================== @@ -464,6 +609,74 @@ def apycart3_to_diffcart3( return CartesianDifferential3D.constructor(obj) +# ===================================== +# CylindricalDifferential + + +# @conversion_method(CylindricalDifferential, apyc.BaseDifferential) +# @conversion_method( +# CylindricalDifferential, apyc.CylindricalDifferential +# ) +def diffcyl_to_apycyl(obj: CylindricalDifferential, /) -> apyc.CylindricalDifferential: + """`coordinax.CylindricalDifferential` -> `astropy.CylindricalDifferential`. + + Examples + -------- + >>> from unxt import Quantity + >>> import coordinax as cx + >>> import astropy.coordinates as apyc + + >>> dif = cx.CylindricalDifferential(d_rho=Quantity(1, unit="km/s"), + ... d_phi=Quantity(2, unit="mas/yr"), + ... d_z=Quantity(3, unit="km/s")) + >>> convert(dif, apyc.CylindricalDifferential) + + + >>> convert(dif, apyc.BaseDifferential) + + + """ + return apyc.CylindricalDifferential( + d_rho=convert(obj.d_rho, apyu.Quantity), + d_phi=convert(obj.d_phi, apyu.Quantity), + d_z=convert(obj.d_z, apyu.Quantity), + ) + + +# TODO: use decorator when https://github.com/beartype/plum/pull/135 +add_conversion_method(CylindricalDifferential, apyc.BaseDifferential, diffcyl_to_apycyl) +add_conversion_method( + CylindricalDifferential, apyc.CylindricalDifferential, diffcyl_to_apycyl +) + + +@conversion_method( # type: ignore[misc] + apyc.CylindricalDifferential, CylindricalDifferential +) +def apycyl_to_diffcyl(obj: apyc.CylindricalDifferential, /) -> CylindricalDifferential: + """`astropy.CylindricalDifferential` -> `coordinax.CylindricalDifferential`. + + Examples + -------- + >>> import astropy.units as u + >>> import astropy.coordinates as apyc + >>> import coordinax as cx + + >>> dcyl = apyc.CylindricalDifferential(d_rho=1 * u.km / u.s, d_phi=2 * u.mas/u.yr, + ... d_z=2 * u.km / u.s) + >>> convert(dcyl, cx.CylindricalDifferential) + CylindricalDifferential( + d_rho=Quantity[...]( value=f32[], unit=Unit("km / s") ), + d_phi=Quantity[...]( value=f32[], unit=Unit("mas / yr") ), + d_z=Quantity[...]( value=f32[], unit=Unit("km / s") ) + ) + + """ + return CylindricalDifferential.constructor(obj) + + # ===================================== # SphericalDifferential @@ -527,8 +740,8 @@ def apysph_to_diffsph( >>> convert(dif, cx.SphericalDifferential) SphericalDifferential( d_r=Quantity[...]( value=f32[], unit=Unit("km / s") ), - d_theta=Quantity[...]( value=f32[], unit=Unit("mas / yr") ), - d_phi=Quantity[...]( value=f32[], unit=Unit("mas / yr") ) + d_phi=Quantity[...]( value=f32[], unit=Unit("mas / yr") ), + d_theta=Quantity[...]( value=f32[], unit=Unit("mas / yr") ) ) """ @@ -536,71 +749,149 @@ def apysph_to_diffsph( # ===================================== -# CylindricalDifferential +# LonLatSphericalDifferential -# @conversion_method(CylindricalDifferential, apyc.BaseDifferential) +# @conversion_method(LonLatSphericalDifferential, apyc.BaseDifferential) # @conversion_method( -# CylindricalDifferential, apyc.CylindricalDifferential +# LonLatSphericalDifferential, apyc.SphericalDifferential # ) -def diffcyl_to_apycyl(obj: CylindricalDifferential, /) -> apyc.CylindricalDifferential: - """`coordinax.CylindricalDifferential` -> `astropy.CylindricalDifferential`. +def difflonlatsph_to_apysph( + obj: LonLatSphericalDifferential, / +) -> apyc.SphericalDifferential: + """LonLatSphericalDifferential -> `astropy.SphericalDifferential`. Examples -------- >>> from unxt import Quantity >>> import coordinax as cx - >>> import astropy.coordinates as apyc - >>> dif = cx.CylindricalDifferential(d_rho=Quantity(1, unit="km/s"), - ... d_phi=Quantity(2, unit="mas/yr"), - ... d_z=Quantity(3, unit="km/s")) - >>> convert(dif, apyc.CylindricalDifferential) - + >>> dif = cx.LonLatSphericalDifferential(d_distance=Quantity(1, unit="km/s"), + ... d_lat=Quantity(2, unit="mas/yr"), + ... d_lon=Quantity(3, unit="mas/yr")) + >>> convert(dif, apyc.SphericalDifferential) + >>> convert(dif, apyc.BaseDifferential) - + """ - return apyc.CylindricalDifferential( - d_rho=convert(obj.d_rho, apyu.Quantity), - d_phi=convert(obj.d_phi, apyu.Quantity), - d_z=convert(obj.d_z, apyu.Quantity), + return apyc.SphericalDifferential( + d_distance=convert(obj.d_distance, apyu.Quantity), + d_lon=convert(obj.d_lon, apyu.Quantity), + d_lat=convert(obj.d_lat, apyu.Quantity), ) # TODO: use decorator when https://github.com/beartype/plum/pull/135 -add_conversion_method(CylindricalDifferential, apyc.BaseDifferential, diffcyl_to_apycyl) add_conversion_method( - CylindricalDifferential, apyc.CylindricalDifferential, diffcyl_to_apycyl + LonLatSphericalDifferential, apyc.BaseDifferential, difflonlatsph_to_apysph +) +add_conversion_method( + LonLatSphericalDifferential, apyc.SphericalDifferential, difflonlatsph_to_apysph ) @conversion_method( # type: ignore[misc] - apyc.CylindricalDifferential, CylindricalDifferential + apyc.SphericalDifferential, LonLatSphericalDifferential ) -def apycyl_to_diffcyl(obj: apyc.CylindricalDifferential, /) -> CylindricalDifferential: - """`astropy.CylindricalDifferential` -> `coordinax.CylindricalDifferential`. +def apysph_to_difflonlatsph( + obj: apyc.SphericalDifferential, / +) -> LonLatSphericalDifferential: + """`astropy.SphericalDifferential` -> LonLatSphericalDifferential. Examples -------- >>> import astropy.units as u - >>> import astropy.coordinates as apyc >>> import coordinax as cx - - >>> dcyl = apyc.CylindricalDifferential(d_rho=1 * u.km / u.s, d_phi=2 * u.mas/u.yr, - ... d_z=2 * u.km / u.s) - >>> convert(dcyl, cx.CylindricalDifferential) - CylindricalDifferential( - d_rho=Quantity[...]( value=f32[], unit=Unit("km / s") ), - d_phi=Quantity[...]( value=f32[], unit=Unit("mas / yr") ), - d_z=Quantity[...]( value=f32[], unit=Unit("km / s") ) + >>> from astropy.coordinates import SphericalDifferential + + >>> dif = SphericalDifferential(d_distance=1 * u.km / u.s, d_lat=2 * u.mas/u.yr, + ... d_lon=3 * u.mas/u.yr) + >>> convert(dif, cx.LonLatSphericalDifferential) + LonLatSphericalDifferential( + d_distance=Quantity[...]( value=f32[], unit=Unit("km / s") ), + d_lon=Quantity[...]( value=f32[], unit=Unit("mas / yr") ), + d_lat=Quantity[...]( value=f32[], unit=Unit("mas / yr") ) ) """ - return CylindricalDifferential.constructor(obj) + return LonLatSphericalDifferential.constructor(obj) -##################################################################### +# ===================================== +# LonCosLatSphericalDifferential + + +# @conversion_method(LonCosLatSphericalDifferential, apyc.BaseDifferential) +# @conversion_method( +# LonCosLatSphericalDifferential, apyc.SphericalCosLatDifferential +# ) +def diffloncoslatsph_to_apysph( + obj: LonCosLatSphericalDifferential, / +) -> apyc.SphericalCosLatDifferential: + """LonCosLatSphericalDifferential -> `astropy.SphericalCosLatDifferential`. + + Examples + -------- + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> dif = cx.LonCosLatSphericalDifferential(d_distance=Quantity(1, unit="km/s"), + ... d_lat=Quantity(2, unit="mas/yr"), + ... d_lon_coslat=Quantity(3, unit="mas/yr")) + >>> convert(dif, apyc.SphericalCosLatDifferential) + + + >>> convert(dif, apyc.BaseDifferential) + + + """ # noqa: E501 + return apyc.SphericalCosLatDifferential( + d_distance=convert(obj.d_distance, apyu.Quantity), + d_lon_coslat=convert(obj.d_lon_coslat, apyu.Quantity), + d_lat=convert(obj.d_lat, apyu.Quantity), + ) + + +# TODO: use decorator when https://github.com/beartype/plum/pull/135 +add_conversion_method( + LonCosLatSphericalDifferential, apyc.BaseDifferential, diffloncoslatsph_to_apysph +) +add_conversion_method( + LonCosLatSphericalDifferential, + apyc.SphericalCosLatDifferential, + diffloncoslatsph_to_apysph, +) + + +@conversion_method( # type: ignore[misc] + apyc.SphericalCosLatDifferential, LonCosLatSphericalDifferential +) +def apysph_to_diffloncoslatsph( + obj: apyc.SphericalCosLatDifferential, / +) -> LonCosLatSphericalDifferential: + """`astropy.SphericalCosLatDifferential` -> LonCosLatSphericalDifferential. + + Examples + -------- + >>> import astropy.units as u + >>> import coordinax as cx + >>> from astropy.coordinates import SphericalCosLatDifferential + + >>> dif = SphericalCosLatDifferential(d_distance=1 * u.km / u.s, + ... d_lat=2 * u.mas/u.yr, + ... d_lon_coslat=3 * u.mas/u.yr) + >>> convert(dif, cx.LonCosLatSphericalDifferential) + LonCosLatSphericalDifferential( + d_distance=Quantity[...]( value=f32[], unit=Unit("km / s") ), + d_lon_coslat=Quantity[...]( value=f32[], unit=Unit("mas / yr") ), + d_lat=Quantity[...]( value=f32[], unit=Unit("mas / yr") ) + ) + + """ + return LonCosLatSphericalDifferential.constructor(obj) diff --git a/src/coordinax/_d3/sphere.py b/src/coordinax/_d3/sphere.py new file mode 100644 index 00000000..37d1f7d8 --- /dev/null +++ b/src/coordinax/_d3/sphere.py @@ -0,0 +1,347 @@ +"""Built-in vector classes.""" + +__all__ = [ + "AbstractSphericalVector", + "AbstractSphericalDifferential", + # Physics conventions + "SphericalVector", + "SphericalDifferential", + # Mathematics conventions + "MathSphericalVector", + "MathSphericalDifferential", + # Geographic / Astronomical conventions + "LonLatSphericalVector", + "LonLatSphericalDifferential", + "LonCosLatSphericalDifferential", +] + +from abc import abstractmethod +from functools import partial +from typing import final + +import equinox as eqx +import jax + +from unxt import Distance, Quantity + +import coordinax._typing as ct +from .base import Abstract3DVector, Abstract3DVectorDifferential +from coordinax._checks import check_phi_range, check_r_non_negative, check_theta_range +from coordinax._converters import converter_phi_to_range +from coordinax._utils import classproperty + +############################################################################## +# Position + + +class AbstractSphericalVector(Abstract3DVector): + """Abstract spherical vector representation.""" + + @classproperty + @classmethod + @abstractmethod + def differential_cls(cls) -> type["AbstractSphericalDifferential"]: ... + + +@final +class SphericalVector(AbstractSphericalVector): + """Spherical vector representation. + + .. note:: + + This class follows the Physics conventions (ISO 80000-2:2019). + + Parameters + ---------- + r : Distance + Radial distance r (slant distance to origin), + phi : Quantity['angle'] + Azimuthal angle [0, 360) [deg] where 0 is the x-axis. + theta : Quantity['angle'] + Polar angle [0, 180] [deg] where 0 is the z-axis. + + """ + + r: ct.BatchableDistance = eqx.field( + converter=partial(Distance.constructor, dtype=float) + ) + r"""Radial distance :math:`r \in [0,+\infty)`.""" + + phi: ct.BatchableAngle = eqx.field( + converter=lambda x: converter_phi_to_range( + Quantity["angle"].constructor(x, dtype=float) # pylint: disable=E1120 + ) + ) + r"""Azimuthal angle :math:`\phi \in [0,360)`.""" + + theta: ct.BatchableAngle = eqx.field( + converter=partial(Quantity["angle"].constructor, dtype=float) + ) + r"""Inclination angle :math:`\phi \in [0,180]`.""" + + def __check_init__(self) -> None: + """Check the validity of the initialization.""" + check_r_non_negative(self.r) + check_theta_range(self.theta) + check_phi_range(self.phi) + + @classproperty + @classmethod + def differential_cls(cls) -> type["SphericalDifferential"]: + return SphericalDifferential + + @partial(jax.jit) + def norm(self) -> ct.BatchableDistance: + """Return the norm of the vector. + + Examples + -------- + >>> from unxt import Quantity + >>> from coordinax import SphericalVector + >>> s = SphericalVector(r=Quantity(3, "kpc"), phi=Quantity(0, "deg"), + ... theta=Quantity(90, "deg")) + >>> s.norm() + Distance(Array(3., dtype=float32), unit='kpc') + + """ + return self.r + + +@final +class MathSphericalVector(AbstractSphericalVector): + """Spherical vector representation. + + .. note:: + + This class follows the Mathematics conventions. + + Parameters + ---------- + r : Distance + Radial distance r (slant distance to origin), + theta : Quantity['angle'] + Azimuthal angle [0, 360) [deg] where 0 is the x-axis. + phi : Quantity['angle'] + Polar angle [0, 180] [deg] where 0 is the z-axis. + + """ + + r: ct.BatchableDistance = eqx.field( + converter=partial(Distance.constructor, dtype=float) + ) + r"""Radial distance :math:`r \in [0,+\infty)`.""" + + theta: ct.BatchableAngle = eqx.field( + converter=lambda x: converter_phi_to_range( + Quantity["angle"].constructor(x, dtype=float) # pylint: disable=E1120 + ) + ) + r"""Azimuthal angle :math:`\phi \in [0,360)`.""" + + phi: ct.BatchableAngle = eqx.field( + converter=partial(Quantity["angle"].constructor, dtype=float) + ) + r"""Inclination angle :math:`\phi \in [0,180]`.""" + + def __check_init__(self) -> None: + """Check the validity of the initialization.""" + check_r_non_negative(self.r) + check_theta_range(self.phi) + check_phi_range(self.theta) + + @classproperty + @classmethod + def differential_cls(cls) -> type["MathSphericalDifferential"]: + return MathSphericalDifferential + + @partial(jax.jit) + def norm(self) -> ct.BatchableDistance: + """Return the norm of the vector. + + Examples + -------- + >>> from unxt import Quantity + >>> from coordinax import MathSphericalVector + >>> s = MathSphericalVector(r=Quantity(3, "kpc"), theta=Quantity(90, "deg"), + ... phi=Quantity(0, "deg")) + >>> s.norm() + Distance(Array(3., dtype=float32), unit='kpc') + + """ + return self.r + + +@final +class LonLatSphericalVector(AbstractSphericalVector): + """Spherical vector representation. + + .. note:: + + This class follows the Geographic / Astronomical convention. + + Parameters + ---------- + distance : Distance + Radial distance r (slant distance to origin), + lon : Quantity['angle'] + The longitude (azimuthal) angle [0, 360) [deg] where 0 is the x-axis. + lat : Quantity['angle'] + The latitude (polar angle) [-90, 90] [deg] where 90 is the z-axis. + + """ + + distance: ct.BatchableDistance = eqx.field( + converter=partial(Distance.constructor, dtype=float) + ) + r"""Radial distance :math:`r \in [0,+\infty)`.""" + + lon: ct.BatchableAngle = eqx.field( + converter=lambda x: converter_phi_to_range( + Quantity["angle"].constructor(x, dtype=float) # pylint: disable=E1120 + ) + ) + r"""Longitude angle :math:`\phi \in [0,360)`.""" + + lat: ct.BatchableAngle = eqx.field( + converter=lambda x: Quantity["angle"].constructor(x, dtype=float) # pylint: disable=E1120 + ) + r"""Latitude angle :math:`\phi \in [-90,90]`.""" + + def __check_init__(self) -> None: + """Check the validity of the initialization.""" + check_r_non_negative(self.distance) + check_phi_range(self.lon) + check_theta_range(self.lat) + + @classproperty + @classmethod + def differential_cls(cls) -> type["LonLatSphericalDifferential"]: + return LonLatSphericalDifferential + + @partial(jax.jit) + def norm(self) -> ct.BatchableDistance: + """Return the norm of the vector. + + Examples + -------- + >>> from unxt import Quantity + >>> from coordinax import LonLatSphericalVector + >>> s = LonLatSphericalVector(lon=Quantity(0, "deg"), lat=Quantity(90, "deg"), + ... distance=Quantity(3, "kpc")) + >>> s.norm() + Distance(Array(3., dtype=float32), unit='kpc') + + """ + return self.distance + + +############################################################################## + + +class AbstractSphericalDifferential(Abstract3DVectorDifferential): + """Spherical differential representation.""" + + @classproperty + @classmethod + @abstractmethod + def integral_cls(cls) -> type[SphericalVector]: ... + + +@final +class SphericalDifferential(Abstract3DVectorDifferential): + """Spherical differential representation.""" + + d_r: ct.BatchableSpeed = eqx.field( + converter=partial(Quantity["speed"].constructor, dtype=float) + ) + r"""Radial speed :math:`dr/dt \in [-\infty, \infty].""" + + d_phi: ct.BatchableAngularSpeed = eqx.field( + converter=partial(Quantity["angular speed"].constructor, dtype=float) + ) + r"""Azimuthal speed :math:`d\phi/dt \in [-\infty, \infty].""" + + d_theta: ct.BatchableAngularSpeed = eqx.field( + converter=partial(Quantity["angular speed"].constructor, dtype=float) + ) + r"""Inclination speed :math:`d\theta/dt \in [-\infty, \infty].""" + + @classproperty + @classmethod + def integral_cls(cls) -> type[SphericalVector]: + return SphericalVector + + +@final +class MathSphericalDifferential(Abstract3DVectorDifferential): + """Spherical differential representation.""" + + d_r: ct.BatchableSpeed = eqx.field( + converter=partial(Quantity["speed"].constructor, dtype=float) + ) + r"""Radial speed :math:`dr/dt \in [-\infty, \infty].""" + + d_theta: ct.BatchableAngularSpeed = eqx.field( + converter=partial(Quantity["angular speed"].constructor, dtype=float) + ) + r"""Azimuthal speed :math:`d\theta/dt \in [-\infty, \infty].""" + + d_phi: ct.BatchableAngularSpeed = eqx.field( + converter=partial(Quantity["angular speed"].constructor, dtype=float) + ) + r"""Inclination speed :math:`d\phi/dt \in [-\infty, \infty].""" + + @classproperty + @classmethod + def integral_cls(cls) -> type[MathSphericalVector]: + return MathSphericalVector + + +@final +class LonLatSphericalDifferential(Abstract3DVectorDifferential): + """Spherical differential representation.""" + + d_distance: ct.BatchableSpeed = eqx.field( + converter=partial(Quantity["speed"].constructor, dtype=float) + ) + r"""Radial speed :math:`dr/dt \in [-\infty, \infty].""" + + d_lon: ct.BatchableAngularSpeed = eqx.field( + converter=partial(Quantity["angular speed"].constructor, dtype=float) + ) + r"""Longitude speed :math:`d\theta/dt \in [-\infty, \infty].""" + + d_lat: ct.BatchableAngularSpeed = eqx.field( + converter=partial(Quantity["angular speed"].constructor, dtype=float) + ) + r"""Latitude speed :math:`d\phi/dt \in [-\infty, \infty].""" + + @classproperty + @classmethod + def integral_cls(cls) -> type[LonLatSphericalVector]: + return LonLatSphericalVector + + +@final +class LonCosLatSphericalDifferential(Abstract3DVectorDifferential): + """Spherical differential representation.""" + + d_distance: ct.BatchableSpeed = eqx.field( + converter=partial(Quantity["speed"].constructor, dtype=float) + ) + r"""Radial speed :math:`dr/dt \in [-\infty, \infty].""" + + d_lon_coslat: ct.BatchableAngularSpeed = eqx.field( + converter=partial(Quantity["angular speed"].constructor, dtype=float) + ) + r"""Longitude * cos(Latitude) speed :math:`d\theta/dt \in [-\infty, \infty].""" + + d_lat: ct.BatchableAngularSpeed = eqx.field( + converter=partial(Quantity["angular speed"].constructor, dtype=float) + ) + r"""Latitude speed :math:`d\phi/dt \in [-\infty, \infty].""" + + @classproperty + @classmethod + def integral_cls(cls) -> type[LonLatSphericalVector]: + return LonLatSphericalVector diff --git a/src/coordinax/_d3/transform.py b/src/coordinax/_d3/transform.py index 86f17674..5e8255bb 100644 --- a/src/coordinax/_d3/transform.py +++ b/src/coordinax/_d3/transform.py @@ -7,6 +7,7 @@ from plum import dispatch import quaxed.array_api as xp +from unxt import Quantity from .base import Abstract3DVector, Abstract3DVectorDifferential from .builtin import ( @@ -14,6 +15,14 @@ CartesianDifferential3D, CylindricalDifferential, CylindricalVector, +) +from .sphere import ( + AbstractSphericalVector, + LonCosLatSphericalDifferential, + LonLatSphericalDifferential, + LonLatSphericalVector, + MathSphericalDifferential, + MathSphericalVector, SphericalDifferential, SphericalVector, ) @@ -33,8 +42,10 @@ def represent_as( @dispatch.multi( (Cartesian3DVector, type[Cartesian3DVector]), - (SphericalVector, type[SphericalVector]), (CylindricalVector, type[CylindricalVector]), + (SphericalVector, type[SphericalVector]), + (LonLatSphericalVector, type[LonLatSphericalVector]), + (MathSphericalVector, type[MathSphericalVector]), ) def represent_as( current: Abstract3DVector, target: type[Abstract3DVector], /, **kwargs: Any @@ -52,6 +63,13 @@ def represent_as( >>> cx.represent_as(vec, cx.Cartesian3DVector) is vec True + Cylindrical to Cylindrical: + + >>> vec = cx.CylindricalVector(rho=Quantity(1, "kpc"), phi=Quantity(2, "deg"), + ... z=Quantity(3, "kpc")) + >>> cx.represent_as(vec, cx.CylindricalVector) is vec + True + Spherical to Spherical: >>> vec = cx.SphericalVector(r=Quantity(1, "kpc"), theta=Quantity(2, "deg"), @@ -59,11 +77,18 @@ def represent_as( >>> cx.represent_as(vec, cx.SphericalVector) is vec True - Cylindrical to Cylindrical: + LonLatSpherical to LonLatSpherical: - >>> vec = cx.CylindricalVector(rho=Quantity(1, "kpc"), phi=Quantity(2, "deg"), - ... z=Quantity(3, "kpc")) - >>> cx.represent_as(vec, cx.CylindricalVector) is vec + >>> vec = cx.LonLatSphericalVector(lon=Quantity(1, "deg"), lat=Quantity(2, "deg"), + ... distance=Quantity(3, "kpc")) + >>> cx.represent_as(vec, cx.LonLatSphericalVector) is vec + True + + MathSpherical to MathSpherical: + + >>> vec = cx.MathSphericalVector(r=Quantity(1, "kpc"), phi=Quantity(2, "deg"), + ... theta=Quantity(3, "deg")) + >>> cx.represent_as(vec, cx.MathSphericalVector) is vec True """ @@ -72,8 +97,15 @@ def represent_as( @dispatch.multi( (CartesianDifferential3D, type[CartesianDifferential3D], AbstractVector), - (SphericalDifferential, type[SphericalDifferential], AbstractVector), (CylindricalDifferential, type[CylindricalDifferential], AbstractVector), + (SphericalDifferential, type[SphericalDifferential], AbstractVector), + (LonLatSphericalDifferential, type[LonLatSphericalDifferential], AbstractVector), + ( + LonCosLatSphericalDifferential, + type[LonCosLatSphericalDifferential], + AbstractVector, + ), + (MathSphericalDifferential, type[MathSphericalDifferential], AbstractVector), ) def represent_as( current: Abstract3DVectorDifferential, @@ -94,17 +126,13 @@ def represent_as( >>> vec = cx.Cartesian3DVector.constructor(Quantity([1, 2, 3], "kpc")) - Cartesian to Cartesian Differential: + Cartesian to Cartesian differential: >>> dif = cx.CartesianDifferential3D.constructor(Quantity([1, 2, 3], "km/s")) >>> cx.represent_as(dif, cx.CartesianDifferential3D, vec) is dif True - >>> dif = cx.SphericalDifferential(d_r=Quantity(1, "km/s"), - ... d_theta=Quantity(2, "mas/yr"), - ... d_phi=Quantity(3, "mas/yr")) - >>> cx.represent_as(dif, cx.SphericalDifferential, vec) is dif - True + Cylindrical to Cylindrical differential: >>> dif = cx.CylindricalDifferential(d_rho=Quantity(1, "km/s"), ... d_phi=Quantity(2, "mas/yr"), @@ -112,6 +140,38 @@ def represent_as( >>> cx.represent_as(dif, cx.CylindricalDifferential, vec) is dif True + Spherical to Spherical differential: + + >>> dif = cx.SphericalDifferential(d_r=Quantity(1, "km/s"), + ... d_phi=Quantity(2, "mas/yr"), + ... d_theta=Quantity(3, "mas/yr")) + >>> cx.represent_as(dif, cx.SphericalDifferential, vec) is dif + True + + LonLatSpherical to LonLatSpherical differential: + + >>> dif = cx.LonLatSphericalDifferential(d_lon=Quantity(1, "mas/yr"), + ... d_lat=Quantity(2, "mas/yr"), + ... d_distance=Quantity(3, "km/s")) + >>> cx.represent_as(dif, cx.LonLatSphericalDifferential, vec) is dif + True + + LonCosLatSpherical to LonCosLatSpherical differential: + + >>> dif = cx.LonCosLatSphericalDifferential(d_lon_coslat=Quantity(1, "mas/yr"), + ... d_lat=Quantity(2, "mas/yr"), + ... d_distance=Quantity(3, "km/s")) + >>> cx.represent_as(dif, cx.LonCosLatSphericalDifferential, vec) is dif + True + + MathSpherical to MathSpherical differential: + + >>> dif = cx.MathSphericalDifferential(d_r=Quantity(1, "km/s"), + ... d_theta=Quantity(2, "mas/yr"), + ... d_phi=Quantity(3, "mas/yr")) + >>> cx.represent_as(dif, cx.MathSphericalDifferential, vec) is dif + True + """ return current @@ -120,6 +180,28 @@ def represent_as( # Cartesian3DVector +@dispatch +def represent_as( + current: Cartesian3DVector, target: type[CylindricalVector], /, **kwargs: Any +) -> CylindricalVector: + """Cartesian3DVector -> CylindricalVector. + + Examples + -------- + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> vec = cx.Cartesian3DVector.constructor(Quantity([1, 2, 3], "km")) + >>> print(cx.represent_as(vec, cx.CylindricalVector)) + + + """ + rho = xp.sqrt(current.x**2 + current.y**2) + phi = xp.atan2(current.y, current.x) + return target(rho=rho, phi=phi, z=current.z) + + @dispatch def represent_as( current: Cartesian3DVector, target: type[SphericalVector], /, **kwargs: Any @@ -138,16 +220,19 @@ def represent_as( """ r = xp.sqrt(current.x**2 + current.y**2 + current.z**2) - theta = xp.acos(current.z / r) phi = xp.atan2(current.y, current.x) - return target(r=r, theta=theta, phi=phi) + theta = xp.acos(current.z / r) + return target(r=r, phi=phi, theta=theta) -@dispatch +@dispatch.multi( + (Cartesian3DVector, type[LonLatSphericalVector]), + (Cartesian3DVector, type[MathSphericalVector]), +) def represent_as( - current: Cartesian3DVector, target: type[CylindricalVector], /, **kwargs: Any -) -> CylindricalVector: - """Cartesian3DVector -> CylindricalVector. + current: Cartesian3DVector, target: type[AbstractSphericalVector], /, **kwargs: Any +) -> AbstractSphericalVector: + """Cartesian3DVector -> AbstractSphericalVector. Examples -------- @@ -155,14 +240,98 @@ def represent_as( >>> import coordinax as cx >>> vec = cx.Cartesian3DVector.constructor(Quantity([1, 2, 3], "km")) - >>> print(cx.represent_as(vec, cx.CylindricalVector)) - + + >>> print(cx.represent_as(vec, cx.LonLatSphericalVector)) + + + >>> print(cx.represent_as(vec, cx.MathSphericalVector)) + """ - rho = xp.sqrt(current.x**2 + current.y**2) - phi = xp.atan2(current.y, current.x) - return target(rho=rho, phi=phi, z=current.z) + return represent_as(represent_as(current, SphericalVector), target) + + +# ============================================================================= +# CylindricalVector + + +@dispatch +def represent_as( + current: CylindricalVector, target: type[Cartesian3DVector], /, **kwargs: Any +) -> Cartesian3DVector: + """CylindricalVector -> Cartesian3DVector. + + Examples + -------- + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> vec = cx.CylindricalVector(rho=Quantity(1., "kpc"), phi=Quantity(90, "deg"), + ... z=Quantity(1, "kpc")) + >>> print(cx.represent_as(vec, cx.Cartesian3DVector)) + + + """ + x = current.rho * xp.cos(current.phi) + y = current.rho * xp.sin(current.phi) + z = current.z + return target(x=x, y=y, z=z) + + +@dispatch +def represent_as( + current: CylindricalVector, target: type[SphericalVector], /, **kwargs: Any +) -> SphericalVector: + """CylindricalVector -> SphericalVector. + + Examples + -------- + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> vec = cx.CylindricalVector(rho=Quantity(1., "kpc"), phi=Quantity(90, "deg"), + ... z=Quantity(1, "kpc")) + >>> print(cx.represent_as(vec, cx.SphericalVector)) + + + """ + r = xp.sqrt(current.rho**2 + current.z**2) + theta = xp.acos(current.z / r) + phi = current.phi + return target(r=r, phi=phi, theta=theta) + + +@dispatch.multi( + (CylindricalVector, type[LonLatSphericalVector]), + (CylindricalVector, type[MathSphericalVector]), +) +def represent_as( + current: CylindricalVector, target: type[AbstractSphericalVector], /, **kwargs: Any +) -> AbstractSphericalVector: + """CylindricalVector -> AbstractSphericalVector. + + Examples + -------- + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> vec = cx.CylindricalVector(rho=Quantity(1., "kpc"), phi=Quantity(90, "deg"), + ... z=Quantity(1, "kpc")) + + >>> print(cx.represent_as(vec, cx.LonLatSphericalVector)) + + + >>> print(cx.represent_as(vec, cx.MathSphericalVector)) + + + """ + return represent_as(represent_as(current, SphericalVector), target) # ============================================================================= @@ -217,27 +386,275 @@ def represent_as( return target(rho=rho, phi=phi, z=z) +@dispatch +def represent_as( + current: SphericalVector, target: type[LonLatSphericalVector], /, **kwargs: Any +) -> LonLatSphericalVector: + """SphericalVector -> LonLatSphericalVector. + + Examples + -------- + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> vec = cx.SphericalVector(r=Quantity(1., "kpc"), phi=Quantity(90, "deg"), + ... theta=Quantity(90, "deg")) + >>> print(cx.represent_as(vec, cx.LonLatSphericalVector)) + + + """ + return target( + distance=current.r, + lon=current.phi, + lat=Quantity(90, "deg") - current.theta, + ) + + +@dispatch +def represent_as( + current: SphericalVector, target: type[MathSphericalVector], /, **kwargs: Any +) -> MathSphericalVector: + """SphericalVector -> MathSphericalVector. + + Examples + -------- + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> vec = cx.SphericalVector(r=Quantity(1., "kpc"), phi=Quantity(90, "deg"), + ... theta=Quantity(90, "deg")) + >>> print(cx.represent_as(vec, cx.MathSphericalVector)) + + + """ + return target(r=current.r, theta=current.phi, phi=current.theta) + + # ============================================================================= -# CylindricalVector +# LonLatSphericalVector @dispatch def represent_as( - current: CylindricalVector, target: type[Cartesian3DVector], /, **kwargs: Any + current: LonLatSphericalVector, target: type[Cartesian3DVector], /, **kwargs: Any ) -> Cartesian3DVector: - """CylindricalVector -> Cartesian3DVector.""" - x = current.rho * xp.cos(current.phi) - y = current.rho * xp.sin(current.phi) - z = current.z + """LonLatSphericalVector -> Cartesian3DVector. + + Examples + -------- + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> vec = cx.LonLatSphericalVector(lon=Quantity(90, "deg"), lat=Quantity(0, "deg"), + ... distance=Quantity(1., "kpc")) + >>> print(cx.represent_as(vec, cx.Cartesian3DVector)) + + + """ + return represent_as(represent_as(current, SphericalVector), Cartesian3DVector) + + +@dispatch +def represent_as( + current: LonLatSphericalVector, target: type[CylindricalVector], /, **kwargs: Any +) -> CylindricalVector: + """LonLatSphericalVector -> CylindricalVector. + + Examples + -------- + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> vec = cx.LonLatSphericalVector(lon=Quantity(90, "deg"), lat=Quantity(0, "deg"), + ... distance=Quantity(1., "kpc")) + >>> print(cx.represent_as(vec, cx.CylindricalVector)) + + + """ + return represent_as(represent_as(current, SphericalVector), target) + + +@dispatch +def represent_as( + current: LonLatSphericalVector, target: type[SphericalVector], /, **kwargs: Any +) -> SphericalVector: + """LonLatSphericalVector -> SphericalVector. + + Examples + -------- + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> vec = cx.LonLatSphericalVector(lon=Quantity(90, "deg"), lat=Quantity(0, "deg"), + ... distance=Quantity(1., "kpc")) + >>> print(cx.represent_as(vec, cx.SphericalVector)) + + + """ + return target( + r=current.distance, phi=current.lon, theta=Quantity(90, "deg") - current.lat + ) + + +# ============================================================================= +# MathSphericalVector + + +@dispatch +def represent_as( + current: MathSphericalVector, target: type[Cartesian3DVector], /, **kwargs: Any +) -> Cartesian3DVector: + """MathSphericalVector -> Cartesian3DVector. + + Examples + -------- + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> vec = cx.MathSphericalVector(r=Quantity(1., "kpc"), theta=Quantity(90, "deg"), + ... phi=Quantity(90, "deg")) + >>> print(cx.represent_as(vec, cx.Cartesian3DVector)) + + + """ + x = current.r * xp.sin(current.phi) * xp.cos(current.theta) + y = current.r * xp.sin(current.phi) * xp.sin(current.theta) + z = current.r * xp.cos(current.phi) return target(x=x, y=y, z=z) @dispatch def represent_as( - current: CylindricalVector, target: type[SphericalVector], /, **kwargs: Any + current: MathSphericalVector, target: type[CylindricalVector], /, **kwargs: Any +) -> CylindricalVector: + """MathSphericalVector -> CylindricalVector. + + Examples + -------- + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> vec = cx.MathSphericalVector(r=Quantity(1., "kpc"), theta=Quantity(90, "deg"), + ... phi=Quantity(90, "deg")) + >>> print(cx.represent_as(vec, cx.CylindricalVector)) + + + """ + rho = xp.abs(current.r * xp.sin(current.phi)) + phi = current.theta + z = current.r * xp.cos(current.phi) + return target(rho=rho, phi=phi, z=z) + + +@dispatch +def represent_as( + current: MathSphericalVector, target: type[SphericalVector], /, **kwargs: Any ) -> SphericalVector: - """CylindricalVector -> SphericalVector.""" - r = xp.sqrt(current.rho**2 + current.z**2) - theta = xp.acos(current.z / r) - phi = current.phi - return target(r=r, theta=theta, phi=phi) + """MathSphericalVector -> SphericalVector. + + Examples + -------- + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> vec = cx.MathSphericalVector(r=Quantity(1., "kpc"), theta=Quantity(90, "deg"), + ... phi=Quantity(90, "deg")) + >>> print(cx.represent_as(vec, cx.SphericalVector)) + + + """ + return target(r=current.r, theta=current.phi, phi=current.theta) + + +# ============================================================================= +# LonLatSphericalDifferential + + +@dispatch +def represent_as( + current: Abstract3DVectorDifferential, + target: type[LonCosLatSphericalDifferential], + position: AbstractVector | Quantity["length"], + /, + **kwargs: Any, +) -> LonCosLatSphericalDifferential: + """Abstract3DVectorDifferential -> LonCosLatSphericalDifferential. + + Examples + -------- + >>> import quaxed.array_api as xp + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> vec = cx.LonLatSphericalVector(lon=Quantity(15, "deg"), lat=Quantity(10, "deg"), + ... distance=Quantity(1.5, "kpc")) + >>> dif = cx.LonLatSphericalDifferential(d_lon=Quantity(7, "mas/yr"), + ... d_lat=Quantity(0, "deg/Gyr"), + ... d_distance=Quantity(-5, "km/s")) + >>> newdif = cx.represent_as(dif, cx.LonCosLatSphericalDifferential, vec) + >>> newdif + LonCosLatSphericalDifferential( + d_distance=Quantity[...]( value=f32[], unit=Unit("km / s") ), + d_lon_coslat=Quantity[...]( value=f32[], unit=Unit("mas / yr") ), + d_lat=Quantity[...]( value=f32[], unit=Unit("deg / Gyr") ) ) + + >>> newdif.d_lon_coslat / xp.cos(vec.lat) # float32 imprecision + Quantity['angular frequency'](Array(6.9999995, dtype=float32), unit='mas / yr') + + """ + # Parse the position to an AbstractVector + if isinstance(position, AbstractVector): + posvec = position + else: # Q -> CartD + posvec = current.integral_cls._cartesian_cls.constructor( # noqa: SLF001 + position + ) + + # Transform the differential to LonLatSphericalDifferential + current = represent_as(current, LonLatSphericalDifferential, posvec) + + # Transform the position to the required type + posvec = represent_as(posvec, current.integral_cls) + + # Calculate the differential in the new system + return target( + d_lon_coslat=current.d_lon * xp.cos(posvec.lat), + d_lat=current.d_lat, + d_distance=current.d_distance, + ) + + +@dispatch +def represent_as( + current: LonCosLatSphericalDifferential, + target: type[LonLatSphericalDifferential], + position: AbstractVector | Quantity["length"], + /, + **kwargs: Any, +) -> LonLatSphericalDifferential: + """LonCosLatSphericalDifferential -> LonLatSphericalDifferential.""" + # Parse the position to an AbstractVector + if isinstance(position, AbstractVector): + posvec = position + else: # Q -> CartD + posvec = current.integral_cls._cartesian_cls.constructor( # noqa: SLF001 + position + ) + + # Transform the position to the required type + posvec = represent_as(posvec, current.integral_cls) + + # Calculate the differential in the new system + return target( + d_lon=current.d_lon_coslat / xp.cos(posvec.lat), + d_lat=current.d_lat, + d_distance=current.d_distance, + ) diff --git a/src/coordinax/_transform.py b/src/coordinax/_transform.py deleted file mode 100644 index 9d8fcbd4..00000000 --- a/src/coordinax/_transform.py +++ /dev/null @@ -1,808 +0,0 @@ -"""Transformations between representations.""" - -__all__ = ["represent_as"] - -from math import prod -from typing import Any -from warnings import warn - -import astropy.units as u -import jax -from plum import dispatch - -import quaxed.array_api as xp -from unxt import Quantity - -from ._base_dif import AbstractVectorDifferential -from ._base_vec import AbstractVector -from ._d1.base import Abstract1DVectorDifferential -from ._d1.builtin import Cartesian1DVector, RadialVector -from ._d2.base import Abstract2DVector, Abstract2DVectorDifferential -from ._d2.builtin import Cartesian2DVector, PolarVector -from ._d3.base import Abstract3DVector, Abstract3DVectorDifferential -from ._d3.builtin import Cartesian3DVector, CylindricalVector, SphericalVector -from ._exceptions import IrreversibleDimensionChange -from ._utils import dataclass_items - - -# TODO: implement for cross-representations -@dispatch.multi( # type: ignore[misc] - # N-D -> N-D - ( - Abstract1DVectorDifferential, - type[Abstract1DVectorDifferential], # type: ignore[misc] - AbstractVector | Quantity["length"], - ), - ( - Abstract2DVectorDifferential, - type[Abstract2DVectorDifferential], # type: ignore[misc] - AbstractVector | Quantity["length"], - ), - ( - Abstract3DVectorDifferential, - type[Abstract3DVectorDifferential], # type: ignore[misc] - AbstractVector | Quantity["length"], - ), -) -def represent_as( - current: AbstractVectorDifferential, - target: type[AbstractVectorDifferential], - position: AbstractVector | Quantity["length"], - /, - **kwargs: Any, -) -> AbstractVectorDifferential: - """AbstractVectorDifferential -> Cartesian -> AbstractVectorDifferential. - - This is the base case for the transformation of vector differentials. - - Parameters - ---------- - current : AbstractVectorDifferential - The vector differential to transform. - target : type[AbstractVectorDifferential] - The target type of the vector differential. - position : AbstractVector - The position vector used to transform the differential. - **kwargs : Any - Additional keyword arguments. - - Examples - -------- - >>> import coordinax as cx - >>> from unxt import Quantity - - Let's start in 1D: - - >>> q = cx.Cartesian1DVector(x=Quantity(1.0, "km")) - >>> p = cx.CartesianDifferential1D(d_x=Quantity(1.0, "km/s")) - >>> cx.represent_as(p, cx.RadialDifferential, q) - RadialDifferential( d_r=Quantity[...]( value=f32[], unit=Unit("km / s") ) ) - - Now in 2D: - - >>> q = cx.Cartesian2DVector.constructor(Quantity([1.0, 2.0], "km")) - >>> p = cx.CartesianDifferential2D.constructor(Quantity([1.0, 2.0], "km/s")) - >>> cx.represent_as(p, cx.PolarDifferential, q) - PolarDifferential( - d_r=Quantity[...]( value=f32[], unit=Unit("km / s") ), - d_phi=Quantity[...]( value=f32[], unit=Unit("rad / s") ) - ) - - And in 3D: - - >>> q = cx.Cartesian3DVector.constructor(Quantity([1.0, 2.0, 3.0], "km")) - >>> p = cx.CartesianDifferential3D.constructor(Quantity([1.0, 2.0, 3.0], "km/s")) - >>> cx.represent_as(p, cx.SphericalDifferential, q) - SphericalDifferential( - d_r=Quantity[...]( value=f32[], unit=Unit("km / s") ), - d_theta=Quantity[...]( value=f32[], unit=Unit("rad / s") ), - d_phi=Quantity[...]( value=f32[], unit=Unit("rad / s") ) - ) - - If given a position as a Quantity, it will be converted to the appropriate - Cartesian vector: - - >>> p = cx.CartesianDifferential3D.constructor(Quantity([1.0, 2.0, 3.0], "km/s")) - >>> cx.represent_as(p, cx.SphericalDifferential, Quantity([1.0, 2.0, 3.0], "km")) - SphericalDifferential( - d_r=Quantity[...]( value=f32[], unit=Unit("km / s") ), - d_theta=Quantity[...]( value=f32[], unit=Unit("rad / s") ), - d_phi=Quantity[...]( value=f32[], unit=Unit("rad / s") ) - ) - - """ - # TODO: not require the shape munging / support more shapes - shape = current.shape - flat_shape = prod(shape) - - # Parse the position to an AbstractVector - if isinstance(position, AbstractVector): - posvec = position - else: # Q -> CartD - posvec = current.integral_cls._cartesian_cls.constructor( # noqa: SLF001 - position - ) - - posvec = posvec.reshape(flat_shape) # flattened - - # Start by transforming the position to the type required by the - # differential to construct the Jacobian. - current_pos = represent_as(posvec, current.integral_cls, **kwargs) - - # Takes the Jacobian through the representation transformation function. This - # returns a representation of the target type, where the value of each field the - # corresponding row of the Jacobian. The value of the field is a Quantity with - # the correct numerator unit (of the Jacobian row). The value is a Vector of the - # original type, with fields that are the columns of that row, but with only the - # denomicator's units. - jac_nested_vecs = jac_rep_as(current_pos, target.integral_cls) - - # This changes the Jacobian to be a dictionary of each row, with the value - # being that row's column as a dictionary, now with the correct units for - # each element: {row_i: {col_j: Quantity(value, row.unit / column.unit)}} - jac_rows = { - f"d_{k}": { - kk: Quantity(vv.value, unit=v.unit / vv.unit) - for kk, vv in dataclass_items(v.value) - } - for k, v in dataclass_items(jac_nested_vecs) - } - - # Now we can use the Jacobian to transform the differential. - flat_current = current.reshape(flat_shape) - return target( - **{ # Each field is the dot product of the row of the J and the diff column. - k: xp.sum( # Doing the dot product. - xp.stack( - tuple( - j_c * getattr(flat_current, f"d_{kk}") - for kk, j_c in j_r.items() - ) - ), - axis=0, - ) - for k, j_r in jac_rows.items() - } - ).reshape(shape) - - -# TODO: situate this better to show how represent_as is used -jac_rep_as = jax.jit( - jax.vmap(jax.jacfwd(represent_as), in_axes=(0, None)), static_argnums=(1,) -) - - -############################################################################### -# 1D - - -# @dispatch.multi( -# (RadialVector, type[LnPolarVector]), -# (RadialVector, type[Log10PolarVector]), -# ) -# def represent_as( -# current: Abstract1DVector, -# target: type[Abstract2DVector], -# /, -# phi: Quantity = Quantity(0.0, u.radian), -# **kwargs: Any, -# ) -> Abstract2DVector: -# """Abstract1DVector -> PolarVector -> Abstract2DVector.""" -# polar = represent_as(current, PolarVector, phi=phi) -# return represent_as(polar, target) - - -# ============================================================================= -# Cartesian1DVector - - -# ----------------------------------------------- -# 2D - - -@dispatch -def represent_as( - current: Cartesian1DVector, - target: type[Cartesian2DVector], - /, - *, - y: Quantity = Quantity(0.0, u.m), - **kwargs: Any, -) -> Cartesian2DVector: - """Cartesian1DVector -> Cartesian2DVector. - - The `x` coordinate is converted to the `x` coordinate of the 2D system. - The `y` coordinate is a keyword argument and defaults to 0. - """ - return target(x=current.x, y=y) - - -@dispatch -def represent_as( - current: Cartesian1DVector, - target: type[PolarVector], - /, - *, - phi: Quantity = Quantity(0.0, u.radian), - **kwargs: Any, -) -> PolarVector: - """Cartesian1DVector -> PolarVector. - - The `x` coordinate is converted to the radial coordinate `r`. - The `phi` coordinate is a keyword argument and defaults to 0. - """ - return target(r=current.x, phi=phi) - - -# @dispatch -# def represent_as( -# current: Cartesian1DVector, -# target: type[LnPolarVector], -# /, -# *, -# phi: Quantity = Quantity(0.0, u.radian), -# **kwargs: Any, -# ) -> LnPolarVector: -# """Cartesian1DVector -> LnPolarVector. - -# The `x` coordinate is converted to the radial coordinate `lnr`. -# The `phi` coordinate is a keyword argument and defaults to 0. -# """ -# return target(lnr=xp.log(current.x), phi=phi) - - -# @dispatch -# def represent_as( -# current: Cartesian1DVector, -# target: type[Log10PolarVector], -# /, -# *, -# phi: Quantity = Quantity(0.0, u.radian), -# **kwargs: Any, -# ) -> Log10PolarVector: -# """Cartesian1DVector -> Log10PolarVector. - -# The `x` coordinate is converted to the radial coordinate `log10r`. -# The `phi` coordinate is a keyword argument and defaults to 0. -# """ -# return target(log10r=xp.log10(current.x), phi=phi) - - -# ----------------------------------------------- -# 3D - - -@dispatch -def represent_as( - current: Cartesian1DVector, - target: type[Cartesian3DVector], - /, - *, - y: Quantity = Quantity(0.0, u.m), - z: Quantity = Quantity(0.0, u.m), - **kwargs: Any, -) -> Cartesian3DVector: - """Cartesian1DVector -> Cartesian3DVector. - - The `x` coordinate is converted to the `x` coordinate of the 3D system. - The `y` and `z` coordinates are keyword arguments and default to 0. - """ - return target(x=current.x, y=y, z=z) - - -@dispatch -def represent_as( - current: Cartesian1DVector, - target: type[SphericalVector], - /, - *, - theta: Quantity = Quantity(0.0, u.radian), - phi: Quantity = Quantity(0.0, u.radian), - **kwargs: Any, -) -> SphericalVector: - """Cartesian1DVector -> SphericalVector. - - The `x` coordinate is converted to the radial coordinate `r`. - The `theta` and `phi` coordinates are keyword arguments and default to 0. - """ - return target(r=current.x, theta=theta, phi=phi) - - -@dispatch -def represent_as( - current: Cartesian1DVector, - target: type[CylindricalVector], - /, - *, - phi: Quantity = Quantity(0.0, u.radian), - z: Quantity = Quantity(0.0, u.m), - **kwargs: Any, -) -> CylindricalVector: - """Cartesian1DVector -> CylindricalVector. - - The `x` coordinate is converted to the radial coordinate `rho`. - The `phi` and `z` coordinates are keyword arguments and default to 0. - """ - return target(rho=current.x, phi=phi, z=z) - - -# ============================================================================= -# RadialVector - -# ----------------------------------------------- -# 2D - - -@dispatch -def represent_as( - current: RadialVector, - target: type[Cartesian2DVector], - /, - *, - y: Quantity = Quantity(0.0, u.m), - **kwargs: Any, -) -> Cartesian2DVector: - """RadialVector -> Cartesian2DVector. - - The `r` coordinate is converted to the cartesian coordinate `x`. - The `y` coordinate is a keyword argument and defaults to 0. - """ - return target(x=current.r, y=y) - - -@dispatch -def represent_as( - current: RadialVector, - target: type[PolarVector], - /, - *, - phi: Quantity = Quantity(0.0, u.radian), - **kwargs: Any, -) -> PolarVector: - """RadialVector -> PolarVector. - - The `r` coordinate is converted to the radial coordinate `r`. - The `phi` coordinate is a keyword argument and defaults to 0. - """ - return target(r=current.r, phi=phi) - - -# ----------------------------------------------- -# 3D - - -@dispatch -def represent_as( - current: RadialVector, - target: type[Cartesian3DVector], - /, - *, - y: Quantity = Quantity(0.0, u.m), - z: Quantity = Quantity(0.0, u.m), - **kwargs: Any, -) -> Cartesian3DVector: - """RadialVector -> Cartesian3DVector. - - The `r` coordinate is converted to the `x` coordinate of the 3D system. - The `y` and `z` coordinates are keyword arguments and default to 0. - """ - return target(x=current.r, y=y, z=z) - - -@dispatch -def represent_as( - current: RadialVector, - target: type[SphericalVector], - /, - *, - theta: Quantity = Quantity(0.0, u.radian), - phi: Quantity = Quantity(0.0, u.radian), - **kwargs: Any, -) -> SphericalVector: - """RadialVector -> SphericalVector. - - The `r` coordinate is converted to the radial coordinate `r`. - The `theta` and `phi` coordinates are keyword arguments and default to 0. - """ - return target(r=current.r, theta=theta, phi=phi) - - -@dispatch -def represent_as( - current: RadialVector, - target: type[CylindricalVector], - /, - *, - phi: Quantity = Quantity(0.0, u.radian), - z: Quantity = Quantity(0.0, u.m), - **kwargs: Any, -) -> CylindricalVector: - """RadialVector -> CylindricalVector. - - The `r` coordinate is converted to the radial coordinate `rho`. - The `phi` and `z` coordinates are keyword arguments and default to 0. - """ - return target(rho=current.r, phi=phi, z=z) - - -############################################################################### -# 2D - - -@dispatch.multi( - (Cartesian2DVector, type[SphericalVector]), - (Cartesian2DVector, type[CylindricalVector]), -) -def represent_as( - current: Abstract2DVector, - target: type[Abstract3DVector], - /, - z: Quantity = Quantity(0.0, u.m), - **kwargs: Any, -) -> Abstract3DVector: - """Abstract2DVector -> Cartesian2D -> Cartesian3D -> Abstract3DVector. - - The 2D vector is in the xy plane. The `z` coordinate is a keyword argument and - defaults to 0. - """ - cart2 = represent_as(current, Cartesian2DVector) - cart3 = represent_as(cart2, Cartesian3DVector, z=z) - return represent_as(cart3, target) - - -@dispatch.multi( - (PolarVector, type[Cartesian3DVector]), - # (LnPolarVector, type[Cartesian3DVector]), - # (LnPolarVector, type[CylindricalVector]), - # (LnPolarVector, type[SphericalVector]), - # (Log10PolarVector, type[Cartesian3DVector]), - # (Log10PolarVector, type[CylindricalVector]), - # (Log10PolarVector, type[SphericalVector]), -) -def represent_as( - current: Abstract2DVector, - target: type[Abstract3DVector], - /, - z: Quantity = Quantity(0.0, u.m), - **kwargs: Any, -) -> Abstract3DVector: - """Abstract2DVector -> PolarVector -> Cylindrical -> Abstract3DVector. - - The 2D vector is in the xy plane. The `z` coordinate is a keyword argument and - defaults to 0. - """ - polar = represent_as(current, PolarVector) - cyl = represent_as(polar, CylindricalVector, z=z) - return represent_as(cyl, target) - - -# ============================================================================= -# Cartesian2DVector - - -# ----------------------------------------------- -# 1D - - -@dispatch -def represent_as( - current: Cartesian2DVector, target: type[Cartesian1DVector], /, **kwargs: Any -) -> Cartesian1DVector: - """Cartesian2DVector -> Cartesian1DVector. - - The `y` coordinate is dropped. - """ - warn("irreversible dimension change", IrreversibleDimensionChange, stacklevel=2) - return target(x=current.x) - - -@dispatch -def represent_as( - current: Cartesian2DVector, target: type[RadialVector], /, **kwargs: Any -) -> RadialVector: - """Cartesian2DVector -> RadialVector. - - The `x` and `y` coordinates are converted to the radial coordinate `r`. - """ - warn("irreversible dimension change", IrreversibleDimensionChange, stacklevel=2) - return target(r=xp.sqrt(current.x**2 + current.y**2)) - - -# ----------------------------------------------- -# 3D - - -@dispatch -def represent_as( - current: Cartesian2DVector, - target: type[Cartesian3DVector], - /, - *, - z: Quantity = Quantity(0.0, u.m), - **kwargs: Any, -) -> Cartesian3DVector: - """Cartesian2DVector -> Cartesian3DVector. - - The `x` and `y` coordinates are converted to the `x` and `y` coordinates of - the 3D system. The `z` coordinate is a keyword argument and defaults to 0. - """ - return target(x=current.x, y=current.y, z=z) - - -# ============================================================================= -# PolarVector - -# ----------------------------------------------- -# 1D - - -@dispatch -def represent_as( - current: PolarVector, target: type[Cartesian1DVector], /, **kwargs: Any -) -> Cartesian1DVector: - """PolarVector -> Cartesian1DVector.""" - warn("irreversible dimension change", IrreversibleDimensionChange, stacklevel=2) - return target(x=current.r * xp.cos(current.phi)) - - -@dispatch -def represent_as( - current: PolarVector, target: type[RadialVector], /, **kwargs: Any -) -> RadialVector: - """PolarVector -> RadialVector.""" - warn("irreversible dimension change", IrreversibleDimensionChange, stacklevel=2) - return target(r=current.r) - - -# ----------------------------------------------- -# 3D - - -@dispatch -def represent_as( - current: PolarVector, - target: type[SphericalVector], - /, - theta: Quantity["angle"] = Quantity(0.0, u.radian), # type: ignore[name-defined] - **kwargs: Any, -) -> SphericalVector: - """PolarVector -> SphericalVector.""" - return target(r=current.r, theta=theta, phi=current.phi) - - -@dispatch -def represent_as( - current: PolarVector, - target: type[CylindricalVector], - /, - *, - z: Quantity["length"] = Quantity(0.0, u.m), # type: ignore[name-defined] - **kwargs: Any, -) -> CylindricalVector: - """PolarVector -> CylindricalVector.""" - return target(rho=current.r, phi=current.phi, z=z) - - -# # ============================================================================= -# # LnPolarVector - -# # ----------------------------------------------- -# # 1D - - -# @dispatch -# def represent_as( -# current: LnPolarVector, target: type[Cartesian1DVector], /, **kwargs: Any -# ) -> Cartesian1DVector: -# """LnPolarVector -> Cartesian1DVector.""" -# polar = represent_as(current, PolarVector) -# return represent_as(polar, target) - - -# @dispatch -# def represent_as( -# current: LnPolarVector, target: type[RadialVector], /, **kwargs: Any -# ) -> RadialVector: -# """LnPolarVector -> RadialVector.""" -# warn("irreversible dimension change", IrreversibleDimensionChange, stacklevel=2) -# return target(r=xp.exp(current.lnr)) - - -# # ============================================================================= -# # Log10PolarVector - -# # ----------------------------------------------- -# # 1D - - -# @dispatch -# def represent_as( -# current: Log10PolarVector, target: type[Cartesian1DVector], /, **kwargs: Any -# ) -> Cartesian1DVector: -# """Log10PolarVector -> Cartesian1DVector.""" -# # warn("irreversible dimension change", IrreversibleDimensionChange, -# # stacklevel=2) -# polar = represent_as(current, PolarVector) -# return represent_as(polar, target) - - -# @dispatch -# def represent_as( -# current: Log10PolarVector, target: type[RadialVector], /, **kwargs: Any -# ) -> RadialVector: -# """Log10PolarVector -> RadialVector.""" -# warn("irreversible dimension change", IrreversibleDimensionChange, stacklevel=2) -# return target(r=xp.pow(10, current.log10r)) - - -############################################################################### -# 3D - - -# @dispatch.multi( -# (CylindricalVector, type[LnPolarVector]), -# (CylindricalVector, type[Log10PolarVector]), -# (SphericalVector, type[LnPolarVector]), -# (SphericalVector, type[Log10PolarVector]), -# ) -# def represent_as( -# current: Abstract3DVector, target: type[Abstract2DVector], **kwargs: Any -# ) -> Abstract2DVector: -# """Abstract3DVector -> Cylindrical -> PolarVector -> Abstract2DVector.""" -# warn("irreversible dimension change", IrreversibleDimensionChange, stacklevel=2) -# cyl = represent_as(current, CylindricalVector) -# polar = represent_as(cyl, PolarVector) -# return represent_as(polar, target) - - -# ============================================================================= -# Cartesian3DVector - - -# ----------------------------------------------- -# 1D - - -@dispatch -def represent_as( - current: Cartesian3DVector, target: type[Cartesian1DVector], /, **kwargs: Any -) -> Cartesian1DVector: - """Cartesian3DVector -> Cartesian1DVector.""" - warn("irreversible dimension change", IrreversibleDimensionChange, stacklevel=2) - return target(x=current.x) - - -@dispatch -def represent_as( - current: Cartesian3DVector, target: type[RadialVector], /, **kwargs: Any -) -> RadialVector: - """Cartesian3DVector -> RadialVector.""" - warn("irreversible dimension change", IrreversibleDimensionChange, stacklevel=2) - return target(r=xp.sqrt(current.x**2 + current.y**2 + current.z**2)) - - -# ----------------------------------------------- -# 2D - - -@dispatch -def represent_as( - current: Cartesian3DVector, target: type[Cartesian2DVector], /, **kwargs: Any -) -> Cartesian2DVector: - """Cartesian3DVector -> Cartesian2DVector.""" - warn("irreversible dimension change", IrreversibleDimensionChange, stacklevel=2) - return target(x=current.x, y=current.y) - - -@dispatch.multi( - (Cartesian3DVector, type[PolarVector]), - # (Cartesian3DVector, type[LnPolarVector]), - # (Cartesian3DVector, type[Log10PolarVector]), -) -def represent_as( - current: Cartesian3DVector, target: type[Abstract2DVector], /, **kwargs: Any -) -> Abstract2DVector: - """Cartesian3DVector -> Cartesian2D -> Abstract2DVector.""" - warn("irreversible dimension change", IrreversibleDimensionChange, stacklevel=2) - cart2 = represent_as(current, Cartesian2DVector) - return represent_as(cart2, target) - - -# ============================================================================= -# SphericalVector - - -# ----------------------------------------------- -# 1D - - -@dispatch -def represent_as( - current: SphericalVector, target: type[Cartesian1DVector], /, **kwargs: Any -) -> Cartesian1DVector: - """SphericalVector -> Cartesian1DVector.""" - warn("irreversible dimension change", IrreversibleDimensionChange, stacklevel=2) - return target(x=current.r * xp.sin(current.theta) * xp.cos(current.phi)) - - -@dispatch -def represent_as( - current: SphericalVector, target: type[RadialVector], /, **kwargs: Any -) -> RadialVector: - """SphericalVector -> RadialVector.""" - warn("irreversible dimension change", IrreversibleDimensionChange, stacklevel=2) - return target(r=current.r) - - -# ----------------------------------------------- -# 2D - - -@dispatch -def represent_as( - current: SphericalVector, target: type[Cartesian2DVector], /, **kwargs: Any -) -> Cartesian2DVector: - """SphericalVector -> Cartesian2DVector.""" - warn("irreversible dimension change", IrreversibleDimensionChange, stacklevel=2) - x = current.r * xp.sin(current.theta) * xp.cos(current.phi) - y = current.r * xp.sin(current.theta) * xp.sin(current.phi) - return target(x=x, y=y) - - -@dispatch -def represent_as( - current: SphericalVector, target: type[PolarVector], /, **kwargs: Any -) -> PolarVector: - """SphericalVector -> PolarVector.""" - warn("irreversible dimension change", IrreversibleDimensionChange, stacklevel=2) - return target(r=current.r * xp.sin(current.theta), phi=current.phi) - - -# ============================================================================= -# CylindricalVector - - -# ----------------------------------------------- -# 1D - - -@dispatch -def represent_as( - current: CylindricalVector, target: type[Cartesian1DVector], /, **kwargs: Any -) -> Cartesian1DVector: - """CylindricalVector -> Cartesian1DVector.""" - warn("irreversible dimension change", IrreversibleDimensionChange, stacklevel=2) - return target(x=current.rho * xp.cos(current.phi)) - - -@dispatch -def represent_as( - current: CylindricalVector, target: type[RadialVector], /, **kwargs: Any -) -> RadialVector: - """CylindricalVector -> RadialVector.""" - warn("irreversible dimension change", IrreversibleDimensionChange, stacklevel=2) - return target(r=current.rho) - - -# ----------------------------------------------- -# 2D - - -@dispatch -def represent_as( - current: CylindricalVector, target: type[Cartesian2DVector], /, **kwargs: Any -) -> Cartesian2DVector: - """CylindricalVector -> Cartesian2DVector.""" - warn("irreversible dimension change", IrreversibleDimensionChange, stacklevel=2) - x = current.rho * xp.cos(current.phi) - y = current.rho * xp.sin(current.phi) - return target(x=x, y=y) - - -@dispatch -def represent_as( - current: CylindricalVector, target: type[PolarVector], /, **kwargs: Any -) -> PolarVector: - """CylindricalVector -> PolarVector.""" - warn("irreversible dimension change", IrreversibleDimensionChange, stacklevel=2) - return target(r=current.rho, phi=current.phi) diff --git a/src/coordinax/_transform/__init__.py b/src/coordinax/_transform/__init__.py new file mode 100644 index 00000000..bab1c198 --- /dev/null +++ b/src/coordinax/_transform/__init__.py @@ -0,0 +1,8 @@ +"""Transformations for Vectors.""" + +__all__ = ["represent_as"] # noqa: F405 + +from .d1 import * +from .d2 import * +from .d3 import * +from .differentials import * diff --git a/src/coordinax/_transform/d1.py b/src/coordinax/_transform/d1.py new file mode 100644 index 00000000..d936d3ca --- /dev/null +++ b/src/coordinax/_transform/d1.py @@ -0,0 +1,464 @@ +"""Transformations from 1D.""" + +__all__ = ["represent_as"] + +from typing import Any + +import astropy.units as u +from plum import dispatch + +from unxt import Quantity + +from coordinax._d1.builtin import Cartesian1DVector, RadialVector +from coordinax._d2.builtin import Cartesian2DVector, PolarVector +from coordinax._d3.builtin import Cartesian3DVector, CylindricalVector +from coordinax._d3.sphere import MathSphericalVector, SphericalVector + +# ============================================================================= +# Cartesian1DVector + + +# ----------------------------------------------- +# to 2D + + +@dispatch +def represent_as( + current: Cartesian1DVector, + target: type[Cartesian2DVector], + /, + *, + y: Quantity = Quantity(0.0, u.m), + **kwargs: Any, +) -> Cartesian2DVector: + """Cartesian1DVector -> Cartesian2DVector. + + The `x` coordinate is converted to the `x` coordinate of the 2D system. + The `y` coordinate is a keyword argument and defaults to 0. + + Examples + -------- + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> x = cx.Cartesian1DVector(x=Quantity(1.0, "km")) + >>> x2 = cx.represent_as(x, cx.Cartesian2DVector) + >>> x2 + Cartesian2DVector( x=Quantity[...](value=f32[], unit=Unit("km")), + y=Quantity[...](value=f32[], unit=Unit("m")) ) + >>> x2.y + Quantity['length'](Array(0., dtype=float32), unit='m') + + >>> x3 = cx.represent_as(x, cx.Cartesian3DVector, y=Quantity(14, "km")) + >>> x3.y + Quantity['length'](Array(14., dtype=float32), unit='km') + + """ + return target(x=current.x, y=y) + + +@dispatch +def represent_as( + current: Cartesian1DVector, + target: type[PolarVector], + /, + *, + phi: Quantity = Quantity(0.0, u.radian), + **kwargs: Any, +) -> PolarVector: + """Cartesian1DVector -> PolarVector. + + The `x` coordinate is converted to the radial coordinate `r`. + The `phi` coordinate is a keyword argument and defaults to 0. + + Examples + -------- + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> x = cx.Cartesian1DVector(x=Quantity(1.0, "km")) + >>> x2 = cx.represent_as(x, cx.PolarVector) + >>> x2 + PolarVector( r=Distance(value=f32[], unit=Unit("km")), + phi=Quantity[...](value=f32[], unit=Unit("rad")) ) + >>> x2.phi + Quantity['angle'](Array(0., dtype=float32), unit='rad') + + >>> x3 = cx.represent_as(x, cx.PolarVector, phi=Quantity(14, "deg")) + >>> x3.phi + Quantity['angle'](Array(14., dtype=float32), unit='deg') + + """ + return target(r=current.x, phi=phi) + + +# ----------------------------------------------- +# to 3D + + +@dispatch +def represent_as( + current: Cartesian1DVector, + target: type[Cartesian3DVector], + /, + *, + y: Quantity = Quantity(0.0, u.m), + z: Quantity = Quantity(0.0, u.m), + **kwargs: Any, +) -> Cartesian3DVector: + """Cartesian1DVector -> Cartesian3DVector. + + The `x` coordinate is converted to the `x` coordinate of the 3D system. + The `y` and `z` coordinates are keyword arguments and default to 0. + + Examples + -------- + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> x = cx.Cartesian1DVector(x=Quantity(1.0, "km")) + >>> x2 = cx.represent_as(x, cx.Cartesian3DVector) + >>> x2 + Cartesian3DVector( x=Quantity[...](value=f32[], unit=Unit("km")), + y=Quantity[...](value=f32[], unit=Unit("m")), + z=Quantity[...](value=f32[], unit=Unit("m")) ) + >>> x2.y + Quantity['length'](Array(0., dtype=float32), unit='m') + >>> x2.z + Quantity['length'](Array(0., dtype=float32), unit='m') + + >>> x3 = cx.represent_as(x, cx.Cartesian3DVector, y=Quantity(14, "km")) + >>> x3.y + Quantity['length'](Array(14., dtype=float32), unit='km') + >>> x3.z + Quantity['length'](Array(0., dtype=float32), unit='m') + + """ + return target(x=current.x, y=y, z=z) + + +@dispatch +def represent_as( + current: Cartesian1DVector, + target: type[SphericalVector] | type[MathSphericalVector], + /, + *, + theta: Quantity = Quantity(0.0, u.radian), + phi: Quantity = Quantity(0.0, u.radian), + **kwargs: Any, +) -> SphericalVector | MathSphericalVector: + """Cartesian1DVector -> SphericalVector | MathSphericalVector. + + The `x` coordinate is converted to the radial coordinate `r`. + The `theta` and `phi` coordinates are keyword arguments and default to 0. + + Examples + -------- + >>> from unxt import Quantity + >>> import coordinax as cx + + SphericalVector: + + >>> x = cx.Cartesian1DVector(x=Quantity(1.0, "km")) + >>> x2 = cx.represent_as(x, cx.SphericalVector) + >>> x2 + SphericalVector( r=Distance(value=f32[], unit=Unit("km")), + phi=Quantity[...](value=f32[], unit=Unit("rad")), + theta=Quantity[...](value=f32[], unit=Unit("rad")) ) + >>> x2.phi + Quantity['angle'](Array(0., dtype=float32), unit='rad') + >>> x2.theta + Quantity['angle'](Array(0., dtype=float32), unit='rad') + + >>> x3 = cx.represent_as(x, cx.SphericalVector, phi=Quantity(14, "deg")) + >>> x3.phi + Quantity['angle'](Array(14., dtype=float32), unit='deg') + >>> x2.theta + Quantity['angle'](Array(0., dtype=float32), unit='rad') + + MathSphericalVector: + + >>> x2 = cx.represent_as(x, cx.MathSphericalVector) + >>> x2 + MathSphericalVector( r=Distance(value=f32[], unit=Unit("km")), + theta=Quantity[...](value=f32[], unit=Unit("rad")), + phi=Quantity[...](value=f32[], unit=Unit("rad")) ) + >>> x2.theta + Quantity['angle'](Array(0., dtype=float32), unit='rad') + >>> x2.phi + Quantity['angle'](Array(0., dtype=float32), unit='rad') + + >>> x3 = cx.represent_as(x, cx.MathSphericalVector, phi=Quantity(14, "deg")) + >>> x3.theta + Quantity['angle'](Array(0., dtype=float32), unit='rad') + >>> x3.phi + Quantity['angle'](Array(14., dtype=float32), unit='deg') + + """ + return target(r=current.x, theta=theta, phi=phi) + + +@dispatch +def represent_as( + current: Cartesian1DVector, + target: type[CylindricalVector], + /, + *, + phi: Quantity = Quantity(0.0, u.radian), + z: Quantity = Quantity(0.0, u.m), + **kwargs: Any, +) -> CylindricalVector: + """Cartesian1DVector -> CylindricalVector. + + The `x` coordinate is converted to the radial coordinate `rho`. + The `phi` and `z` coordinates are keyword arguments and default to 0. + + Examples + -------- + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> x = cx.Cartesian1DVector(x=Quantity(1.0, "km")) + >>> x2 = cx.represent_as(x, cx.CylindricalVector) + >>> x2 + CylindricalVector( rho=Quantity[...](value=f32[], unit=Unit("km")), + phi=Quantity[...](value=f32[], unit=Unit("rad")), + z=Quantity[...](value=f32[], unit=Unit("m")) ) + >>> x2.phi + Quantity['angle'](Array(0., dtype=float32), unit='rad') + >>> x2.z + Quantity['length'](Array(0., dtype=float32), unit='m') + + >>> x3 = cx.represent_as(x, cx.CylindricalVector, phi=Quantity(14, "deg")) + >>> x3.phi + Quantity['angle'](Array(14., dtype=float32), unit='deg') + >>> x3.z + Quantity['length'](Array(0., dtype=float32), unit='m') + + """ + return target(rho=current.x, phi=phi, z=z) + + +# ============================================================================= +# RadialVector + +# ----------------------------------------------- +# 2D + + +@dispatch +def represent_as( + current: RadialVector, + target: type[Cartesian2DVector], + /, + *, + y: Quantity = Quantity(0.0, u.m), + **kwargs: Any, +) -> Cartesian2DVector: + """RadialVector -> Cartesian2DVector. + + The `r` coordinate is converted to the cartesian coordinate `x`. + The `y` coordinate is a keyword argument and defaults to 0. + + Examples + -------- + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> x = cx.RadialVector(r=Quantity(1.0, "km")) + >>> x2 = cx.represent_as(x, cx.Cartesian2DVector) + >>> x2 + Cartesian2DVector( x=Quantity[...](value=f32[], unit=Unit("km")), + y=Quantity[...](value=f32[], unit=Unit("m")) ) + >>> x2.y + Quantity['length'](Array(0., dtype=float32), unit='m') + + >>> x3 = cx.represent_as(x, cx.Cartesian2DVector, y=Quantity(14, "km")) + >>> x3.y + Quantity['length'](Array(14., dtype=float32), unit='km') + + """ + return target(x=current.r, y=y) + + +@dispatch +def represent_as( + current: RadialVector, + target: type[PolarVector], + /, + *, + phi: Quantity = Quantity(0.0, u.radian), + **kwargs: Any, +) -> PolarVector: + """RadialVector -> PolarVector. + + The `r` coordinate is converted to the radial coordinate `r`. + The `phi` coordinate is a keyword argument and defaults to 0. + + Examples + -------- + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> x = cx.RadialVector(r=Quantity(1.0, "km")) + >>> x2 = cx.represent_as(x, cx.PolarVector) + >>> x2 + PolarVector( r=Distance(value=f32[], unit=Unit("km")), + phi=Quantity[...](value=f32[], unit=Unit("rad")) ) + >>> x2.phi + Quantity['angle'](Array(0., dtype=float32), unit='rad') + + >>> x3 = cx.represent_as(x, cx.PolarVector, phi=Quantity(14, "deg")) + >>> x3.phi + Quantity['angle'](Array(14., dtype=float32), unit='deg') + + """ + return target(r=current.r, phi=phi) + + +# ----------------------------------------------- +# 3D + + +@dispatch +def represent_as( + current: RadialVector, + target: type[Cartesian3DVector], + /, + *, + y: Quantity = Quantity(0.0, u.m), + z: Quantity = Quantity(0.0, u.m), + **kwargs: Any, +) -> Cartesian3DVector: + """RadialVector -> Cartesian3DVector. + + The `r` coordinate is converted to the `x` coordinate of the 3D system. + The `y` and `z` coordinates are keyword arguments and default to 0. + + Examples + -------- + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> x = cx.RadialVector(r=Quantity(1.0, "km")) + >>> x2 = cx.represent_as(x, cx.Cartesian3DVector) + >>> x2 + Cartesian3DVector( x=Quantity[...](value=f32[], unit=Unit("km")), + y=Quantity[...](value=f32[], unit=Unit("m")), + z=Quantity[...](value=f32[], unit=Unit("m")) ) + >>> x2.y + Quantity['length'](Array(0., dtype=float32), unit='m') + >>> x2.z + Quantity['length'](Array(0., dtype=float32), unit='m') + + >>> x3 = cx.represent_as(x, cx.Cartesian3DVector, y=Quantity(14, "km")) + >>> x3.y + Quantity['length'](Array(14., dtype=float32), unit='km') + >>> x3.z + Quantity['length'](Array(0., dtype=float32), unit='m') + + """ + return target(x=current.r, y=y, z=z) + + +@dispatch +def represent_as( + current: RadialVector, + target: type[SphericalVector] | type[MathSphericalVector], + /, + *, + theta: Quantity = Quantity(0.0, u.radian), + phi: Quantity = Quantity(0.0, u.radian), + **kwargs: Any, +) -> SphericalVector | MathSphericalVector: + """RadialVector -> SphericalVector | MathSphericalVector. + + The `r` coordinate is converted to the radial coordinate `r`. + The `theta` and `phi` coordinates are keyword arguments and default to 0. + + Examples + -------- + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> x = cx.RadialVector(r=Quantity(1.0, "km")) + + SphericalVector: + + >>> x2 = cx.represent_as(x, cx.SphericalVector) + >>> x2 + SphericalVector( r=Distance(value=f32[], unit=Unit("km")), + phi=Quantity[...](value=f32[], unit=Unit("rad")), + theta=Quantity[...](value=f32[], unit=Unit("rad")) ) + >>> x2.phi + Quantity['angle'](Array(0., dtype=float32), unit='rad') + >>> x2.theta + Quantity['angle'](Array(0., dtype=float32), unit='rad') + + >>> x3 = cx.represent_as(x, cx.SphericalVector, phi=Quantity(14, "deg")) + >>> x3.phi + Quantity['angle'](Array(14., dtype=float32), unit='deg') + >>> x3.theta + Quantity['angle'](Array(0., dtype=float32), unit='rad') + + MathSphericalVector: + + >>> x2 = cx.represent_as(x, cx.MathSphericalVector) + >>> x2 + MathSphericalVector( r=Distance(value=f32[], unit=Unit("km")), + theta=Quantity[...](value=f32[], unit=Unit("rad")), + phi=Quantity[...](value=f32[], unit=Unit("rad")) ) + >>> x2.theta + Quantity['angle'](Array(0., dtype=float32), unit='rad') + >>> x2.phi + Quantity['angle'](Array(0., dtype=float32), unit='rad') + + >>> x3 = cx.represent_as(x, cx.MathSphericalVector, phi=Quantity(14, "deg")) + >>> x3.theta + Quantity['angle'](Array(0., dtype=float32), unit='rad') + >>> x3.phi + Quantity['angle'](Array(14., dtype=float32), unit='deg') + + """ + return target(r=current.r, theta=theta, phi=phi) + + +@dispatch +def represent_as( + current: RadialVector, + target: type[CylindricalVector], + /, + *, + phi: Quantity = Quantity(0.0, u.radian), + z: Quantity = Quantity(0.0, u.m), + **kwargs: Any, +) -> CylindricalVector: + """RadialVector -> CylindricalVector. + + The `r` coordinate is converted to the radial coordinate `rho`. + The `phi` and `z` coordinates are keyword arguments and default to 0. + + Examples + -------- + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> x = cx.RadialVector(r=Quantity(1.0, "km")) + >>> x2 = cx.represent_as(x, cx.CylindricalVector) + >>> x2 + CylindricalVector( rho=Quantity[...](value=f32[], unit=Unit("km")), + phi=Quantity[...](value=f32[], unit=Unit("rad")), + z=Quantity[...](value=f32[], unit=Unit("m")) ) + >>> x2.phi + Quantity['angle'](Array(0., dtype=float32), unit='rad') + >>> x2.z + Quantity['length'](Array(0., dtype=float32), unit='m') + + >>> x3 = cx.represent_as(x, cx.CylindricalVector, phi=Quantity(14, "deg")) + >>> x3.phi + Quantity['angle'](Array(14., dtype=float32), unit='deg') + >>> x3.z + Quantity['length'](Array(0., dtype=float32), unit='m') + + """ + return target(rho=current.r, phi=phi, z=z) diff --git a/src/coordinax/_transform/d2.py b/src/coordinax/_transform/d2.py new file mode 100644 index 00000000..7bd7817d --- /dev/null +++ b/src/coordinax/_transform/d2.py @@ -0,0 +1,365 @@ +"""Transformations between representations.""" + +__all__ = ["represent_as"] + +from typing import Any +from warnings import warn + +import astropy.units as u +from plum import dispatch + +import quaxed.array_api as xp +from unxt import Quantity + +from coordinax._d1.builtin import Cartesian1DVector, RadialVector +from coordinax._d2.base import Abstract2DVector +from coordinax._d2.builtin import Cartesian2DVector, PolarVector +from coordinax._d3.base import Abstract3DVector +from coordinax._d3.builtin import Cartesian3DVector, CylindricalVector +from coordinax._d3.sphere import MathSphericalVector, SphericalVector +from coordinax._exceptions import IrreversibleDimensionChange + + +@dispatch.multi( + (Cartesian2DVector, type[CylindricalVector]), + (Cartesian2DVector, type[SphericalVector]), + (Cartesian2DVector, type[MathSphericalVector]), +) +def represent_as( + current: Abstract2DVector, + target: type[Abstract3DVector], + /, + z: Quantity = Quantity(0.0, u.m), + **kwargs: Any, +) -> Abstract3DVector: + """Abstract2DVector -> Cartesian2D -> Cartesian3D -> Abstract3DVector. + + The 2D vector is in the xy plane. The `z` coordinate is a keyword argument and + defaults to 0. + + Examples + -------- + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> x = cx.Cartesian2DVector.constructor(Quantity([1.0, 2.0], "km")) + + >>> x2 = cx.represent_as(x, cx.CylindricalVector, z=Quantity(14, "km")) + >>> x2 + CylindricalVector( rho=Quantity[...](value=f32[], unit=Unit("km")), + phi=Quantity[...](value=f32[], unit=Unit("rad")), + z=Quantity[...](value=f32[], unit=Unit("km")) ) + >>> x2.z + Quantity['length'](Array(14., dtype=float32), unit='km') + + >>> x3 = cx.represent_as(x, cx.SphericalVector, z=Quantity(14, "km")) + >>> x3 + SphericalVector( r=Distance(value=f32[], unit=Unit("km")), + phi=Quantity[...](value=f32[], unit=Unit("rad")), + theta=Quantity[...](value=f32[], unit=Unit("rad")) ) + >>> x3.r + Distance(Array(14.177447, dtype=float32), unit='km') + + >>> x3 = cx.represent_as(x, cx.MathSphericalVector, z=Quantity(14, "km")) + >>> x3 + MathSphericalVector( r=Distance(value=f32[], unit=Unit("km")), + theta=Quantity[...](value=f32[], unit=Unit("rad")), + phi=Quantity[...](value=f32[], unit=Unit("rad")) ) + >>> x3.r + Distance(Array(14.177447, dtype=float32), unit='km') + + """ + cart2 = represent_as(current, Cartesian2DVector) + cart3 = represent_as(cart2, Cartesian3DVector, z=z) + return represent_as(cart3, target) + + +@dispatch.multi( + (PolarVector, type[Cartesian3DVector]), +) +def represent_as( + current: Abstract2DVector, + target: type[Abstract3DVector], + /, + z: Quantity = Quantity(0.0, u.m), + **kwargs: Any, +) -> Abstract3DVector: + """Abstract2DVector -> PolarVector -> Cylindrical -> Abstract3DVector. + + The 2D vector is in the xy plane. The `z` coordinate is a keyword argument and + defaults to 0. + + Examples + -------- + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> x = cx.PolarVector(r=Quantity(1.0, "km"), phi=Quantity(10.0, "deg")) + + >>> x2 = cx.represent_as(x, cx.Cartesian3DVector, z=Quantity(14, "km")) + >>> x2 + Cartesian3DVector( x=Quantity[...](value=f32[], unit=Unit("km")), + y=Quantity[...](value=f32[], unit=Unit("km")), + z=Quantity[...](value=f32[], unit=Unit("km")) ) + >>> x2.z + Quantity['length'](Array(14., dtype=float32), unit='km') + + """ + polar = represent_as(current, PolarVector) + cyl = represent_as(polar, CylindricalVector, z=z) + return represent_as(cyl, target) + + +# ============================================================================= +# Cartesian2DVector + + +# ----------------------------------------------- +# 1D + + +@dispatch +def represent_as( + current: Cartesian2DVector, target: type[Cartesian1DVector], /, **kwargs: Any +) -> Cartesian1DVector: + """Cartesian2DVector -> Cartesian1DVector. + + The `y` coordinate is dropped. + + Examples + -------- + >>> import warnings + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> x = cx.Cartesian2DVector.constructor(Quantity([1.0, 2.0], "km")) + + >>> with warnings.catch_warnings(): + ... warnings.simplefilter("ignore") + ... x2 = cx.represent_as(x, cx.Cartesian1DVector, z=Quantity(14, "km")) + >>> x2 + Cartesian1DVector( x=Quantity[...](value=f32[], unit=Unit("km")) ) + >>> x2.x + Quantity['length'](Array(1., dtype=float32), unit='km') + + """ + warn("irreversible dimension change", IrreversibleDimensionChange, stacklevel=2) + return target(x=current.x) + + +@dispatch +def represent_as( + current: Cartesian2DVector, target: type[RadialVector], /, **kwargs: Any +) -> RadialVector: + """Cartesian2DVector -> RadialVector. + + The `x` and `y` coordinates are converted to the radial coordinate `r`. + + Examples + -------- + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> x = cx.Cartesian2DVector.constructor(Quantity([1.0, 2.0], "km")) + + >>> with warnings.catch_warnings(): + ... warnings.simplefilter("ignore") + ... x2 = cx.represent_as(x, cx.RadialVector, z=Quantity(14, "km")) + >>> x2 + RadialVector(r=Distance(value=f32[], unit=Unit("km"))) + >>> x2.r + Distance(Array(2.236068, dtype=float32), unit='km') + + """ + warn("irreversible dimension change", IrreversibleDimensionChange, stacklevel=2) + return target(r=xp.sqrt(current.x**2 + current.y**2)) + + +# ----------------------------------------------- +# 3D + + +@dispatch +def represent_as( + current: Cartesian2DVector, + target: type[Cartesian3DVector], + /, + *, + z: Quantity = Quantity(0.0, u.m), + **kwargs: Any, +) -> Cartesian3DVector: + """Cartesian2DVector -> Cartesian3DVector. + + The `x` and `y` coordinates are converted to the `x` and `y` coordinates of + the 3D system. The `z` coordinate is a keyword argument and defaults to 0. + + Examples + -------- + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> x = cx.Cartesian2DVector.constructor(Quantity([1.0, 2.0], "km")) + + >>> x2 = cx.represent_as(x, cx.Cartesian3DVector, z=Quantity(14, "km")) + >>> x2 + Cartesian3DVector( x=Quantity[...](value=f32[], unit=Unit("km")), + y=Quantity[...](value=f32[], unit=Unit("km")), + z=Quantity[...](value=f32[], unit=Unit("km")) ) + >>> x2.z + Quantity['length'](Array(14., dtype=float32), unit='km') + + """ + return target(x=current.x, y=current.y, z=z) + + +# ============================================================================= +# PolarVector + +# ----------------------------------------------- +# 1D + + +@dispatch +def represent_as( + current: PolarVector, target: type[Cartesian1DVector], /, **kwargs: Any +) -> Cartesian1DVector: + """PolarVector -> Cartesian1DVector. + + Examples + -------- + >>> import warnings + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> x = cx.PolarVector(r=Quantity(1.0, "km"), phi=Quantity(10.0, "deg")) + + >>> with warnings.catch_warnings(): + ... warnings.simplefilter("ignore") + ... x2 = cx.represent_as(x, cx.Cartesian1DVector) + >>> x2 + Cartesian1DVector( x=Quantity[...](value=f32[], unit=Unit("km")) ) + >>> x2.x + Quantity['length'](Array(0.9848077, dtype=float32), unit='km') + + """ + warn("irreversible dimension change", IrreversibleDimensionChange, stacklevel=2) + return target(x=current.r * xp.cos(current.phi)) + + +@dispatch +def represent_as( + current: PolarVector, target: type[RadialVector], /, **kwargs: Any +) -> RadialVector: + """PolarVector -> RadialVector. + + Examples + -------- + >>> import warnings + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> x = cx.PolarVector(r=Quantity(1.0, "km"), phi=Quantity(10.0, "deg")) + + >>> with warnings.catch_warnings(): + ... warnings.simplefilter("ignore") + ... x2 = cx.represent_as(x, cx.RadialVector) + >>> x2 + RadialVector(r=Distance(value=f32[], unit=Unit("km"))) + >>> x2.r + Distance(Array(1., dtype=float32), unit='km') + + """ + warn("irreversible dimension change", IrreversibleDimensionChange, stacklevel=2) + return target(r=current.r) + + +# ----------------------------------------------- +# 3D + + +@dispatch +def represent_as( + current: PolarVector, + target: type[SphericalVector], + /, + theta: Quantity["angle"] = Quantity(0.0, u.radian), # type: ignore[name-defined] + **kwargs: Any, +) -> SphericalVector: + """PolarVector -> SphericalVector. + + Examples + -------- + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> x = cx.PolarVector(r=Quantity(1.0, "km"), phi=Quantity(10.0, "deg")) + + >>> x2 = cx.represent_as(x, cx.SphericalVector, theta=Quantity(14, "deg")) + >>> x2 + SphericalVector( r=Distance(value=f32[], unit=Unit("km")), + phi=Quantity[...](value=f32[], unit=Unit("deg")), + theta=Quantity[...](value=f32[], unit=Unit("deg")) ) + >>> x2.theta + Quantity['angle'](Array(14., dtype=float32), unit='deg') + + """ + return target(r=current.r, theta=theta, phi=current.phi) + + +@dispatch +def represent_as( + current: PolarVector, + target: type[MathSphericalVector], + /, + phi: Quantity["angle"] = Quantity(0.0, u.radian), # type: ignore[name-defined] + **kwargs: Any, +) -> MathSphericalVector: + """PolarVector -> MathSphericalVector. + + Examples + -------- + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> x = cx.PolarVector(r=Quantity(1.0, "km"), phi=Quantity(10.0, "deg")) + + >>> x2 = cx.represent_as(x, cx.MathSphericalVector, phi=Quantity(14, "deg")) + >>> x2 + MathSphericalVector( r=Distance(value=f32[], unit=Unit("km")), + theta=Quantity[...](value=f32[], unit=Unit("deg")), + phi=Quantity[...](value=f32[], unit=Unit("deg")) ) + >>> x2.phi + Quantity['angle'](Array(14., dtype=float32), unit='deg') + + """ + return target(r=current.r, phi=phi, theta=current.phi) + + +@dispatch +def represent_as( + current: PolarVector, + target: type[CylindricalVector], + /, + *, + z: Quantity["length"] = Quantity(0.0, u.m), # type: ignore[name-defined] + **kwargs: Any, +) -> CylindricalVector: + """PolarVector -> CylindricalVector. + + Examples + -------- + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> x = cx.PolarVector(r=Quantity(1.0, "km"), phi=Quantity(10.0, "deg")) + + >>> x2 = cx.represent_as(x, cx.CylindricalVector, z=Quantity(14, "km")) + >>> x2 + CylindricalVector( rho=Quantity[...](value=f32[], unit=Unit("km")), + phi=Quantity[...](value=f32[], unit=Unit("deg")), + z=Quantity[...](value=f32[], unit=Unit("km")) ) + >>> x2.z + Quantity['length'](Array(14., dtype=float32), unit='km') + + """ + return target(rho=current.r, phi=current.phi, z=z) diff --git a/src/coordinax/_transform/d3.py b/src/coordinax/_transform/d3.py new file mode 100644 index 00000000..2fcfeae6 --- /dev/null +++ b/src/coordinax/_transform/d3.py @@ -0,0 +1,497 @@ +"""Transformations between representations.""" + +__all__ = ["represent_as"] + +from typing import Any +from warnings import warn + +from plum import dispatch + +import quaxed.array_api as xp + +from coordinax._d1.builtin import Cartesian1DVector, RadialVector +from coordinax._d2.base import Abstract2DVector +from coordinax._d2.builtin import Cartesian2DVector, PolarVector +from coordinax._d3.builtin import Cartesian3DVector, CylindricalVector +from coordinax._d3.sphere import MathSphericalVector, SphericalVector +from coordinax._exceptions import IrreversibleDimensionChange + +# ============================================================================= +# Cartesian3DVector + + +# ----------------------------------------------- +# 1D + + +@dispatch +def represent_as( + current: Cartesian3DVector, target: type[Cartesian1DVector], /, **kwargs: Any +) -> Cartesian1DVector: + """Cartesian3DVector -> Cartesian1DVector. + + The `y` and `z` coordinates are dropped. + + Examples + -------- + >>> import warnings + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> x = cx.Cartesian3DVector.constructor(Quantity([1.0, 2.0, 3.0], "km")) + + >>> with warnings.catch_warnings(): + ... warnings.simplefilter("ignore") + ... x2 = cx.represent_as(x, cx.Cartesian1DVector) + >>> x2 + Cartesian1DVector( + x=Quantity[PhysicalType('length')](value=f32[], unit=Unit("km")) + ) + + """ + warn("irreversible dimension change", IrreversibleDimensionChange, stacklevel=2) + return target(x=current.x) + + +@dispatch +def represent_as( + current: Cartesian3DVector, target: type[RadialVector], /, **kwargs: Any +) -> RadialVector: + """Cartesian3DVector -> RadialVector. + + Examples + -------- + >>> import warnings + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> x = cx.Cartesian3DVector.constructor(Quantity([1.0, 2.0, 3.0], "km")) + + >>> with warnings.catch_warnings(): + ... warnings.simplefilter("ignore") + ... x2 = cx.represent_as(x, cx.RadialVector) + >>> x2 + RadialVector(r=Distance(value=f32[], unit=Unit("km"))) + + """ + warn("irreversible dimension change", IrreversibleDimensionChange, stacklevel=2) + return target(r=xp.sqrt(current.x**2 + current.y**2 + current.z**2)) + + +# ----------------------------------------------- +# 2D + + +@dispatch +def represent_as( + current: Cartesian3DVector, target: type[Cartesian2DVector], /, **kwargs: Any +) -> Cartesian2DVector: + """Cartesian3DVector -> Cartesian2DVector. + + Examples + -------- + >>> import warnings + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> x = cx.Cartesian3DVector.constructor(Quantity([1.0, 2.0, 3.0], "km")) + + >>> with warnings.catch_warnings(): + ... warnings.simplefilter("ignore") + ... x2 = cx.represent_as(x, cx.Cartesian2DVector) + >>> x2 + Cartesian2DVector( x=Quantity[...](value=f32[], unit=Unit("km")), + y=Quantity[...](value=f32[], unit=Unit("km")) ) + + """ + warn("irreversible dimension change", IrreversibleDimensionChange, stacklevel=2) + return target(x=current.x, y=current.y) + + +@dispatch.multi( + (Cartesian3DVector, type[PolarVector]), +) +def represent_as( + current: Cartesian3DVector, target: type[Abstract2DVector], /, **kwargs: Any +) -> Abstract2DVector: + """Cartesian3DVector -> Cartesian2D -> Abstract2DVector. + + Examples + -------- + >>> import warnings + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> x = cx.Cartesian3DVector.constructor(Quantity([1.0, 2.0, 3.0], "km")) + + >>> with warnings.catch_warnings(): + ... warnings.simplefilter("ignore") + ... x2 = cx.represent_as(x, cx.PolarVector) + >>> x2 + PolarVector( r=Distance(value=f32[], unit=Unit("km")), + phi=Quantity[...](value=f32[], unit=Unit("rad")) ) + + """ + warn("irreversible dimension change", IrreversibleDimensionChange, stacklevel=2) + cart2 = represent_as(current, Cartesian2DVector) + return represent_as(cart2, target) + + +# ============================================================================= +# CylindricalVector + + +# ----------------------------------------------- +# 1D + + +@dispatch +def represent_as( + current: CylindricalVector, target: type[Cartesian1DVector], /, **kwargs: Any +) -> Cartesian1DVector: + """CylindricalVector -> Cartesian1DVector. + + Examples + -------- + >>> import warnings + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> x = cx.CylindricalVector(rho=Quantity(1.0, "km"), phi=Quantity(10.0, "deg"), + ... z=Quantity(14, "km")) + + >>> with warnings.catch_warnings(): + ... warnings.simplefilter("ignore") + ... x2 = cx.represent_as(x, cx.Cartesian1DVector) + >>> x2 + Cartesian1DVector( x=Quantity[...](value=f32[], unit=Unit("km")) ) + + """ + warn("irreversible dimension change", IrreversibleDimensionChange, stacklevel=2) + return target(x=current.rho * xp.cos(current.phi)) + + +@dispatch +def represent_as( + current: CylindricalVector, target: type[RadialVector], /, **kwargs: Any +) -> RadialVector: + """CylindricalVector -> RadialVector. + + Examples + -------- + >>> import warnings + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> x = cx.CylindricalVector(rho=Quantity(1.0, "km"), phi=Quantity(10.0, "deg"), + ... z=Quantity(14, "km")) + + >>> with warnings.catch_warnings(): + ... warnings.simplefilter("ignore") + ... x2 = cx.represent_as(x, cx.RadialVector) + >>> x2 + RadialVector(r=Distance(value=f32[], unit=Unit("km"))) + + """ + warn("irreversible dimension change", IrreversibleDimensionChange, stacklevel=2) + return target(r=current.rho) + + +# ----------------------------------------------- +# 2D + + +@dispatch +def represent_as( + current: CylindricalVector, target: type[Cartesian2DVector], /, **kwargs: Any +) -> Cartesian2DVector: + """CylindricalVector -> Cartesian2DVector. + + Examples + -------- + >>> import warnings + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> x = cx.CylindricalVector(rho=Quantity(1.0, "km"), phi=Quantity(10.0, "deg"), + ... z=Quantity(14, "km")) + + >>> with warnings.catch_warnings(): + ... warnings.simplefilter("ignore") + ... x2 = cx.represent_as(x, cx.Cartesian2DVector) + >>> x2 + Cartesian2DVector( x=Quantity[...](value=f32[], unit=Unit("km")), + y=Quantity[...](value=f32[], unit=Unit("km")) ) + + """ + warn("irreversible dimension change", IrreversibleDimensionChange, stacklevel=2) + x = current.rho * xp.cos(current.phi) + y = current.rho * xp.sin(current.phi) + return target(x=x, y=y) + + +@dispatch +def represent_as( + current: CylindricalVector, target: type[PolarVector], /, **kwargs: Any +) -> PolarVector: + """CylindricalVector -> PolarVector. + + Examples + -------- + >>> import warnings + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> x = cx.CylindricalVector(rho=Quantity(1.0, "km"), phi=Quantity(10.0, "deg"), + ... z=Quantity(14, "km")) + + >>> with warnings.catch_warnings(): + ... warnings.simplefilter("ignore") + ... x2 = cx.represent_as(x, cx.PolarVector) + >>> x2 + PolarVector( r=Distance(value=f32[], unit=Unit("km")), + phi=Quantity[...](value=f32[], unit=Unit("deg")) ) + + """ + warn("irreversible dimension change", IrreversibleDimensionChange, stacklevel=2) + return target(r=current.rho, phi=current.phi) + + +# ============================================================================= +# SphericalVector + + +# ----------------------------------------------- +# 1D + + +@dispatch +def represent_as( + current: SphericalVector, target: type[Cartesian1DVector], /, **kwargs: Any +) -> Cartesian1DVector: + """SphericalVector -> Cartesian1DVector. + + Examples + -------- + >>> import warnings + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> x = cx.SphericalVector(r=Quantity(1.0, "km"), phi=Quantity(10.0, "deg"), + ... theta=Quantity(14, "deg")) + + >>> with warnings.catch_warnings(): + ... warnings.simplefilter("ignore") + ... x2 = cx.represent_as(x, cx.Cartesian1DVector) + >>> x2 + Cartesian1DVector( x=Quantity[...](value=f32[], unit=Unit("km")) ) + + """ + warn("irreversible dimension change", IrreversibleDimensionChange, stacklevel=2) + return target(x=current.r * xp.sin(current.theta) * xp.cos(current.phi)) + + +@dispatch +def represent_as( + current: SphericalVector, target: type[RadialVector], /, **kwargs: Any +) -> RadialVector: + """SphericalVector -> RadialVector. + + Examples + -------- + >>> import warnings + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> x = cx.SphericalVector(r=Quantity(1.0, "km"), phi=Quantity(10.0, "deg"), + ... theta=Quantity(14, "deg")) + + >>> with warnings.catch_warnings(): + ... warnings.simplefilter("ignore") + ... x2 = cx.represent_as(x, cx.RadialVector) + >>> x2 + RadialVector(r=Distance(value=f32[], unit=Unit("km"))) + + """ + warn("irreversible dimension change", IrreversibleDimensionChange, stacklevel=2) + return target(r=current.r) + + +# ----------------------------------------------- +# 2D + + +@dispatch +def represent_as( + current: SphericalVector, target: type[Cartesian2DVector], /, **kwargs: Any +) -> Cartesian2DVector: + """SphericalVector -> Cartesian2DVector. + + Examples + -------- + >>> import warnings + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> x = cx.SphericalVector(r=Quantity(1.0, "km"), phi=Quantity(10.0, "deg"), + ... theta=Quantity(14, "deg")) + + >>> with warnings.catch_warnings(): + ... warnings.simplefilter("ignore") + ... x2 = cx.represent_as(x, cx.Cartesian2DVector) + >>> x2 + Cartesian2DVector( x=Quantity[...](value=f32[], unit=Unit("km")), + y=Quantity[...](value=f32[], unit=Unit("km")) ) + + """ + warn("irreversible dimension change", IrreversibleDimensionChange, stacklevel=2) + x = current.r * xp.sin(current.theta) * xp.cos(current.phi) + y = current.r * xp.sin(current.theta) * xp.sin(current.phi) + return target(x=x, y=y) + + +@dispatch +def represent_as( + current: SphericalVector, target: type[PolarVector], /, **kwargs: Any +) -> PolarVector: + """SphericalVector -> PolarVector. + + Examples + -------- + >>> import warnings + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> x = cx.SphericalVector(r=Quantity(1.0, "km"), phi=Quantity(10.0, "deg"), + ... theta=Quantity(14, "deg")) + + >>> with warnings.catch_warnings(): + ... warnings.simplefilter("ignore") + ... x2 = cx.represent_as(x, cx.PolarVector) + >>> x2 + PolarVector( r=Distance(value=f32[], unit=Unit("km")), + phi=Quantity[...](value=f32[], unit=Unit("deg")) ) + + """ + warn("irreversible dimension change", IrreversibleDimensionChange, stacklevel=2) + return target(r=current.r * xp.sin(current.theta), phi=current.phi) + + +# ============================================================================= +# MathSphericalVector + + +# ----------------------------------------------- +# 1D + + +@dispatch +def represent_as( + current: MathSphericalVector, target: type[Cartesian1DVector], /, **kwargs: Any +) -> Cartesian1DVector: + """MathSphericalVector -> Cartesian1DVector. + + Examples + -------- + >>> import warnings + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> x = cx.MathSphericalVector(r=Quantity(1.0, "km"), theta=Quantity(10.0, "deg"), + ... phi=Quantity(14, "deg")) + + >>> with warnings.catch_warnings(): + ... warnings.simplefilter("ignore") + ... x2 = cx.represent_as(x, cx.Cartesian1DVector) + >>> x2 + Cartesian1DVector( x=Quantity[...](value=f32[], unit=Unit("km")) ) + + """ + warn("irreversible dimension change", IrreversibleDimensionChange, stacklevel=2) + return target(x=current.r * xp.sin(current.phi) * xp.cos(current.theta)) + + +@dispatch +def represent_as( + current: MathSphericalVector, target: type[RadialVector], /, **kwargs: Any +) -> RadialVector: + """MathSphericalVector -> RadialVector. + + Examples + -------- + >>> import warnings + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> x = cx.MathSphericalVector(r=Quantity(1.0, "km"), theta=Quantity(10.0, "deg"), + ... phi=Quantity(14, "deg")) + + >>> with warnings.catch_warnings(): + ... warnings.simplefilter("ignore") + ... x2 = cx.represent_as(x, cx.RadialVector) + >>> x2 + RadialVector(r=Distance(value=f32[], unit=Unit("km"))) + + """ + warn("irreversible dimension change", IrreversibleDimensionChange, stacklevel=2) + return target(r=current.r) + + +# ----------------------------------------------- +# 2D + + +@dispatch +def represent_as( + current: MathSphericalVector, target: type[Cartesian2DVector], /, **kwargs: Any +) -> Cartesian2DVector: + """MathSphericalVector -> Cartesian2DVector. + + Examples + -------- + >>> import warnings + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> x = cx.MathSphericalVector(r=Quantity(1.0, "km"), theta=Quantity(10.0, "deg"), + ... phi=Quantity(14, "deg")) + + >>> with warnings.catch_warnings(): + ... warnings.simplefilter("ignore") + ... x2 = cx.represent_as(x, cx.Cartesian2DVector) + >>> x2 + Cartesian2DVector( x=Quantity[...](value=f32[], unit=Unit("km")), + y=Quantity[...](value=f32[], unit=Unit("km")) ) + + """ + warn("irreversible dimension change", IrreversibleDimensionChange, stacklevel=2) + x = current.r * xp.sin(current.phi) * xp.cos(current.theta) + y = current.r * xp.sin(current.phi) * xp.sin(current.theta) + return target(x=x, y=y) + + +@dispatch +def represent_as( + current: MathSphericalVector, target: type[PolarVector], /, **kwargs: Any +) -> PolarVector: + """MathSphericalVector -> PolarVector. + + Examples + -------- + >>> import warnings + >>> from unxt import Quantity + >>> import coordinax as cx + + >>> x = cx.MathSphericalVector(r=Quantity(1.0, "km"), theta=Quantity(10.0, "deg"), + ... phi=Quantity(14, "deg")) + + >>> with warnings.catch_warnings(): + ... warnings.simplefilter("ignore") + ... x2 = cx.represent_as(x, cx.PolarVector) + >>> x2 + PolarVector( r=Distance(value=f32[], unit=Unit("km")), + phi=Quantity[...](value=f32[], unit=Unit("deg")) ) + + """ + warn("irreversible dimension change", IrreversibleDimensionChange, stacklevel=2) + return target(r=current.r * xp.sin(current.phi), phi=current.theta) diff --git a/src/coordinax/_transform/differentials.py b/src/coordinax/_transform/differentials.py new file mode 100644 index 00000000..aee99c92 --- /dev/null +++ b/src/coordinax/_transform/differentials.py @@ -0,0 +1,166 @@ +"""Transformations between representations.""" + +__all__ = ["represent_as"] + +from math import prod +from typing import Any + +import jax +from plum import dispatch + +import quaxed.array_api as xp +from unxt import Quantity + +from coordinax._base_dif import AbstractVectorDifferential +from coordinax._base_vec import AbstractVector +from coordinax._d1.base import Abstract1DVectorDifferential +from coordinax._d2.base import Abstract2DVectorDifferential +from coordinax._d3.base import Abstract3DVectorDifferential +from coordinax._utils import dataclass_items + + +# TODO: implement for cross-representations +@dispatch.multi( # type: ignore[misc] + # N-D -> N-D + ( + Abstract1DVectorDifferential, + type[Abstract1DVectorDifferential], # type: ignore[misc] + AbstractVector | Quantity["length"], + ), + ( + Abstract2DVectorDifferential, + type[Abstract2DVectorDifferential], # type: ignore[misc] + AbstractVector | Quantity["length"], + ), + ( + Abstract3DVectorDifferential, + type[Abstract3DVectorDifferential], # type: ignore[misc] + AbstractVector | Quantity["length"], + ), +) +def represent_as( + current: AbstractVectorDifferential, + target: type[AbstractVectorDifferential], + position: AbstractVector | Quantity["length"], + /, + **kwargs: Any, +) -> AbstractVectorDifferential: + """AbstractVectorDifferential -> Cartesian -> AbstractVectorDifferential. + + This is the base case for the transformation of vector differentials. + + Parameters + ---------- + current : AbstractVectorDifferential + The vector differential to transform. + target : type[AbstractVectorDifferential] + The target type of the vector differential. + position : AbstractVector + The position vector used to transform the differential. + **kwargs : Any + Additional keyword arguments. + + Examples + -------- + >>> import coordinax as cx + >>> from unxt import Quantity + + Let's start in 1D: + + >>> q = cx.Cartesian1DVector(x=Quantity(1.0, "km")) + >>> p = cx.CartesianDifferential1D(d_x=Quantity(1.0, "km/s")) + >>> cx.represent_as(p, cx.RadialDifferential, q) + RadialDifferential( d_r=Quantity[...]( value=f32[], unit=Unit("km / s") ) ) + + Now in 2D: + + >>> q = cx.Cartesian2DVector.constructor(Quantity([1.0, 2.0], "km")) + >>> p = cx.CartesianDifferential2D.constructor(Quantity([1.0, 2.0], "km/s")) + >>> cx.represent_as(p, cx.PolarDifferential, q) + PolarDifferential( + d_r=Quantity[...]( value=f32[], unit=Unit("km / s") ), + d_phi=Quantity[...]( value=f32[], unit=Unit("rad / s") ) + ) + + And in 3D: + + >>> q = cx.Cartesian3DVector.constructor(Quantity([1.0, 2.0, 3.0], "km")) + >>> p = cx.CartesianDifferential3D.constructor(Quantity([1.0, 2.0, 3.0], "km/s")) + >>> cx.represent_as(p, cx.SphericalDifferential, q) + SphericalDifferential( + d_r=Quantity[...]( value=f32[], unit=Unit("km / s") ), + d_phi=Quantity[...]( value=f32[], unit=Unit("rad / s") ), + d_theta=Quantity[...]( value=f32[], unit=Unit("rad / s") ) + ) + + If given a position as a Quantity, it will be converted to the appropriate + Cartesian vector: + + >>> p = cx.CartesianDifferential3D.constructor(Quantity([1.0, 2.0, 3.0], "km/s")) + >>> cx.represent_as(p, cx.SphericalDifferential, Quantity([1.0, 2.0, 3.0], "km")) + SphericalDifferential( + d_r=Quantity[...]( value=f32[], unit=Unit("km / s") ), + d_phi=Quantity[...]( value=f32[], unit=Unit("rad / s") ), + d_theta=Quantity[...]( value=f32[], unit=Unit("rad / s") ) + ) + + """ + # TODO: not require the shape munging / support more shapes + shape = current.shape + flat_shape = prod(shape) + + # Parse the position to an AbstractVector + if isinstance(position, AbstractVector): + posvec = position + else: # Q -> CartD + posvec = current.integral_cls._cartesian_cls.constructor( # noqa: SLF001 + position + ) + + posvec = posvec.reshape(flat_shape) # flattened + + # Start by transforming the position to the type required by the + # differential to construct the Jacobian. + current_pos = represent_as(posvec, current.integral_cls, **kwargs) + + # Takes the Jacobian through the representation transformation function. This + # returns a representation of the target type, where the value of each field the + # corresponding row of the Jacobian. The value of the field is a Quantity with + # the correct numerator unit (of the Jacobian row). The value is a Vector of the + # original type, with fields that are the columns of that row, but with only the + # denomicator's units. + jac_nested_vecs = jac_rep_as(current_pos, target.integral_cls) + + # This changes the Jacobian to be a dictionary of each row, with the value + # being that row's column as a dictionary, now with the correct units for + # each element: {row_i: {col_j: Quantity(value, row.unit / column.unit)}} + jac_rows = { + f"d_{k}": { + kk: Quantity(vv.value, unit=v.unit / vv.unit) + for kk, vv in dataclass_items(v.value) + } + for k, v in dataclass_items(jac_nested_vecs) + } + + # Now we can use the Jacobian to transform the differential. + flat_current = current.reshape(flat_shape) + return target( + **{ # Each field is the dot product of the row of the J and the diff column. + k: xp.sum( # Doing the dot product. + xp.stack( + tuple( + j_c * getattr(flat_current, f"d_{kk}") + for kk, j_c in j_r.items() + ) + ), + axis=0, + ) + for k, j_r in jac_rows.items() + } + ).reshape(shape) + + +# TODO: situate this better to show how represent_as is used +jac_rep_as = jax.jit( + jax.vmap(jax.jacfwd(represent_as), in_axes=(0, None)), static_argnums=(1,) +) diff --git a/tests/test_base.py b/tests/test_base.py index 5438d948..48d3ee01 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -47,8 +47,6 @@ # 2D Cartesian2DVector, PolarVector, - # LnPolarVector, - # Log10PolarVector, # 3D Cartesian3DVector, SphericalVector, diff --git a/tests/test_d1.py b/tests/test_d1.py index deec458a..5a49ec2e 100644 --- a/tests/test_d1.py +++ b/tests/test_d1.py @@ -61,22 +61,6 @@ def test_cartesian1d_to_polar(self, vector): assert qnp.array_equal(polar.r, Quantity([1, 2, 3, 4], "kpc")) assert qnp.array_equal(polar.phi, Quantity([0, 1, 2, 3], "rad")) - # def test_cartesian1d_to_lnpolar(self, vector): - # """Test ``coordinax.represent_as(LnPolarVector)``.""" - # lnpolar = vector.to_lnpolar(phi=Quantity([0, 1, 2, 3], "rad")) - - # assert isinstance(lnpolar, LnPolarVector) - # assert lnpolar.lnr == xp.log(Quantity([1, 2, 3, 4], "kpc")) - # assert qnp.array_equal(lnpolar.phi, Quantity([0, 1, 2, 3], "rad")) - - # def test_cartesian1d_to_log10polar(self, vector): - # """Test ``coordinax.represent_as(Log10PolarVector)``.""" - # log10polar = vector.to_log10polar(phi=Quantity([0, 1, 2, 3], "rad")) - - # assert isinstance(log10polar, Log10PolarVector) - # assert log10polar.log10r == xp.log10(Quantity([1, 2, 3, 4], "kpc")) - # assert qnp.array_equal(log10polar.phi, Quantity([0, 1, 2, 3], "rad")) - def test_cartesian1d_to_cartesian3d(self, vector): """Test ``coordinax.represent_as(Cartesian3DVector)``.""" cart3d = vector.represent_as( @@ -163,14 +147,6 @@ def test_radial_to_polar(self, vector): assert qnp.array_equal(polar.r, Quantity([1, 2, 3, 4], "kpc")) assert qnp.array_equal(polar.phi, Quantity([0, 1, 2, 3], "rad")) - # def test_radial_to_lnpolar(self, vector): - # """Test ``coordinax.represent_as(LnPolarVector)``.""" - # assert False - - # def test_radial_to_log10polar(self, vector): - # """Test ``coordinax.represent_as(Log10PolarVector)``.""" - # assert False - def test_radial_to_cartesian3d(self, vector): """Test ``coordinax.represent_as(Cartesian3DVector)``.""" cart3d = vector.represent_as( diff --git a/tests/test_d2.py b/tests/test_d2.py index db90775f..75bcfd83 100644 --- a/tests/test_d2.py +++ b/tests/test_d2.py @@ -70,14 +70,6 @@ def test_cartesian2d_to_polar(self, vector): atol=Quantity(1e-8, "deg"), ) - # def test_cartesian2d_to_lnpolar(self, vector): - # """Test ``coordinax.represent_as(LnPolarVector)``.""" - # assert False - - # def test_cartesian2d_to_log10polar(self, vector): - # """Test ``coordinax.represent_as(Log10PolarVector)``.""" - # assert False - def test_cartesian2d_to_cartesian3d(self, vector): """Test ``coordinax.represent_as(Cartesian3DVector)``.""" cart3d = vector.represent_as( @@ -185,14 +177,6 @@ def test_polar_to_polar(self, vector): newvec = cx.represent_as(vector, cx.PolarVector) assert newvec is vector - # def test_polar_to_lnpolar(self, vector): - # """Test ``coordinax.represent_as(LnPolarVector)``.""" - # assert False - - # def test_polar_to_log10polar(self, vector): - # """Test ``coordinax.represent_as(Log10PolarVector) - # assert False - def test_polar_to_cartesian3d(self, vector): """Test ``coordinax.represent_as(Cartesian3DVector)``.""" cart3d = vector.represent_as( diff --git a/tests/test_d3.py b/tests/test_d3.py index af84de20..a6f84803 100644 --- a/tests/test_d3.py +++ b/tests/test_d3.py @@ -115,16 +115,6 @@ def test_cartesian3d_to_polar(self, vector): polar.phi, Quantity([1.3734008, 1.2490457, 1.1659045, 1.1071488], "rad") ) - # @pytest.mark.filterwarnings("ignore:Irreversible dimension change") - # def test_cartesian3d_to_lnpolar(self, vector): - # """Test ``coordinax.represent_as(LnPolarVector)``.""" - # assert False - - # @pytest.mark.filterwarnings("ignore:Irreversible dimension change") - # def test_cartesian3d_to_log10polar(self, vector): - # """Test ``coordinax.represent_as(Log10PolarVector)``.""" - # assert False - def test_cartesian3d_to_cartesian3d(self, vector): """Test ``coordinax.represent_as(Cartesian3DVector)``.""" # Jit can copy @@ -191,40 +181,40 @@ def test_cartesian3d_to_cylindrical_astropy(self, vector, apyvector): assert np.allclose(convert(cyl.phi, APYQuantity), apycyl.phi) -class TestSphericalVector(Abstract3DVectorTest): - """Test :class:`coordinax.SphericalVector`.""" +class TestCylindricalVector(Abstract3DVectorTest): + """Test :class:`coordinax.CylindricalVector`.""" @pytest.fixture(scope="class") - def vector(self) -> cx.SphericalVector: + def vector(self) -> cx.AbstractVector: """Return a vector.""" - return cx.SphericalVector( - r=Quantity([1, 2, 3, 4], "kpc"), - phi=Quantity([0, 65, 135, 270], "deg"), - theta=Quantity([0, 36, 142, 180], "deg"), + return cx.CylindricalVector( + rho=Quantity([1, 2, 3, 4], "kpc"), + phi=Quantity([0, 1, 2, 3], "rad"), + z=Quantity([9, 10, 11, 12], "m"), ) @pytest.fixture(scope="class") def apyvector(self, vector: cx.AbstractVector): """Return an Astropy vector.""" - return convert(vector, apyc.PhysicsSphericalRepresentation) + return convert(vector, apyc.CylindricalRepresentation) # ========================================================================== # represent_as @pytest.mark.filterwarnings("ignore:Irreversible dimension change") - def test_spherical_to_cartesian1d(self, vector): + def test_cylindrical_to_cartesian1d(self, vector): """Test ``coordinax.represent_as(Cartesian1DVector)``.""" cart1d = vector.represent_as(cx.Cartesian1DVector) assert isinstance(cart1d, cx.Cartesian1DVector) assert qnp.allclose( cart1d.x, - Quantity([0, 0.49681753, -1.3060151, -4.1700245e-15], "kpc"), + Quantity([1.0, 1.0806047, -1.2484405, -3.95997], "kpc"), atol=Quantity(1e-8, "kpc"), ) @pytest.mark.filterwarnings("ignore:Irreversible dimension change") - def test_spherical_to_radial(self, vector): + def test_cylindrical_to_radial(self, vector): """Test ``coordinax.represent_as(RadialVector)``.""" radial = vector.represent_as(cx.RadialVector) @@ -232,59 +222,41 @@ def test_spherical_to_radial(self, vector): assert qnp.array_equal(radial.r, Quantity([1, 2, 3, 4], "kpc")) @pytest.mark.filterwarnings("ignore:Irreversible dimension change") - def test_spherical_to_cartesian2d(self, vector): + def test_cylindrical_to_cartesian2d(self, vector): """Test ``coordinax.represent_as(Cartesian2DVector)``.""" - cart2d = vector.represent_as( - cx.Cartesian2DVector, y=Quantity([5, 6, 7, 8], "km") - ) + cart2d = vector.represent_as(cx.Cartesian2DVector) assert isinstance(cart2d, cx.Cartesian2DVector) assert qnp.array_equal( - cart2d.x, - Quantity([0, 0.49681753, -1.3060151, -4.1700245e-15], "kpc"), + cart2d.x, Quantity([1.0, 1.0806046, -1.2484405, -3.95997], "kpc") ) assert qnp.array_equal( - cart2d.y, Quantity([0.0, 1.0654287, 1.3060151, 3.4969111e-07], "kpc") + cart2d.y, Quantity([0.0, 1.6829419, 2.7278922, 0.56448], "kpc") ) @pytest.mark.filterwarnings("ignore:Irreversible dimension change") - def test_spherical_to_polar(self, vector): + def test_cylindrical_to_polar(self, vector): """Test ``coordinax.represent_as(PolarVector)``.""" - polar = vector.represent_as(cx.PolarVector, phi=Quantity([0, 1, 2, 3], "rad")) + polar = vector.represent_as(cx.PolarVector) assert isinstance(polar, cx.PolarVector) - assert qnp.array_equal( - polar.r, - Quantity([0.0, 1.1755705, 1.8469844, -3.4969111e-07], "kpc"), - ) - assert qnp.array_equal(polar.phi, Quantity([0.0, 65.0, 135.0, 270.0], "deg")) - - # @pytest.mark.filterwarnings("ignore:Irreversible dimension change") - # def test_spherical_to_lnpolar(self, vector): - # """Test ``coordinax.represent_as(LnPolarVector)``.""" - # assert False - - # @pytest.mark.filterwarnings("ignore:Irreversible dimension change") - # def test_spherical_to_log10polar(self, vector): - # """Test ``coordinax.represent_as(Log10PolarVector)``.""" - # assert False + assert qnp.array_equal(polar.r, Quantity([1, 2, 3, 4], "kpc")) + assert qnp.array_equal(polar.phi, Quantity([0, 1, 2, 3], "rad")) - def test_spherical_to_cartesian3d(self, vector): + def test_cylindrical_to_cartesian3d(self, vector): """Test ``coordinax.represent_as(Cartesian3DVector)``.""" cart3d = vector.represent_as(cx.Cartesian3DVector) assert isinstance(cart3d, cx.Cartesian3DVector) assert qnp.array_equal( - cart3d.x, Quantity([0, 0.49681753, -1.3060151, -4.1700245e-15], "kpc") - ) - assert qnp.array_equal( - cart3d.y, Quantity([0.0, 1.0654287, 1.3060151, 3.4969111e-07], "kpc") + cart3d.x, Quantity([1.0, 1.0806046, -1.2484405, -3.95997], "kpc") ) assert qnp.array_equal( - cart3d.z, Quantity([1.0, 1.618034, -2.3640323, -4.0], "kpc") + cart3d.y, Quantity([0.0, 1.6829419, 2.7278922, 0.56448], "kpc") ) + assert qnp.array_equal(cart3d.z, vector.z) - def test_spherical_to_cartesian3d_astropy(self, vector, apyvector): + def test_cylindrical_to_cartesian3d_astropy(self, vector, apyvector): """Test Astropy equivalence.""" cart3d = vector.represent_as(cx.Cartesian3DVector) @@ -293,94 +265,77 @@ def test_spherical_to_cartesian3d_astropy(self, vector, apyvector): assert np.allclose(convert(cart3d.y, APYQuantity), apycart3.y) assert np.allclose(convert(cart3d.z, APYQuantity), apycart3.z) - def test_spherical_to_spherical(self, vector): + def test_cylindrical_to_spherical(self, vector): """Test ``coordinax.represent_as(SphericalVector)``.""" - # Jit can copy - newvec = vector.represent_as(cx.SphericalVector) - assert newvec == vector + spherical = vector.represent_as(cx.SphericalVector) - # The normal `represent_as` method should return the same object - newvec = cx.represent_as(vector, cx.SphericalVector) - assert newvec is vector + assert isinstance(spherical, cx.SphericalVector) + assert qnp.array_equal(spherical.r, Quantity([1, 2, 3, 4], "kpc")) + assert qnp.array_equal(spherical.phi, Quantity([0, 1, 2, 3], "rad")) + assert qnp.array_equal(spherical.theta, Quantity(xp.full(4, xp.pi / 2), "rad")) - def test_spherical_to_spherical_astropy(self, vector, apyvector): + def test_cylindrical_to_spherical_astropy(self, vector, apyvector): """Test Astropy equivalence.""" sph = vector.represent_as(cx.SphericalVector) - apysph = apyvector.represent_as(apyc.PhysicsSphericalRepresentation) assert np.allclose(convert(sph.r, APYQuantity), apysph.r) assert np.allclose(convert(sph.theta, APYQuantity), apysph.theta) assert np.allclose(convert(sph.phi, APYQuantity), apysph.phi) - def test_spherical_to_cylindrical(self, vector): + def test_cylindrical_to_cylindrical(self, vector): """Test ``coordinax.represent_as(CylindricalVector)``.""" - cylindrical = vector.represent_as( - cx.CylindricalVector, z=Quantity([9, 10, 11, 12], "m") - ) + # Jit can copy + newvec = vector.represent_as(cx.CylindricalVector) + assert newvec == vector - assert isinstance(cylindrical, cx.CylindricalVector) - assert qnp.array_equal( - cylindrical.rho, - Quantity([0.0, 1.1755705, 1.8469844, 3.4969111e-07], "kpc"), - ) - assert qnp.array_equal( - cylindrical.phi, Quantity([0.0, 65.0, 135.0, 270.0], "deg") - ) - assert qnp.array_equal( - cylindrical.z, Quantity([1.0, 1.618034, -2.3640323, -4.0], "kpc") - ) + # The normal `represent_as` method should return the same object + newvec = cx.represent_as(vector, cx.CylindricalVector) + assert newvec is vector - def test_spherical_to_cylindrical_astropy(self, vector, apyvector): - """Test ``coordinax.represent_as(CylindricalVector)``.""" - cyl = vector.represent_as( - cx.CylindricalVector, z=Quantity([9, 10, 11, 12], "m") - ) + def test_cylindrical_to_cylindrical_astropy(self, vector, apyvector): + """Test Astropy equivalence.""" + cyl = vector.represent_as(cx.CylindricalVector) apycyl = apyvector.represent_as(apyc.CylindricalRepresentation) assert np.allclose(convert(cyl.rho, APYQuantity), apycyl.rho) assert np.allclose(convert(cyl.z, APYQuantity), apycyl.z) - - assert np.allclose(convert(cyl.phi[:-1], APYQuantity), apycyl.phi[:-1]) - # There's a 'bug' in Astropy where at the origin phi is always 90, or at - # least doesn't keep its value. - with pytest.raises(AssertionError): # TODO: Fix this - assert np.allclose(convert(cyl.phi[-1], APYQuantity), apycyl.phi[-1]) + assert np.allclose(convert(cyl.phi, APYQuantity), apycyl.phi) -class TestCylindricalVector(Abstract3DVectorTest): - """Test :class:`coordinax.CylindricalVector`.""" +class TestSphericalVector(Abstract3DVectorTest): + """Test :class:`coordinax.SphericalVector`.""" @pytest.fixture(scope="class") - def vector(self) -> cx.AbstractVector: + def vector(self) -> cx.SphericalVector: """Return a vector.""" - return cx.CylindricalVector( - rho=Quantity([1, 2, 3, 4], "kpc"), - phi=Quantity([0, 1, 2, 3], "rad"), - z=Quantity([9, 10, 11, 12], "m"), + return cx.SphericalVector( + r=Quantity([1, 2, 3, 4], "kpc"), + phi=Quantity([0, 65, 135, 270], "deg"), + theta=Quantity([0, 36, 142, 180], "deg"), ) @pytest.fixture(scope="class") def apyvector(self, vector: cx.AbstractVector): """Return an Astropy vector.""" - return convert(vector, apyc.CylindricalRepresentation) + return convert(vector, apyc.PhysicsSphericalRepresentation) # ========================================================================== # represent_as @pytest.mark.filterwarnings("ignore:Irreversible dimension change") - def test_cylindrical_to_cartesian1d(self, vector): + def test_spherical_to_cartesian1d(self, vector): """Test ``coordinax.represent_as(Cartesian1DVector)``.""" cart1d = vector.represent_as(cx.Cartesian1DVector) assert isinstance(cart1d, cx.Cartesian1DVector) assert qnp.allclose( cart1d.x, - Quantity([1.0, 1.0806047, -1.2484405, -3.95997], "kpc"), + Quantity([0, 0.49681753, -1.3060151, -4.1700245e-15], "kpc"), atol=Quantity(1e-8, "kpc"), ) @pytest.mark.filterwarnings("ignore:Irreversible dimension change") - def test_cylindrical_to_radial(self, vector): + def test_spherical_to_radial(self, vector): """Test ``coordinax.represent_as(RadialVector)``.""" radial = vector.represent_as(cx.RadialVector) @@ -388,51 +343,49 @@ def test_cylindrical_to_radial(self, vector): assert qnp.array_equal(radial.r, Quantity([1, 2, 3, 4], "kpc")) @pytest.mark.filterwarnings("ignore:Irreversible dimension change") - def test_cylindrical_to_cartesian2d(self, vector): + def test_spherical_to_cartesian2d(self, vector): """Test ``coordinax.represent_as(Cartesian2DVector)``.""" - cart2d = vector.represent_as(cx.Cartesian2DVector) + cart2d = vector.represent_as( + cx.Cartesian2DVector, y=Quantity([5, 6, 7, 8], "km") + ) assert isinstance(cart2d, cx.Cartesian2DVector) assert qnp.array_equal( - cart2d.x, Quantity([1.0, 1.0806046, -1.2484405, -3.95997], "kpc") + cart2d.x, + Quantity([0, 0.49681753, -1.3060151, -4.1700245e-15], "kpc"), ) assert qnp.array_equal( - cart2d.y, Quantity([0.0, 1.6829419, 2.7278922, 0.56448], "kpc") + cart2d.y, Quantity([0.0, 1.0654287, 1.3060151, 3.4969111e-07], "kpc") ) @pytest.mark.filterwarnings("ignore:Irreversible dimension change") - def test_cylindrical_to_polar(self, vector): + def test_spherical_to_polar(self, vector): """Test ``coordinax.represent_as(PolarVector)``.""" - polar = vector.represent_as(cx.PolarVector) + polar = vector.represent_as(cx.PolarVector, phi=Quantity([0, 1, 2, 3], "rad")) assert isinstance(polar, cx.PolarVector) - assert qnp.array_equal(polar.r, Quantity([1, 2, 3, 4], "kpc")) - assert qnp.array_equal(polar.phi, Quantity([0, 1, 2, 3], "rad")) - - # @pytest.mark.filterwarnings("ignore:Irreversible dimension change") - # def test_cylindrical_to_lnpolar(self, vector): - # """Test ``coordinax.represent_as(LnPolarVector)``.""" - # assert False - - # @pytest.mark.filterwarnings("ignore:Irreversible dimension change") - # def test_cylindrical_to_log10polar(self, vector): - # """Test ``coordinax.represent_as(Log10PolarVector)``.""" - # assert False + assert qnp.array_equal( + polar.r, + Quantity([0.0, 1.1755705, 1.8469844, -3.4969111e-07], "kpc"), + ) + assert qnp.array_equal(polar.phi, Quantity([0.0, 65.0, 135.0, 270.0], "deg")) - def test_cylindrical_to_cartesian3d(self, vector): + def test_spherical_to_cartesian3d(self, vector): """Test ``coordinax.represent_as(Cartesian3DVector)``.""" cart3d = vector.represent_as(cx.Cartesian3DVector) assert isinstance(cart3d, cx.Cartesian3DVector) assert qnp.array_equal( - cart3d.x, Quantity([1.0, 1.0806046, -1.2484405, -3.95997], "kpc") + cart3d.x, Quantity([0, 0.49681753, -1.3060151, -4.1700245e-15], "kpc") ) assert qnp.array_equal( - cart3d.y, Quantity([0.0, 1.6829419, 2.7278922, 0.56448], "kpc") + cart3d.y, Quantity([0.0, 1.0654287, 1.3060151, 3.4969111e-07], "kpc") + ) + assert qnp.array_equal( + cart3d.z, Quantity([1.0, 1.618034, -2.3640323, -4.0], "kpc") ) - assert qnp.array_equal(cart3d.z, vector.z) - def test_cylindrical_to_cartesian3d_astropy(self, vector, apyvector): + def test_spherical_to_cartesian3d_astropy(self, vector, apyvector): """Test Astropy equivalence.""" cart3d = vector.represent_as(cx.Cartesian3DVector) @@ -441,41 +394,83 @@ def test_cylindrical_to_cartesian3d_astropy(self, vector, apyvector): assert np.allclose(convert(cart3d.y, APYQuantity), apycart3.y) assert np.allclose(convert(cart3d.z, APYQuantity), apycart3.z) - def test_cylindrical_to_spherical(self, vector): + def test_spherical_to_cylindrical(self, vector): + """Test ``coordinax.represent_as(CylindricalVector)``.""" + cyl = vector.represent_as( + cx.CylindricalVector, z=Quantity([9, 10, 11, 12], "m") + ) + + assert isinstance(cyl, cx.CylindricalVector) + assert qnp.array_equal( + cyl.rho, + Quantity([0.0, 1.1755705, 1.8469844, 3.4969111e-07], "kpc"), + ) + assert qnp.array_equal(cyl.phi, Quantity([0.0, 65.0, 135.0, 270.0], "deg")) + assert qnp.array_equal( + cyl.z, Quantity([1.0, 1.618034, -2.3640323, -4.0], "kpc") + ) + + def test_spherical_to_cylindrical_astropy(self, vector, apyvector): + """Test ``coordinax.represent_as(CylindricalVector)``.""" + cyl = vector.represent_as( + cx.CylindricalVector, z=Quantity([9, 10, 11, 12], "m") + ) + + apycyl = apyvector.represent_as(apyc.CylindricalRepresentation) + assert np.allclose(convert(cyl.rho, APYQuantity), apycyl.rho) + assert np.allclose(convert(cyl.z, APYQuantity), apycyl.z) + + assert np.allclose(convert(cyl.phi[:-1], APYQuantity), apycyl.phi[:-1]) + # There's a 'bug' in Astropy where at the origin phi is always 90, or at + # least doesn't keep its value. + with pytest.raises(AssertionError): # TODO: Fix this + assert np.allclose(convert(cyl.phi[-1], APYQuantity), apycyl.phi[-1]) + + def test_spherical_to_spherical(self, vector): """Test ``coordinax.represent_as(SphericalVector)``.""" - spherical = vector.represent_as(cx.SphericalVector) + # Jit can copy + newvec = vector.represent_as(cx.SphericalVector) + assert newvec == vector - assert isinstance(spherical, cx.SphericalVector) - assert qnp.array_equal(spherical.r, Quantity([1, 2, 3, 4], "kpc")) - assert qnp.array_equal(spherical.phi, Quantity([0, 1, 2, 3], "rad")) - assert qnp.array_equal(spherical.theta, Quantity(xp.full(4, xp.pi / 2), "rad")) + # The normal `represent_as` method should return the same object + newvec = cx.represent_as(vector, cx.SphericalVector) + assert newvec is vector - def test_cylindrical_to_spherical_astropy(self, vector, apyvector): + def test_spherical_to_spherical_astropy(self, vector, apyvector): """Test Astropy equivalence.""" sph = vector.represent_as(cx.SphericalVector) + apysph = apyvector.represent_as(apyc.PhysicsSphericalRepresentation) assert np.allclose(convert(sph.r, APYQuantity), apysph.r) assert np.allclose(convert(sph.theta, APYQuantity), apysph.theta) assert np.allclose(convert(sph.phi, APYQuantity), apysph.phi) - def test_cylindrical_to_cylindrical(self, vector): - """Test ``coordinax.represent_as(CylindricalVector)``.""" - # Jit can copy - newvec = vector.represent_as(cx.CylindricalVector) - assert newvec == vector + def test_spherical_to_mathspherical(self, vector): + """Test ``coordinax.represent_as(MathSphericalVector)``.""" + newvec = cx.represent_as(vector, cx.MathSphericalVector) + assert qnp.array_equal(newvec.r, vector.r) + assert qnp.array_equal(newvec.phi, vector.theta) + assert qnp.array_equal(newvec.theta, vector.phi) - # The normal `represent_as` method should return the same object - newvec = cx.represent_as(vector, cx.CylindricalVector) - assert newvec is vector + def test_spherical_to_lonlatspherical(self, vector): + """Test ``coordinax.represent_as(LonLatSphericalVector)``.""" + llsph = vector.represent_as( + cx.LonLatSphericalVector, z=Quantity([9, 10, 11, 12], "m") + ) - def test_cylindrical_to_cylindrical_astropy(self, vector, apyvector): + assert isinstance(llsph, cx.LonLatSphericalVector) + assert qnp.array_equal(llsph.distance, vector.r) + assert qnp.array_equal(llsph.lon, vector.phi) + assert qnp.array_equal(llsph.lat, Quantity(90, "deg") - vector.theta) + + def test_spherical_to_lonlatspherical_astropy(self, vector, apyvector): """Test Astropy equivalence.""" - cyl = vector.represent_as(cx.CylindricalVector) + llsph = vector.represent_as(cx.LonLatSphericalVector) - apycyl = apyvector.represent_as(apyc.CylindricalRepresentation) - assert np.allclose(convert(cyl.rho, APYQuantity), apycyl.rho) - assert np.allclose(convert(cyl.z, APYQuantity), apycyl.z) - assert np.allclose(convert(cyl.phi, APYQuantity), apycyl.phi) + apycart3 = apyvector.represent_as(apyc.SphericalRepresentation) + assert np.allclose(convert(llsph.distance, APYQuantity), apycart3.distance) + assert np.allclose(convert(llsph.lon, APYQuantity), apycart3.lon) + assert np.allclose(convert(llsph.lat, APYQuantity), apycart3.lat) class Abstract3DVectorDifferentialTest(AbstractVectorDifferentialTest): @@ -650,46 +645,42 @@ def test_cartesian3d_to_cylindrical_astropy( assert np.allclose(convert(cyl.d_phi, APYQuantity), apycyl.d_phi) -class TestSphericalDifferential(Abstract3DVectorDifferentialTest): - """Test :class:`coordinax.SphericalDifferential`.""" +class TestCylindricalDifferential(Abstract3DVectorDifferentialTest): + """Test :class:`coordinax.CylindricalDifferential`.""" @pytest.fixture(scope="class") - def difntl(self) -> cx.SphericalDifferential: + def difntl(self) -> cx.CylindricalDifferential: """Return a differential.""" - return cx.SphericalDifferential( - d_r=Quantity([5, 6, 7, 8], "km/s"), + return cx.CylindricalDifferential( + d_rho=Quantity([5, 6, 7, 8], "km/s"), d_phi=Quantity([9, 10, 11, 12], "mas/yr"), - d_theta=Quantity([13, 14, 15, 16], "mas/yr"), + d_z=Quantity([13, 14, 15, 16], "km/s"), ) @pytest.fixture(scope="class") - def vector(self) -> cx.SphericalVector: + def vector(self) -> cx.CylindricalVector: """Return a vector.""" - return cx.SphericalVector( - r=Quantity([1, 2, 3, 4], "kpc"), - phi=Quantity([0, 42, 160, 270], "deg"), - theta=Quantity([3, 63, 90, 179.5], "deg"), + return cx.CylindricalVector( + rho=Quantity([1, 2, 3, 4], "kpc"), + phi=Quantity([0, 1, 2, 3], "rad"), + z=Quantity([9, 10, 11, 12], "kpc"), ) @pytest.fixture(scope="class") - def apydifntl( - self, difntl: cx.SphericalDifferential - ) -> apyc.PhysicsSphericalDifferential: + def apydifntl(self, difntl: cx.CylindricalDifferential): """Return an Astropy differential.""" - return convert(difntl, apyc.PhysicsSphericalDifferential) + return convert(difntl, apyc.CylindricalDifferential) @pytest.fixture(scope="class") - def apyvector( - self, vector: cx.SphericalVector - ) -> apyc.PhysicsSphericalRepresentation: + def apyvector(self, vector: cx.CylindricalVector) -> apyc.CylindricalRepresentation: """Return an Astropy vector.""" - return convert(vector, apyc.PhysicsSphericalRepresentation) + return convert(vector, apyc.CylindricalRepresentation) # ========================================================================== @pytest.mark.xfail(reason="Not implemented") @pytest.mark.filterwarnings("ignore:Explicitly requested dtype") - def test_spherical_to_cartesian1d(self, difntl, vector): + def test_cylindrical_to_cartesian1d(self, difntl, vector): """Test ``coordinax.represent_as(Cartesian1DVector)``.""" cart1d = difntl.represent_as(cx.CartesianDifferential1D, vector) @@ -698,7 +689,7 @@ def test_spherical_to_cartesian1d(self, difntl, vector): @pytest.mark.xfail(reason="Not implemented") @pytest.mark.filterwarnings("ignore:Explicitly requested dtype") - def test_spherical_to_radial(self, difntl, vector): + def test_cylindrical_to_radial(self, difntl, vector): """Test ``coordinax.represent_as(RadialVector)``.""" radial = difntl.represent_as(cx.RadialVector, vector) @@ -707,7 +698,7 @@ def test_spherical_to_radial(self, difntl, vector): @pytest.mark.xfail(reason="Not implemented") @pytest.mark.filterwarnings("ignore:Explicitly requested dtype") - def test_spherical_to_cartesian2d(self, difntl, vector): + def test_cylindrical_to_cartesian2d(self, difntl, vector): """Test ``coordinax.represent_as(Cartesian2DVector)``.""" cart2d = difntl.represent_as(cx.CartesianDifferential2D, vector) @@ -717,7 +708,7 @@ def test_spherical_to_cartesian2d(self, difntl, vector): @pytest.mark.xfail(reason="Not implemented") @pytest.mark.filterwarnings("ignore:Explicitly requested dtype") - def test_spherical_to_polar(self, difntl, vector): + def test_cylindrical_to_polar(self, difntl, vector): """Test ``coordinax.represent_as(PolarVector)``.""" polar = difntl.represent_as(cx.PolarVector, vector) @@ -725,49 +716,49 @@ def test_spherical_to_polar(self, difntl, vector): assert qnp.array_equal(polar.d_r, Quantity([1, 2, 3, 4], "km/s")) assert qnp.array_equal(polar.d_phi, Quantity([5, 6, 7, 8], "mas/yr")) - def test_spherical_to_cartesian3d(self, difntl, vector): + def test_cylindrical_to_cartesian3d(self, difntl, vector, apydifntl, apyvector): """Test ``coordinax.represent_as(Cartesian3DVector)``.""" cart3d = difntl.represent_as(cx.CartesianDifferential3D, vector) assert isinstance(cart3d, cx.CartesianDifferential3D) - assert qnp.allclose( - cart3d.d_x, - Quantity([61.803337, -7.770853, -60.081947, 1.985678], "km/s"), - atol=Quantity(1e-8, "km/s"), + assert qnp.array_equal( + cart3d.d_x, Quantity([5.0, -76.537544, -145.15944, -40.03075], "km/s") ) - assert qnp.allclose( + assert qnp.array_equal( cart3d.d_y, - Quantity([2.2328734, 106.6765, -144.60716, 303.30875], "km/s"), - atol=Quantity(1e-8, "km/s"), - ) - assert qnp.allclose( - cart3d.d_z, - Quantity([1.7678856, -115.542175, -213.32118, -10.647271], "km/s"), - atol=Quantity(1e-8, "km/s"), + Quantity([42.664234, 56.274563, -58.73506, -224.13647], "km/s"), ) - - def test_spherical_to_cartesian3d_astropy( - self, difntl, vector, apydifntl, apyvector - ): - """Test Astropy equivalence.""" - cart3d = difntl.represent_as(cx.CartesianDifferential3D, vector) + assert qnp.array_equal(cart3d.d_z, Quantity([13.0, 14.0, 15.0, 16.0], "km/s")) apycart3 = apydifntl.represent_as(apyc.CartesianDifferential, apyvector) assert np.allclose(convert(cart3d.d_x, APYQuantity), apycart3.d_x) assert np.allclose(convert(cart3d.d_y, APYQuantity), apycart3.d_y) assert np.allclose(convert(cart3d.d_z, APYQuantity), apycart3.d_z) - def test_spherical_to_spherical(self, difntl, vector): + def test_cylindrical_to_spherical(self, difntl, vector): """Test ``coordinax.represent_as(SphericalDifferential)``.""" - # Jit can copy - newvec = difntl.represent_as(cx.SphericalDifferential, vector) - assert newvec == difntl + dsph = difntl.represent_as(cx.SphericalDifferential, vector) - # The normal `represent_as` method should return the same object - newvec = cx.represent_as(difntl, cx.SphericalDifferential, vector) - assert newvec is difntl + assert isinstance(dsph, cx.SphericalDifferential) + assert qnp.array_equal( + dsph.d_r, + Quantity([13.472646, 14.904826, 16.313278, 17.708754], "km/s"), + ) + assert qnp.array_equal( + dsph.d_phi, + Quantity([42.664234, 47.404705, 52.145176, 56.885643], "km rad / (kpc s)"), + ) + assert qnp.allclose( + dsph.d_theta, + Quantity( + [0.3902412, 0.30769292, 0.24615361, 0.19999981], "km rad / (kpc s)" + ), + atol=Quantity(5e-7, "km rad / (kpc s)"), + ) - def test_spherical_to_spherical_astropy(self, difntl, vector, apydifntl, apyvector): + def test_cylindrical_to_spherical_astropy( + self, difntl, vector, apydifntl, apyvector + ): """Test Astropy equivalence.""" sph = difntl.represent_as(cx.SphericalDifferential, vector) apysph = apydifntl.represent_as(apyc.PhysicsSphericalDifferential, apyvector) @@ -775,30 +766,17 @@ def test_spherical_to_spherical_astropy(self, difntl, vector, apydifntl, apyvect assert np.allclose(convert(sph.d_theta, APYQuantity), apysph.d_theta) assert np.allclose(convert(sph.d_phi, APYQuantity), apysph.d_phi) - def test_spherical_to_cylindrical(self, difntl, vector): + def test_cylindrical_to_cylindrical(self, difntl, vector): """Test ``coordinax.represent_as(CylindricalDifferential)``.""" - cylindrical = difntl.represent_as(cx.CylindricalDifferential, vector) + # Jit can copy + newvec = difntl.represent_as(cx.CylindricalDifferential, vector) + assert newvec == difntl - assert isinstance(cylindrical, cx.CylindricalDifferential) - assert qnp.allclose( - cylindrical.d_rho, - Quantity([61.803337, 65.60564, 6.9999905, -303.30875], "km/s"), - atol=Quantity(1e-8, "km/s"), - ) - assert qnp.allclose( - cylindrical.d_phi, - Quantity([2444.4805, 2716.0894, 2987.6985, 3259.3074], "deg km / (kpc s)"), - atol=Quantity(1e-8, "mas/yr"), - ) - assert qnp.allclose( - cylindrical.d_z, - Quantity([1.7678856, -115.542175, -213.32118, -10.647271], "km/s"), - atol=Quantity(1e-8, "km/s"), - ) + # The normal `represent_as` method should return the same object + newvec = cx.represent_as(difntl, cx.CylindricalDifferential, vector) + assert newvec is difntl - def test_spherical_to_cylindrical_astropy( - self, difntl, vector, apydifntl, apyvector - ): + def test_cylindrical_to_cylindrical(self, difntl, vector, apydifntl, apyvector): """Test Astropy equivalence.""" cyl = difntl.represent_as(cx.CylindricalDifferential, vector) apycyl = apydifntl.represent_as(apyc.CylindricalDifferential, apyvector) @@ -807,42 +785,46 @@ def test_spherical_to_cylindrical_astropy( assert np.allclose(convert(cyl.d_phi, APYQuantity), apycyl.d_phi) -class TestCylindricalDifferential(Abstract3DVectorDifferentialTest): - """Test :class:`coordinax.CylindricalDifferential`.""" +class TestSphericalDifferential(Abstract3DVectorDifferentialTest): + """Test :class:`coordinax.SphericalDifferential`.""" @pytest.fixture(scope="class") - def difntl(self) -> cx.CylindricalDifferential: + def difntl(self) -> cx.SphericalDifferential: """Return a differential.""" - return cx.CylindricalDifferential( - d_rho=Quantity([5, 6, 7, 8], "km/s"), + return cx.SphericalDifferential( + d_r=Quantity([5, 6, 7, 8], "km/s"), d_phi=Quantity([9, 10, 11, 12], "mas/yr"), - d_z=Quantity([13, 14, 15, 16], "km/s"), + d_theta=Quantity([13, 14, 15, 16], "mas/yr"), ) @pytest.fixture(scope="class") - def vector(self) -> cx.CylindricalVector: + def vector(self) -> cx.SphericalVector: """Return a vector.""" - return cx.CylindricalVector( - rho=Quantity([1, 2, 3, 4], "kpc"), - phi=Quantity([0, 1, 2, 3], "rad"), - z=Quantity([9, 10, 11, 12], "kpc"), + return cx.SphericalVector( + r=Quantity([1, 2, 3, 4], "kpc"), + phi=Quantity([0, 42, 160, 270], "deg"), + theta=Quantity([3, 63, 90, 179.5], "deg"), ) @pytest.fixture(scope="class") - def apydifntl(self, difntl: cx.CylindricalDifferential): + def apydifntl( + self, difntl: cx.SphericalDifferential + ) -> apyc.PhysicsSphericalDifferential: """Return an Astropy differential.""" - return convert(difntl, apyc.CylindricalDifferential) + return convert(difntl, apyc.PhysicsSphericalDifferential) @pytest.fixture(scope="class") - def apyvector(self, vector: cx.CylindricalVector) -> apyc.CylindricalRepresentation: + def apyvector( + self, vector: cx.SphericalVector + ) -> apyc.PhysicsSphericalRepresentation: """Return an Astropy vector.""" - return convert(vector, apyc.CylindricalRepresentation) + return convert(vector, apyc.PhysicsSphericalRepresentation) # ========================================================================== @pytest.mark.xfail(reason="Not implemented") @pytest.mark.filterwarnings("ignore:Explicitly requested dtype") - def test_cylindrical_to_cartesian1d(self, difntl, vector): + def test_spherical_to_cartesian1d(self, difntl, vector): """Test ``coordinax.represent_as(Cartesian1DVector)``.""" cart1d = difntl.represent_as(cx.CartesianDifferential1D, vector) @@ -851,7 +833,7 @@ def test_cylindrical_to_cartesian1d(self, difntl, vector): @pytest.mark.xfail(reason="Not implemented") @pytest.mark.filterwarnings("ignore:Explicitly requested dtype") - def test_cylindrical_to_radial(self, difntl, vector): + def test_spherical_to_radial(self, difntl, vector): """Test ``coordinax.represent_as(RadialVector)``.""" radial = difntl.represent_as(cx.RadialVector, vector) @@ -860,7 +842,7 @@ def test_cylindrical_to_radial(self, difntl, vector): @pytest.mark.xfail(reason="Not implemented") @pytest.mark.filterwarnings("ignore:Explicitly requested dtype") - def test_cylindrical_to_cartesian2d(self, difntl, vector): + def test_spherical_to_cartesian2d(self, difntl, vector): """Test ``coordinax.represent_as(Cartesian2DVector)``.""" cart2d = difntl.represent_as(cx.CartesianDifferential2D, vector) @@ -870,7 +852,7 @@ def test_cylindrical_to_cartesian2d(self, difntl, vector): @pytest.mark.xfail(reason="Not implemented") @pytest.mark.filterwarnings("ignore:Explicitly requested dtype") - def test_cylindrical_to_polar(self, difntl, vector): + def test_spherical_to_polar(self, difntl, vector): """Test ``coordinax.represent_as(PolarVector)``.""" polar = difntl.represent_as(cx.PolarVector, vector) @@ -878,49 +860,80 @@ def test_cylindrical_to_polar(self, difntl, vector): assert qnp.array_equal(polar.d_r, Quantity([1, 2, 3, 4], "km/s")) assert qnp.array_equal(polar.d_phi, Quantity([5, 6, 7, 8], "mas/yr")) - def test_cylindrical_to_cartesian3d(self, difntl, vector, apydifntl, apyvector): + def test_spherical_to_cartesian3d(self, difntl, vector): """Test ``coordinax.represent_as(Cartesian3DVector)``.""" cart3d = difntl.represent_as(cx.CartesianDifferential3D, vector) assert isinstance(cart3d, cx.CartesianDifferential3D) - assert qnp.array_equal( - cart3d.d_x, Quantity([5.0, -76.537544, -145.15944, -40.03075], "km/s") + assert qnp.allclose( + cart3d.d_x, + Quantity([61.803337, -7.770853, -60.081947, 1.985678], "km/s"), + atol=Quantity(1e-8, "km/s"), ) - assert qnp.array_equal( + assert qnp.allclose( cart3d.d_y, - Quantity([42.664234, 56.274563, -58.73506, -224.13647], "km/s"), + Quantity([2.2328734, 106.6765, -144.60716, 303.30875], "km/s"), + atol=Quantity(1e-8, "km/s"), ) - assert qnp.array_equal(cart3d.d_z, Quantity([13.0, 14.0, 15.0, 16.0], "km/s")) + assert qnp.allclose( + cart3d.d_z, + Quantity([1.7678856, -115.542175, -213.32118, -10.647271], "km/s"), + atol=Quantity(1e-8, "km/s"), + ) + + def test_spherical_to_cartesian3d_astropy( + self, difntl, vector, apydifntl, apyvector + ): + """Test Astropy equivalence.""" + cart3d = difntl.represent_as(cx.CartesianDifferential3D, vector) apycart3 = apydifntl.represent_as(apyc.CartesianDifferential, apyvector) assert np.allclose(convert(cart3d.d_x, APYQuantity), apycart3.d_x) assert np.allclose(convert(cart3d.d_y, APYQuantity), apycart3.d_y) assert np.allclose(convert(cart3d.d_z, APYQuantity), apycart3.d_z) - def test_cylindrical_to_spherical(self, difntl, vector): - """Test ``coordinax.represent_as(SphericalDifferential)``.""" - dsph = difntl.represent_as(cx.SphericalDifferential, vector) + def test_spherical_to_cylindrical(self, difntl, vector): + """Test ``coordinax.represent_as(CylindricalDifferential)``.""" + cylindrical = difntl.represent_as(cx.CylindricalDifferential, vector) - assert isinstance(dsph, cx.SphericalDifferential) - assert qnp.array_equal( - dsph.d_r, - Quantity([13.472646, 14.904826, 16.313278, 17.708754], "km/s"), + assert isinstance(cylindrical, cx.CylindricalDifferential) + assert qnp.allclose( + cylindrical.d_rho, + Quantity([61.803337, 65.60564, 6.9999905, -303.30875], "km/s"), + atol=Quantity(1e-8, "km/s"), ) - assert qnp.array_equal( - dsph.d_phi, - Quantity([42.664234, 47.404705, 52.145176, 56.885643], "km rad / (kpc s)"), + assert qnp.allclose( + cylindrical.d_phi, + Quantity([2444.4805, 2716.0894, 2987.6985, 3259.3074], "deg km / (kpc s)"), + atol=Quantity(1e-8, "mas/yr"), ) assert qnp.allclose( - dsph.d_theta, - Quantity( - [0.3902412, 0.30769292, 0.24615361, 0.19999981], "km rad / (kpc s)" - ), - atol=Quantity(5e-7, "km rad / (kpc s)"), + cylindrical.d_z, + Quantity([1.7678856, -115.542175, -213.32118, -10.647271], "km/s"), + atol=Quantity(1e-8, "km/s"), ) - def test_cylindrical_to_spherical_astropy( + def test_spherical_to_cylindrical_astropy( self, difntl, vector, apydifntl, apyvector ): + """Test Astropy equivalence.""" + cyl = difntl.represent_as(cx.CylindricalDifferential, vector) + apycyl = apydifntl.represent_as(apyc.CylindricalDifferential, apyvector) + assert np.allclose(convert(cyl.d_rho, APYQuantity), apycyl.d_rho) + assert np.allclose(convert(cyl.d_z, APYQuantity), apycyl.d_z) + assert np.allclose(convert(cyl.d_phi, APYQuantity), apycyl.d_phi) + + def test_spherical_to_spherical(self, difntl, vector): + """Test ``coordinax.represent_as(SphericalDifferential)``.""" + # Jit can copy + newvec = difntl.represent_as(cx.SphericalDifferential, vector) + assert newvec == difntl + + # The normal `represent_as` method should return the same object + newvec = cx.represent_as(difntl, cx.SphericalDifferential, vector) + assert newvec is difntl + + def test_spherical_to_spherical_astropy(self, difntl, vector, apydifntl, apyvector): """Test Astropy equivalence.""" sph = difntl.represent_as(cx.SphericalDifferential, vector) apysph = apydifntl.represent_as(apyc.PhysicsSphericalDifferential, apyvector) @@ -928,20 +941,35 @@ def test_cylindrical_to_spherical_astropy( assert np.allclose(convert(sph.d_theta, APYQuantity), apysph.d_theta) assert np.allclose(convert(sph.d_phi, APYQuantity), apysph.d_phi) - def test_cylindrical_to_cylindrical(self, difntl, vector): - """Test ``coordinax.represent_as(CylindricalDifferential)``.""" - # Jit can copy - newvec = difntl.represent_as(cx.CylindricalDifferential, vector) - assert newvec == difntl + def test_spherical_to_lonlatspherical(self, difntl, vector): + """Test ``coordinax.represent_as(LonLatSphericalDifferential)``.""" + llsph = difntl.represent_as(cx.LonLatSphericalDifferential, vector) - # The normal `represent_as` method should return the same object - newvec = cx.represent_as(difntl, cx.CylindricalDifferential, vector) - assert newvec is difntl + assert isinstance(llsph, cx.LonLatSphericalDifferential) + assert qnp.array_equal(llsph.d_distance, difntl.d_r) + assert qnp.array_equal(llsph.d_lon, difntl.d_phi) + assert qnp.allclose( + llsph.d_lat, + Quantity([-13.0, -14.0, -15.0, -16.0], "mas/yr"), + atol=Quantity(1e-8, "mas/yr"), + ) - def test_cylindrical_to_cylindrical(self, difntl, vector, apydifntl, apyvector): + def test_spherical_to_lonlatspherical_astropy( + self, difntl, vector, apydifntl, apyvector + ): """Test Astropy equivalence.""" - cyl = difntl.represent_as(cx.CylindricalDifferential, vector) - apycyl = apydifntl.represent_as(apyc.CylindricalDifferential, apyvector) - assert np.allclose(convert(cyl.d_rho, APYQuantity), apycyl.d_rho) - assert np.allclose(convert(cyl.d_z, APYQuantity), apycyl.d_z) - assert np.allclose(convert(cyl.d_phi, APYQuantity), apycyl.d_phi) + cart3d = difntl.represent_as(cx.LonLatSphericalDifferential, vector) + + apycart3 = apydifntl.represent_as(apyc.SphericalDifferential, apyvector) + assert np.allclose(convert(cart3d.d_distance, APYQuantity), apycart3.d_distance) + assert np.allclose(convert(cart3d.d_lon, APYQuantity), apycart3.d_lon) + assert np.allclose(convert(cart3d.d_lat, APYQuantity), apycart3.d_lat) + + def test_spherical_to_mathspherical(self, difntl, vector): + """Test ``coordinax.represent_as(MathSpherical)``.""" + llsph = difntl.represent_as(cx.MathSphericalDifferential, vector) + + assert isinstance(llsph, cx.MathSphericalDifferential) + assert qnp.array_equal(llsph.d_r, difntl.d_r) + assert qnp.array_equal(llsph.d_theta, difntl.d_phi) + assert qnp.array_equal(llsph.d_phi, difntl.d_theta)