From b7dcbe778d661d8598016e26b9307f9996e3bc02 Mon Sep 17 00:00:00 2001
From: nstarman <nstarman@users.noreply.github.com>
Date: Sat, 24 Feb 2024 14:27:38 -0500
Subject: [PATCH] vector norm convenience method

Signed-off-by: nstarman <nstarman@users.noreply.github.com>
---
 pyproject.toml            |  5 ++++-
 src/vector/_base.py       |  6 ++++++
 src/vector/_d1/base.py    | 11 +++++++++++
 src/vector/_d1/builtin.py |  8 ++++++++
 src/vector/_d2/base.py    | 11 +++++++++++
 src/vector/_d2/builtin.py | 13 +++++++++++++
 src/vector/_d3/base.py    | 11 +++++++++++
 src/vector/_d3/builtin.py | 18 ++++++++++++++++++
 8 files changed, 82 insertions(+), 1 deletion(-)

diff --git a/pyproject.toml b/pyproject.toml
index d1d64b81..961d1b70 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -98,7 +98,10 @@
     ]
 
 [tool.mypy]
-  disable_error_code = ["no-redef"]
+  disable_error_code = [
+    "no-redef",  # for plum-dispatch
+    "name-defined",  # for jaxtyping
+  ]
   disallow_incomplete_defs = false
   disallow_untyped_defs = false
   enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"]
diff --git a/src/vector/_base.py b/src/vector/_base.py
index c7f71593..b8b28ee9 100644
--- a/src/vector/_base.py
+++ b/src/vector/_base.py
@@ -136,6 +136,12 @@ def represent_as(self, target: type[VT], /, *args: Any, **kwargs: Any) -> VT:
 
         return represent_as(self, target, **kwargs)
 
+    @abstractmethod
+    def norm(self) -> Quantity:
+        """Return the norm of the vector."""
+        # TODO: make a generic method that works on all dimensions
+        raise NotImplementedError
+
 
 class AbstractVectorDifferential(AbstractVectorBase):
     """Abstract representation of vector differentials in different systems."""
diff --git a/src/vector/_d1/base.py b/src/vector/_d1/base.py
index aaab2c8e..d06a7f06 100644
--- a/src/vector/_d1/base.py
+++ b/src/vector/_d1/base.py
@@ -3,7 +3,11 @@
 __all__ = ["Abstract1DVector", "Abstract1DVectorDifferential"]
 
 
+from functools import partial
+
 import equinox as eqx
+import jax
+from jax_quantity import Quantity
 
 from vector._base import AbstractVector, AbstractVectorDifferential
 
@@ -11,6 +15,13 @@
 class Abstract1DVector(AbstractVector):
     """Abstract representation of 1D coordinates in different systems."""
 
+    @partial(jax.jit)
+    def norm(self) -> Quantity["length"]:
+        """Return the norm of the vector."""
+        from .builtin import Cartesian1DVector  # pylint: disable=C0415
+
+        return self.represent_as(Cartesian1DVector).norm()
+
 
 class Abstract1DVectorDifferential(AbstractVectorDifferential):
     """Abstract representation of 1D differentials in different systems."""
diff --git a/src/vector/_d1/builtin.py b/src/vector/_d1/builtin.py
index 1b06d8ac..bc48c3a7 100644
--- a/src/vector/_d1/builtin.py
+++ b/src/vector/_d1/builtin.py
@@ -9,9 +9,12 @@
     "RadialDifferential",
 ]
 
+from functools import partial
 from typing import ClassVar, final
 
+import array_api_jax_compat as xp
 import equinox as eqx
+import jax
 
 from vector._checks import check_r_non_negative
 from vector._typing import BatchableLength, BatchableSpeed
@@ -30,6 +33,11 @@ class Cartesian1DVector(Abstract1DVector):
     x: BatchableLength = eqx.field(converter=converter_quantity_array)
     r"""X coordinate :math:`x \in (-\infty,+\infty)`."""
 
+    @partial(jax.jit)
+    def norm(self) -> BatchableLength:
+        """Return the norm of the vector."""
+        return xp.abs(self.x)
+
 
 @final
 class RadialVector(Abstract1DVector):
diff --git a/src/vector/_d2/base.py b/src/vector/_d2/base.py
index 563f27c5..fd0d7384 100644
--- a/src/vector/_d2/base.py
+++ b/src/vector/_d2/base.py
@@ -3,7 +3,11 @@
 __all__ = ["Abstract2DVector", "Abstract2DVectorDifferential"]
 
 
+from functools import partial
+
 import equinox as eqx
+import jax
+from jax_quantity import Quantity
 
 from vector._base import AbstractVector, AbstractVectorDifferential
 
@@ -11,6 +15,13 @@
 class Abstract2DVector(AbstractVector):
     """Abstract representation of 2D coordinates in different systems."""
 
