Skip to content

Commit

Permalink
refactor-convert_to (#31)
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman authored Feb 25, 2024
1 parent a1334a6 commit 014cbce
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 73 deletions.
38 changes: 19 additions & 19 deletions src/vector/_d3/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import astropy.units as apyu
from jax_quantity import Quantity
from jaxtyping import Shaped
from plum import conversion_method
from plum import conversion_method, convert

from vector._utils import dataclass_values, full_shaped

Expand Down Expand Up @@ -53,9 +53,9 @@ def vec_diff_to_q(
def cart3_to_apycart3(obj: Cartesian3DVector, /) -> apyc.CartesianRepresentation:
"""`vector.Cartesian3DVector` -> `astropy.CartesianRepresentation`."""
return apyc.CartesianRepresentation(
x=obj.x.as_type(apyu.Quantity),
y=obj.y.as_type(apyu.Quantity),
z=obj.z.as_type(apyu.Quantity),
x=convert(obj.x, apyu.Quantity),
y=convert(obj.y, apyu.Quantity),
z=convert(obj.z, apyu.Quantity),
)


Expand All @@ -76,9 +76,9 @@ def apycart3_to_cart3(obj: apyc.CartesianRepresentation, /) -> Cartesian3DVector
def sph_to_apysph(obj: SphericalVector, /) -> apyc.PhysicsSphericalRepresentation:
"""`vector.SphericalVector` -> `astropy.PhysicsSphericalRepresentation`."""
return apyc.PhysicsSphericalRepresentation(
r=obj.r.as_type(apyu.Quantity),
phi=obj.phi.as_type(apyu.Quantity),
theta=obj.theta.as_type(apyu.Quantity),
r=convert(obj.r, apyu.Quantity),
phi=convert(obj.phi, apyu.Quantity),
theta=convert(obj.theta, apyu.Quantity),
)


Expand All @@ -99,9 +99,9 @@ def apysph_to_sph(obj: apyc.PhysicsSphericalRepresentation, /) -> SphericalVecto
def cyl_to_apycyl(obj: CylindricalVector, /) -> apyc.CylindricalRepresentation:
"""`vector.CylindricalVector` -> `astropy.CylindricalRepresentation`."""
return apyc.CylindricalRepresentation(
rho=obj.rho.as_type(apyu.Quantity),
phi=obj.phi.as_type(apyu.Quantity),
z=obj.z.as_type(apyu.Quantity),
rho=convert(obj.rho, apyu.Quantity),
phi=convert(obj.phi, apyu.Quantity),
z=convert(obj.z, apyu.Quantity),
)


Expand All @@ -123,9 +123,9 @@ def diffcart3_to_apycart3(
) -> apyc.CartesianDifferential:
"""`vector.CartesianDifferential3D` -> `astropy.CartesianDifferential`."""
return apyc.CartesianDifferential(
d_x=obj.d_x.as_type(apyu.Quantity),
d_y=obj.d_y.as_type(apyu.Quantity),
d_z=obj.d_z.as_type(apyu.Quantity),
d_x=convert(obj.d_x, apyu.Quantity),
d_y=convert(obj.d_y, apyu.Quantity),
d_z=convert(obj.d_z, apyu.Quantity),
)


Expand All @@ -152,9 +152,9 @@ def diffsph_to_apysph(
) -> apyc.PhysicsSphericalDifferential:
"""`vector.SphericalDifferential` -> `astropy.PhysicsSphericalDifferential`."""
return apyc.PhysicsSphericalDifferential(
d_r=obj.d_r.as_type(apyu.Quantity),
d_phi=obj.d_phi.as_type(apyu.Quantity),
d_theta=obj.d_theta.as_type(apyu.Quantity),
d_r=convert(obj.d_r, apyu.Quantity),
d_phi=convert(obj.d_phi, apyu.Quantity),
d_theta=convert(obj.d_theta, apyu.Quantity),
)


Expand All @@ -179,9 +179,9 @@ def apysph_to_diffsph(
def diffcyl_to_apycyl(obj: CylindricalDifferential, /) -> apyc.CylindricalDifferential:
"""`vector.CylindricalDifferential` -> `astropy.CylindricalDifferential`."""
return apyc.CylindricalDifferential(
d_rho=obj.d_rho.as_type(apyu.Quantity),
d_phi=obj.d_phi.as_type(apyu.Quantity),
d_z=obj.d_z.as_type(apyu.Quantity),
d_rho=convert(obj.d_rho, apyu.Quantity),
d_phi=convert(obj.d_phi, apyu.Quantity),
d_z=convert(obj.d_z, apyu.Quantity),
)


Expand Down
108 changes: 54 additions & 54 deletions tests/test_d3.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,9 @@ def test_cartesian3d_to_cartesian3d_astropy(self, vector, apyvector):
"""Test Astropy equivalence."""
newvec = vector.represent_as(Cartesian3DVector)

assert np.allclose(newvec.x.as_type(u.Quantity), apyvector.x)
assert np.allclose(newvec.y.as_type(u.Quantity), apyvector.y)
assert np.allclose(newvec.z.as_type(u.Quantity), apyvector.z)
assert np.allclose(convert(newvec.x, u.Quantity), apyvector.x)
assert np.allclose(convert(newvec.y, u.Quantity), apyvector.y)
assert np.allclose(convert(newvec.z, u.Quantity), apyvector.z)

def test_cartesian3d_to_spherical(self, vector):
"""Test ``vector.represent_as(SphericalVector)``."""
Expand All @@ -145,9 +145,9 @@ def test_cartesian3d_to_spherical_astropy(self, vector, apyvector):
sph = vector.represent_as(SphericalVector)

apysph = apyvector.represent_as(apyc.PhysicsSphericalRepresentation)
assert np.allclose(sph.r.as_type(u.Quantity), apysph.r)
assert np.allclose(sph.theta.as_type(u.Quantity), apysph.theta)
assert np.allclose(sph.phi.as_type(u.Quantity), apysph.phi)
assert np.allclose(convert(sph.r, u.Quantity), apysph.r)
assert np.allclose(convert(sph.theta, u.Quantity), apysph.theta)
assert np.allclose(convert(sph.phi, u.Quantity), apysph.phi)

def test_cartesian3d_to_cylindrical(self, vector):
"""Test ``vector.represent_as(CylindricalVector)``."""
Expand All @@ -166,9 +166,9 @@ def test_cartesian3d_to_cylindrical_astropy(self, vector, apyvector):
cyl = vector.represent_as(CylindricalVector)

apycyl = apyvector.represent_as(apyc.CylindricalRepresentation)
assert np.allclose(cyl.rho.as_type(u.Quantity), apycyl.rho)
assert np.allclose(cyl.z.as_type(u.Quantity), apycyl.z)
assert np.allclose(cyl.phi.as_type(u.Quantity), apycyl.phi)
assert np.allclose(convert(cyl.rho, u.Quantity), apycyl.rho)
assert np.allclose(convert(cyl.z, u.Quantity), apycyl.z)
assert np.allclose(convert(cyl.phi, u.Quantity), apycyl.phi)


class TestSphericalVector:
Expand Down Expand Up @@ -264,9 +264,9 @@ def test_spherical_to_cartesian3d_astropy(self, vector, apyvector):
cart3d = vector.represent_as(Cartesian3DVector)

apycart3 = apyvector.represent_as(apyc.CartesianRepresentation)
assert np.allclose(cart3d.x.as_type(u.Quantity), apycart3.x)
assert np.allclose(cart3d.y.as_type(u.Quantity), apycart3.y)
assert np.allclose(cart3d.z.as_type(u.Quantity), apycart3.z)
assert np.allclose(convert(cart3d.x, u.Quantity), apycart3.x)
assert np.allclose(convert(cart3d.y, u.Quantity), apycart3.y)
assert np.allclose(convert(cart3d.z, u.Quantity), apycart3.z)

def test_spherical_to_spherical(self, vector):
"""Test ``vector.represent_as(SphericalVector)``."""
Expand All @@ -283,9 +283,9 @@ def test_spherical_to_spherical_astropy(self, vector, apyvector):
sph = vector.represent_as(SphericalVector)

apysph = apyvector.represent_as(apyc.PhysicsSphericalRepresentation)
assert np.allclose(sph.r.as_type(u.Quantity), apysph.r)
assert np.allclose(sph.theta.as_type(u.Quantity), apysph.theta)
assert np.allclose(sph.phi.as_type(u.Quantity), apysph.phi)
assert np.allclose(convert(sph.r, u.Quantity), apysph.r)
assert np.allclose(convert(sph.theta, u.Quantity), apysph.theta)
assert np.allclose(convert(sph.phi, u.Quantity), apysph.phi)

def test_spherical_to_cylindrical(self, vector):
"""Test ``vector.represent_as(CylindricalVector)``."""
Expand All @@ -308,11 +308,11 @@ def test_spherical_to_cylindrical_astropy(self, vector, apyvector):
cyl = vector.represent_as(CylindricalVector, z=Quantity([9, 10, 11, 12], u.m))

apycyl = apyvector.represent_as(apyc.CylindricalRepresentation)
assert np.allclose(cyl.rho.as_type(u.Quantity), apycyl.rho)
assert np.allclose(cyl.z.as_type(u.Quantity), apycyl.z)
assert np.allclose(convert(cyl.rho, u.Quantity), apycyl.rho)
assert np.allclose(convert(cyl.z, u.Quantity), apycyl.z)

with pytest.raises(AssertionError): # TODO: Fix this
assert np.allclose(cyl.phi.as_type(u.Quantity), apycyl.phi)
assert np.allclose(convert(cyl.phi, u.Quantity), apycyl.phi)


class TestCylindricalVector:
Expand Down Expand Up @@ -405,9 +405,9 @@ def test_cylindrical_to_cartesian3d_astropy(self, vector, apyvector):
cart3d = vector.represent_as(Cartesian3DVector)

apycart3 = apyvector.represent_as(apyc.CartesianRepresentation)
assert np.allclose(cart3d.x.as_type(u.Quantity), apycart3.x)
assert np.allclose(cart3d.y.as_type(u.Quantity), apycart3.y)
assert np.allclose(cart3d.z.as_type(u.Quantity), apycart3.z)
assert np.allclose(convert(cart3d.x, u.Quantity), apycart3.x)
assert np.allclose(convert(cart3d.y, u.Quantity), apycart3.y)
assert np.allclose(convert(cart3d.z, u.Quantity), apycart3.z)

def test_cylindrical_to_spherical(self, vector):
"""Test ``vector.represent_as(SphericalVector)``."""
Expand All @@ -422,9 +422,9 @@ def test_cylindrical_to_spherical_astropy(self, vector, apyvector):
"""Test Astropy equivalence."""
sph = vector.represent_as(SphericalVector)
apysph = apyvector.represent_as(apyc.PhysicsSphericalRepresentation)
assert np.allclose(sph.r.as_type(u.Quantity), apysph.r)
assert np.allclose(sph.theta.as_type(u.Quantity), apysph.theta)
assert np.allclose(sph.phi.as_type(u.Quantity), apysph.phi)
assert np.allclose(convert(sph.r, u.Quantity), apysph.r)
assert np.allclose(convert(sph.theta, u.Quantity), apysph.theta)
assert np.allclose(convert(sph.phi, u.Quantity), apysph.phi)

def test_cylindrical_to_cylindrical(self, vector):
"""Test ``vector.represent_as(CylindricalVector)``."""
Expand All @@ -441,9 +441,9 @@ def test_cylindrical_to_cylindrical_astropy(self, vector, apyvector):
cyl = vector.represent_as(CylindricalVector)

apycyl = apyvector.represent_as(apyc.CylindricalRepresentation)
assert np.allclose(cyl.rho.as_type(u.Quantity), apycyl.rho)
assert np.allclose(cyl.z.as_type(u.Quantity), apycyl.z)
assert np.allclose(cyl.phi.as_type(u.Quantity), apycyl.phi)
assert np.allclose(convert(cyl.rho, u.Quantity), apycyl.rho)
assert np.allclose(convert(cyl.z, u.Quantity), apycyl.z)
assert np.allclose(convert(cyl.phi, u.Quantity), apycyl.phi)


class Abstract3DVectorDifferentialTest(AbstractVectorDifferentialTest):
Expand Down Expand Up @@ -538,9 +538,9 @@ def test_cartesian3d_to_cartesian3d_astropy(
cart3 = difntl.represent_as(CartesianDifferential3D, vector)

apycart3 = apydifntl.represent_as(apyc.CartesianDifferential, apyvector)
assert np.allclose(cart3.d_x.as_type(u.Quantity), apycart3.d_x)
assert np.allclose(cart3.d_y.as_type(u.Quantity), apycart3.d_y)
assert np.allclose(cart3.d_z.as_type(u.Quantity), apycart3.d_z)
assert np.allclose(convert(cart3.d_x, u.Quantity), apycart3.d_x)
assert np.allclose(convert(cart3.d_y, u.Quantity), apycart3.d_y)
assert np.allclose(convert(cart3.d_z, u.Quantity), apycart3.d_z)

def test_cartesian3d_to_spherical(self, difntl, vector):
"""Test ``vector.represent_as(SphericalDifferential)``."""
Expand All @@ -566,16 +566,16 @@ def test_cartesian3d_to_spherical_astropy(
sph = difntl.represent_as(SphericalDifferential, vector)

apysph = apydifntl.represent_as(apyc.PhysicsSphericalDifferential, apyvector)
assert np.allclose(sph.d_r.as_type(u.Quantity), apysph.d_r)
assert np.allclose(convert(sph.d_r, u.Quantity), apysph.d_r)
with pytest.raises(AssertionError): # TODO: fixme
assert np.allclose(
sph.d_theta.as_type(u.Quantity).to(u.mas / u.Myr),
convert(sph.d_theta, u.Quantity).to(u.mas / u.Myr),
apysph.d_theta.to(u.mas / u.Myr),
atol=1e-9,
)
with pytest.raises(AssertionError): # TODO: fixme
assert np.allclose(
sph.d_phi.as_type(u.Quantity).to(u.mas / u.Myr),
convert(sph.d_phi, u.Quantity).to(u.mas / u.Myr),
apysph.d_phi.to(u.mas / u.Myr),
atol=1e-7,
)
Expand All @@ -600,10 +600,10 @@ def test_cartesian3d_to_spherical_astropy(
"""Test Astropy equivalence."""
cyl = difntl.represent_as(CylindricalDifferential, vector)
apycyl = apydifntl.represent_as(apyc.CylindricalDifferential, apyvector)
assert np.allclose(cyl.d_rho.as_type(u.Quantity), apycyl.d_rho)
assert np.allclose(cyl.d_z.as_type(u.Quantity), apycyl.d_z)
assert np.allclose(convert(cyl.d_rho, u.Quantity), apycyl.d_rho)
assert np.allclose(convert(cyl.d_z, u.Quantity), apycyl.d_z)
with pytest.raises(AssertionError): # TODO: fixme
assert np.allclose(cyl.d_phi.as_type(u.Quantity), apycyl.d_phi)
assert np.allclose(convert(cyl.d_phi, u.Quantity), apycyl.d_phi)


class TestSphericalDifferential(Abstract3DVectorDifferentialTest):
Expand Down Expand Up @@ -704,9 +704,9 @@ def test_spherical_to_cartesian3d_astropy(
cart3d = difntl.represent_as(CartesianDifferential3D, vector)

apycart3 = apydifntl.represent_as(apyc.CartesianDifferential, apyvector)
assert np.allclose(cart3d.d_x.as_type(u.Quantity), apycart3.d_x)
assert np.allclose(cart3d.d_y.as_type(u.Quantity), apycart3.d_y)
assert np.allclose(cart3d.d_z.as_type(u.Quantity), apycart3.d_z)
assert np.allclose(convert(cart3d.d_x, u.Quantity), apycart3.d_x)
assert np.allclose(convert(cart3d.d_y, u.Quantity), apycart3.d_y)
assert np.allclose(convert(cart3d.d_z, u.Quantity), apycart3.d_z)

def test_spherical_to_spherical(self, difntl, vector):
"""Test ``vector.represent_as(SphericalDifferential)``."""
Expand All @@ -722,9 +722,9 @@ def test_spherical_to_spherical_astropy(self, difntl, vector, apydifntl, apyvect
"""Test Astropy equivalence."""
sph = difntl.represent_as(SphericalDifferential, vector)
apysph = apydifntl.represent_as(apyc.PhysicsSphericalDifferential, apyvector)
assert np.allclose(sph.d_r.as_type(u.Quantity), apysph.d_r)
assert np.allclose(sph.d_theta.as_type(u.Quantity), apysph.d_theta)
assert np.allclose(sph.d_phi.as_type(u.Quantity), apysph.d_phi)
assert np.allclose(convert(sph.d_r, u.Quantity), apysph.d_r)
assert np.allclose(convert(sph.d_theta, u.Quantity), apysph.d_theta)
assert np.allclose(convert(sph.d_phi, u.Quantity), apysph.d_phi)

def test_spherical_to_cylindrical(self, difntl, vector):
"""Test ``vector.represent_as(CylindricalDifferential)``."""
Expand All @@ -747,9 +747,9 @@ def test_spherical_to_cylindrical_astropy(
"""Test Astropy equivalence."""
cyl = difntl.represent_as(CylindricalDifferential, vector)
apycyl = apydifntl.represent_as(apyc.CylindricalDifferential, apyvector)
assert np.allclose(cyl.d_rho.as_type(u.Quantity), apycyl.d_rho)
assert np.allclose(cyl.d_z.as_type(u.Quantity), apycyl.d_z)
assert np.allclose(cyl.d_phi.as_type(u.Quantity), apycyl.d_phi)
assert np.allclose(convert(cyl.d_rho, u.Quantity), apycyl.d_rho)
assert np.allclose(convert(cyl.d_z, u.Quantity), apycyl.d_z)
assert np.allclose(convert(cyl.d_phi, u.Quantity), apycyl.d_phi)


class TestCylindricalDifferential(Abstract3DVectorDifferentialTest):
Expand Down Expand Up @@ -838,9 +838,9 @@ def test_cylindrical_to_cartesian3d(self, difntl, vector, apydifntl, apyvector):
assert array_equal(cart3d.d_z, Quantity([9, 10, 11, 12], u.km / u.s))

apycart3 = apydifntl.represent_as(apyc.CartesianDifferential, apyvector)
assert np.allclose(cart3d.d_x.as_type(u.Quantity), apycart3.d_x)
assert np.allclose(cart3d.d_y.as_type(u.Quantity), apycart3.d_y)
assert np.allclose(cart3d.d_z.as_type(u.Quantity), apycart3.d_z)
assert np.allclose(convert(cart3d.d_x, u.Quantity), apycart3.d_x)
assert np.allclose(convert(cart3d.d_y, u.Quantity), apycart3.d_y)
assert np.allclose(convert(cart3d.d_z, u.Quantity), apycart3.d_z)

def test_cylindrical_to_spherical(self, difntl, vector):
"""Test ``vector.represent_as(SphericalDifferential)``."""
Expand All @@ -863,10 +863,10 @@ def test_cylindrical_to_spherical_astropy(
"""Test Astropy equivalence."""
sph = difntl.represent_as(SphericalDifferential, vector)
apysph = apydifntl.represent_as(apyc.PhysicsSphericalDifferential, apyvector)
assert np.allclose(sph.d_r.as_type(u.Quantity), apysph.d_r)
assert np.allclose(convert(sph.d_r, u.Quantity), apysph.d_r)
with pytest.raises(AssertionError):
assert np.allclose(sph.d_theta.as_type(u.Quantity), apysph.d_theta)
assert np.allclose(sph.d_phi.as_type(u.Quantity), apysph.d_phi)
assert np.allclose(convert(sph.d_theta, u.Quantity), apysph.d_theta)
assert np.allclose(convert(sph.d_phi, u.Quantity), apysph.d_phi)

def test_cylindrical_to_cylindrical(self, difntl, vector):
"""Test ``vector.represent_as(CylindricalDifferential)``."""
Expand All @@ -882,6 +882,6 @@ def test_cylindrical_to_cylindrical(self, difntl, vector, apydifntl, apyvector):
"""Test Astropy equivalence."""
cyl = difntl.represent_as(CylindricalDifferential, vector)
apycyl = apydifntl.represent_as(apyc.CylindricalDifferential, apyvector)
assert np.allclose(cyl.d_rho.as_type(u.Quantity), apycyl.d_rho)
assert np.allclose(cyl.d_z.as_type(u.Quantity), apycyl.d_z)
assert np.allclose(cyl.d_phi.as_type(u.Quantity), apycyl.d_phi)
assert np.allclose(convert(cyl.d_rho, u.Quantity), apycyl.d_rho)
assert np.allclose(convert(cyl.d_z, u.Quantity), apycyl.d_z)
assert np.allclose(convert(cyl.d_phi, u.Quantity), apycyl.d_phi)

0 comments on commit 014cbce

Please sign in to comment.