Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: negate diff #45

Merged
merged 1 commit into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions src/vector/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,31 @@ def integral_cls(cls) -> type["AbstractVectorDifferential"]:
"""
raise NotImplementedError

# ===============================================================
# Unary operations

def __neg__(self) -> "Self":
"""Negate the vector.

Examples
--------
>>> from jax_quantity import Quantity
>>> from vector import RadialDifferential
>>> dr = RadialDifferential(Quantity(1, "m/s"))
>>> -dr
RadialDifferential( d_r=Quantity[...]( value=f32[], unit=Unit("m / s") ) )

>>> from vector import PolarDifferential
>>> dp = PolarDifferential(Quantity(1, "m/s"), Quantity(1, "mas/yr"))
>>> neg_dp = -dp
>>> neg_dp.d_r
Quantity['speed'](Array(-1., dtype=float32), unit='m / s')
>>> neg_dp.d_phi
Quantity['angular frequency'](Array(-1., dtype=float32), unit='mas / yr')

"""
return replace(self, **{k: -v for k, v in dataclass_items(self)})

# ===============================================================
# Binary operations

Expand Down
18 changes: 14 additions & 4 deletions tests/test_d3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import jax.numpy as jnp
import numpy as np
import pytest
from astropy.coordinates.tests.test_representation import representation_equal
from plum import convert

import array_api_jax_compat as xp
Expand Down Expand Up @@ -36,7 +37,7 @@ class Abstract3DVectorTest(AbstractVectorTest):
"""Test :class:`vector.Abstract3DVector`."""


class TestCartesian3DVector:
class TestCartesian3DVector(Abstract3DVectorTest):
"""Test :class:`vector.Cartesian3DVector`."""

@pytest.fixture(scope="class")
Expand All @@ -51,7 +52,7 @@ def vector(self) -> AbstractVector:
)

@pytest.fixture(scope="class")
def apyvector(self, vector: AbstractVector):
def apyvector(self, vector: AbstractVector) -> apyc.CartesianRepresentation:
"""Return an Astropy vector."""
return convert(vector, apyc.CartesianRepresentation)

Expand Down Expand Up @@ -171,7 +172,7 @@ def test_cartesian3d_to_cylindrical_astropy(self, vector, apyvector):
assert np.allclose(convert(cyl.phi, u.Quantity), apycyl.phi)


class TestSphericalVector:
class TestSphericalVector(Abstract3DVectorTest):
"""Test :class:`vector.SphericalVector`."""

@pytest.fixture(scope="class")
Expand Down Expand Up @@ -315,7 +316,7 @@ def test_spherical_to_cylindrical_astropy(self, vector, apyvector):
assert np.allclose(convert(cyl.phi, u.Quantity), apycyl.phi)


class TestCylindricalVector:
class TestCylindricalVector(Abstract3DVectorTest):
"""Test :class:`vector.CylindricalVector`."""

@pytest.fixture(scope="class")
Expand Down Expand Up @@ -449,6 +450,15 @@ def test_cylindrical_to_cylindrical_astropy(self, vector, apyvector):
class Abstract3DVectorDifferentialTest(AbstractVectorDifferentialTest):
"""Test :class:`vector.Abstract2DVectorDifferential`."""

# ==========================================================================
# Unary operations

def test_neg_compare_apy(
self, difntl: AbstractVector, apydifntl: apyc.BaseRepresentation
):
"""Test negation."""
assert all(representation_equal(convert(-difntl, type(apydifntl)), -apydifntl))


class TestCartesianDifferential3D(Abstract3DVectorDifferentialTest):
"""Test :class:`vector.CartesianDifferential3D`."""
Expand Down