+    @partial(jax.jit)
+    def norm(self) -> Quantity["length"]:
+        """Return the norm of the vector."""
+        from .builtin import Cartesian2DVector  # pylint: disable=C0415
+
+        return self.represent_as(Cartesian2DVector).norm()
+
 
 class Abstract2DVectorDifferential(AbstractVectorDifferential):
     """Abstract representation of 2D vector differentials."""
diff --git a/src/vector/_d2/builtin.py b/src/vector/_d2/builtin.py
index 2dc5b1f6..74b98b6f 100644
--- a/src/vector/_d2/builtin.py
+++ b/src/vector/_d2/builtin.py
@@ -11,9 +11,12 @@
     "PolarDifferential",
 ]
 
+from functools import partial
 from typing import ClassVar, final
 
+import array_api_jax_compat as xp
 import equinox as eqx
+import jax
 
 from vector._checks import check_phi_range, check_r_non_negative
 from vector._typing import (
@@ -40,6 +43,11 @@ class Cartesian2DVector(Abstract2DVector):
     y: BatchableLength = eqx.field(converter=converter_quantity_array)
     r"""Y coordinate :math:`y \in (-\infty,+\infty)`."""
 
+    @partial(jax.jit)
+    def norm(self) -> BatchableLength:
+        """Return the norm of the vector."""
+        return xp.sqrt(self.x**2 + self.y**2)
+
 
 @final
 class PolarVector(Abstract2DVector):
@@ -59,6 +67,11 @@ def __check_init__(self) -> None:
         check_r_non_negative(self.r)
         check_phi_range(self.phi)
 
+    @partial(jax.jit)
+    def norm(self) -> BatchableLength:
+        """Return the norm of the vector."""
+        return self.r
+
 
 # class LnPolarVector(Abstract2DVector):
 #     """Log-polar vector representation."""
diff --git a/src/vector/_d3/base.py b/src/vector/_d3/base.py
index c0f143eb..3fda3a0d 100644
--- a/src/vector/_d3/base.py
+++ b/src/vector/_d3/base.py
@@ -3,7 +3,11 @@
 __all__ = ["Abstract3DVector", "Abstract3DVectorDifferential"]
 
 
+from functools import partial
+
 import equinox as eqx
+import jax
+from jax_quantity import Quantity
 
 from vector._base import AbstractVector, AbstractVectorDifferential
 
@@ -11,6 +15,13 @@
 class Abstract3DVector(AbstractVector):
     """Abstract representation of 3D coordinates in different systems."""
 
+    @partial(jax.jit)
+    def norm(self) -> Quantity["length"]:
+        """Return the norm of the vector."""
+        from .builtin import Cartesian3DVector  # pylint: disable=C0415
+
+        return self.represent_as(Cartesian3DVector).norm()
+
 
 class Abstract3DVectorDifferential(AbstractVectorDifferential):
     """Abstract representation of 3D vector differentials."""
diff --git a/src/vector/_d3/builtin.py b/src/vector/_d3/builtin.py
index 1b82db76..a1882230 100644
--- a/src/vector/_d3/builtin.py
+++ b/src/vector/_d3/builtin.py
@@ -11,9 +11,12 @@
     "CylindricalDifferential",
 ]
 
+from functools import partial
 from typing import ClassVar, final
 
+import array_api_jax_compat as xp
 import equinox as eqx
+import jax
 
 from vector._checks import check_phi_range, check_r_non_negative, check_theta_range
 from vector._typing import (
@@ -43,6 +46,11 @@ class Cartesian3DVector(Abstract3DVector):
     z: BatchableLength = eqx.field(converter=converter_quantity_array)
     r"""Z coordinate :math:`z \in (-\infty,+\infty)`."""
 
+    @partial(jax.jit)
+    def norm(self) -> BatchableLength:
+        """Return the norm of the vector."""
+        return xp.sqrt(self.x**2 + self.y**2 + self.z**2)
+
 
 @final
 class SphericalVector(Abstract3DVector):
@@ -63,6 +71,11 @@ def __check_init__(self) -> None:
         check_theta_range(self.theta)
         check_phi_range(self.phi)
 
+    @partial(jax.jit)
+    def norm(self) -> BatchableLength:
+        """Return the norm of the vector."""
+        return self.r
+
 
 @final
 class CylindricalVector(Abstract3DVector):
@@ -82,6 +95,11 @@ def __check_init__(self) -> None:
         check_r_non_negative(self.rho)
         check_phi_range(self.phi)
 
+    @partial(jax.jit)
+    def norm(self) -> BatchableLength:
+        """Return the norm of the vector."""
+        return xp.sqrt(self.rho**2 + self.z**2)
+
 
 ##############################################################################
 # Differential