From b7dcbe778d661d8598016e26b9307f9996e3bc02 Mon Sep 17 00:00:00 2001 From: nstarman <nstarman@users.noreply.github.com> Date: Sat, 24 Feb 2024 14:27:38 -0500 Subject: [PATCH] vector norm convenience method Signed-off-by: nstarman <nstarman@users.noreply.github.com> --- pyproject.toml | 5 ++++- src/vector/_base.py | 6 ++++++ src/vector/_d1/base.py | 11 +++++++++++ src/vector/_d1/builtin.py | 8 ++++++++ src/vector/_d2/base.py | 11 +++++++++++ src/vector/_d2/builtin.py | 13 +++++++++++++ src/vector/_d3/base.py | 11 +++++++++++ src/vector/_d3/builtin.py | 18 ++++++++++++++++++ 8 files changed, 82 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index d1d64b81..961d1b70 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,7 +98,10 @@ ] [tool.mypy] - disable_error_code = ["no-redef"] + disable_error_code = [ + "no-redef", # for plum-dispatch + "name-defined", # for jaxtyping + ] disallow_incomplete_defs = false disallow_untyped_defs = false enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"] diff --git a/src/vector/_base.py b/src/vector/_base.py index c7f71593..b8b28ee9 100644 --- a/src/vector/_base.py +++ b/src/vector/_base.py @@ -136,6 +136,12 @@ def represent_as(self, target: type[VT], /, *args: Any, **kwargs: Any) -> VT: return represent_as(self, target, **kwargs) + @abstractmethod + def norm(self) -> Quantity: + """Return the norm of the vector.""" + # TODO: make a generic method that works on all dimensions + raise NotImplementedError + class AbstractVectorDifferential(AbstractVectorBase): """Abstract representation of vector differentials in different systems.""" diff --git a/src/vector/_d1/base.py b/src/vector/_d1/base.py index aaab2c8e..d06a7f06 100644 --- a/src/vector/_d1/base.py +++ b/src/vector/_d1/base.py @@ -3,7 +3,11 @@ __all__ = ["Abstract1DVector", "Abstract1DVectorDifferential"] +from functools import partial + import equinox as eqx +import jax +from jax_quantity import Quantity from vector._base import AbstractVector, AbstractVectorDifferential @@ -11,6 +15,13 @@ class Abstract1DVector(AbstractVector): """Abstract representation of 1D coordinates in different systems.""" + @partial(jax.jit) + def norm(self) -> Quantity["length"]: + """Return the norm of the vector.""" + from .builtin import Cartesian1DVector # pylint: disable=C0415 + + return self.represent_as(Cartesian1DVector).norm() + class Abstract1DVectorDifferential(AbstractVectorDifferential): """Abstract representation of 1D differentials in different systems.""" diff --git a/src/vector/_d1/builtin.py b/src/vector/_d1/builtin.py index 1b06d8ac..bc48c3a7 100644 --- a/src/vector/_d1/builtin.py +++ b/src/vector/_d1/builtin.py @@ -9,9 +9,12 @@ "RadialDifferential", ] +from functools import partial from typing import ClassVar, final +import array_api_jax_compat as xp import equinox as eqx +import jax from vector._checks import check_r_non_negative from vector._typing import BatchableLength, BatchableSpeed @@ -30,6 +33,11 @@ class Cartesian1DVector(Abstract1DVector): x: BatchableLength = eqx.field(converter=converter_quantity_array) r"""X coordinate :math:`x \in (-\infty,+\infty)`.""" + @partial(jax.jit) + def norm(self) -> BatchableLength: + """Return the norm of the vector.""" + return xp.abs(self.x) + @final class RadialVector(Abstract1DVector): diff --git a/src/vector/_d2/base.py b/src/vector/_d2/base.py index 563f27c5..fd0d7384 100644 --- a/src/vector/_d2/base.py +++ b/src/vector/_d2/base.py @@ -3,7 +3,11 @@ __all__ = ["Abstract2DVector", "Abstract2DVectorDifferential"] +from functools import partial + import equinox as eqx +import jax +from jax_quantity import Quantity from vector._base import AbstractVector, AbstractVectorDifferential @@ -11,6 +15,13 @@ class Abstract2DVector(AbstractVector): """Abstract representation of 2D coordinates in different systems.""" + @partial(jax.jit) + def norm(self) -> Quantity["length"]: + """Return the norm of the vector.""" + from .builtin import Cartesian2DVector # pylint: disable=C0415 + + return self.represent_as(Cartesian2DVector).norm() + class Abstract2DVectorDifferential(AbstractVectorDifferential): """Abstract representation of 2D vector differentials.""" diff --git a/src/vector/_d2/builtin.py b/src/vector/_d2/builtin.py index 2dc5b1f6..74b98b6f 100644 --- a/src/vector/_d2/builtin.py +++ b/src/vector/_d2/builtin.py @@ -11,9 +11,12 @@ "PolarDifferential", ] +from functools import partial from typing import ClassVar, final +import array_api_jax_compat as xp import equinox as eqx +import jax from vector._checks import check_phi_range, check_r_non_negative from vector._typing import ( @@ -40,6 +43,11 @@ class Cartesian2DVector(Abstract2DVector): y: BatchableLength = eqx.field(converter=converter_quantity_array) r"""Y coordinate :math:`y \in (-\infty,+\infty)`.""" + @partial(jax.jit) + def norm(self) -> BatchableLength: + """Return the norm of the vector.""" + return xp.sqrt(self.x**2 + self.y**2) + @final class PolarVector(Abstract2DVector): @@ -59,6 +67,11 @@ def __check_init__(self) -> None: check_r_non_negative(self.r) check_phi_range(self.phi) + @partial(jax.jit) + def norm(self) -> BatchableLength: + """Return the norm of the vector.""" + return self.r + # class LnPolarVector(Abstract2DVector): # """Log-polar vector representation.""" diff --git a/src/vector/_d3/base.py b/src/vector/_d3/base.py index c0f143eb..3fda3a0d 100644 --- a/src/vector/_d3/base.py +++ b/src/vector/_d3/base.py @@ -3,7 +3,11 @@ __all__ = ["Abstract3DVector", "Abstract3DVectorDifferential"] +from functools import partial + import equinox as eqx +import jax +from jax_quantity import Quantity from vector._base import AbstractVector, AbstractVectorDifferential @@ -11,6 +15,13 @@ class Abstract3DVector(AbstractVector): """Abstract representation of 3D coordinates in different systems.""" + @partial(jax.jit) + def norm(self) -> Quantity["length"]: + """Return the norm of the vector.""" + from .builtin import Cartesian3DVector # pylint: disable=C0415 + + return self.represent_as(Cartesian3DVector).norm() + class Abstract3DVectorDifferential(AbstractVectorDifferential): """Abstract representation of 3D vector differentials.""" diff --git a/src/vector/_d3/builtin.py b/src/vector/_d3/builtin.py index 1b82db76..a1882230 100644 --- a/src/vector/_d3/builtin.py +++ b/src/vector/_d3/builtin.py @@ -11,9 +11,12 @@ "CylindricalDifferential", ] +from functools import partial from typing import ClassVar, final +import array_api_jax_compat as xp import equinox as eqx +import jax from vector._checks import check_phi_range, check_r_non_negative, check_theta_range from vector._typing import ( @@ -43,6 +46,11 @@ class Cartesian3DVector(Abstract3DVector): z: BatchableLength = eqx.field(converter=converter_quantity_array) r"""Z coordinate :math:`z \in (-\infty,+\infty)`.""" + @partial(jax.jit) + def norm(self) -> BatchableLength: + """Return the norm of the vector.""" + return xp.sqrt(self.x**2 + self.y**2 + self.z**2) + @final class SphericalVector(Abstract3DVector): @@ -63,6 +71,11 @@ def __check_init__(self) -> None: check_theta_range(self.theta) check_phi_range(self.phi) + @partial(jax.jit) + def norm(self) -> BatchableLength: + """Return the norm of the vector.""" + return self.r + @final class CylindricalVector(Abstract3DVector): @@ -82,6 +95,11 @@ def __check_init__(self) -> None: check_r_non_negative(self.rho) check_phi_range(self.phi) + @partial(jax.jit) + def norm(self) -> BatchableLength: + """Return the norm of the vector.""" + return xp.sqrt(self.rho**2 + self.z**2) + ############################################################################## # Differential