Skip to content

Commit

Permalink
vector norm convenience method
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman committed Feb 24, 2024
1 parent 27e7f9f commit b7dcbe7
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 1 deletion.
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,10 @@
]

[tool.mypy]
disable_error_code = ["no-redef"]
disable_error_code = [
"no-redef", # for plum-dispatch
"name-defined", # for jaxtyping
]
disallow_incomplete_defs = false
disallow_untyped_defs = false
enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"]
Expand Down
6 changes: 6 additions & 0 deletions src/vector/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,12 @@ def represent_as(self, target: type[VT], /, *args: Any, **kwargs: Any) -> VT:

return represent_as(self, target, **kwargs)

@abstractmethod
def norm(self) -> Quantity:
"""Return the norm of the vector."""
# TODO: make a generic method that works on all dimensions
raise NotImplementedError


class AbstractVectorDifferential(AbstractVectorBase):
"""Abstract representation of vector differentials in different systems."""
Expand Down
11 changes: 11 additions & 0 deletions src/vector/_d1/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,25 @@
__all__ = ["Abstract1DVector", "Abstract1DVectorDifferential"]


from functools import partial

import equinox as eqx
import jax
from jax_quantity import Quantity

from vector._base import AbstractVector, AbstractVectorDifferential


class Abstract1DVector(AbstractVector):
"""Abstract representation of 1D coordinates in different systems."""

@partial(jax.jit)
def norm(self) -> Quantity["length"]:
"""Return the norm of the vector."""
from .builtin import Cartesian1DVector # pylint: disable=C0415

Check warning on line 21 in src/vector/_d1/base.py

View check run for this annotation

Codecov / codecov/patch

src/vector/_d1/base.py#L21

Added line #L21 was not covered by tests

return self.represent_as(Cartesian1DVector).norm()

Check warning on line 23 in src/vector/_d1/base.py

View check run for this annotation

Codecov / codecov/patch

src/vector/_d1/base.py#L23

Added line #L23 was not covered by tests


class Abstract1DVectorDifferential(AbstractVectorDifferential):
"""Abstract representation of 1D differentials in different systems."""
Expand Down
8 changes: 8 additions & 0 deletions src/vector/_d1/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@
"RadialDifferential",
]

from functools import partial
from typing import ClassVar, final

import array_api_jax_compat as xp
import equinox as eqx
import jax

from vector._checks import check_r_non_negative
from vector._typing import BatchableLength, BatchableSpeed
Expand All @@ -30,6 +33,11 @@ class Cartesian1DVector(Abstract1DVector):
x: BatchableLength = eqx.field(converter=converter_quantity_array)
r"""X coordinate :math:`x \in (-\infty,+\infty)`."""

@partial(jax.jit)
def norm(self) -> BatchableLength:
"""Return the norm of the vector."""
return xp.abs(self.x)

Check warning on line 39 in src/vector/_d1/builtin.py

View check run for this annotation

Codecov / codecov/patch

src/vector/_d1/builtin.py#L39

Added line #L39 was not covered by tests


@final
class RadialVector(Abstract1DVector):
Expand Down
11 changes: 11 additions & 0 deletions src/vector/_d2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,25 @@
__all__ = ["Abstract2DVector", "Abstract2DVectorDifferential"]


from functools import partial

import equinox as eqx
import jax
from jax_quantity import Quantity

from vector._base import AbstractVector, AbstractVectorDifferential


class Abstract2DVector(AbstractVector):
"""Abstract representation of 2D coordinates in different systems."""

@partial(jax.jit)
def norm(self) -> Quantity["length"]:
"""Return the norm of the vector."""
from .builtin import Cartesian2DVector # pylint: disable=C0415

Check warning on line 21 in src/vector/_d2/base.py

View check run for this annotation

Codecov / codecov/patch

src/vector/_d2/base.py#L21

Added line #L21 was not covered by tests

return self.represent_as(Cartesian2DVector).norm()

Check warning on line 23 in src/vector/_d2/base.py

View check run for this annotation

Codecov / codecov/patch

src/vector/_d2/base.py#L23

Added line #L23 was not covered by tests


class Abstract2DVectorDifferential(AbstractVectorDifferential):
"""Abstract representation of 2D vector differentials."""
Expand Down
13 changes: 13 additions & 0 deletions src/vector/_d2/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@
"PolarDifferential",
]

from functools import partial
from typing import ClassVar, final

import array_api_jax_compat as xp
import equinox as eqx
import jax

from vector._checks import check_phi_range, check_r_non_negative
from vector._typing import (
Expand All @@ -40,6 +43,11 @@ class Cartesian2DVector(Abstract2DVector):
y: BatchableLength = eqx.field(converter=converter_quantity_array)
r"""Y coordinate :math:`y \in (-\infty,+\infty)`."""

@partial(jax.jit)
def norm(self) -> BatchableLength:
"""Return the norm of the vector."""
return xp.sqrt(self.x**2 + self.y**2)

Check warning on line 49 in src/vector/_d2/builtin.py

View check run for this annotation

Codecov / codecov/patch

src/vector/_d2/builtin.py#L49

Added line #L49 was not covered by tests


@final
class PolarVector(Abstract2DVector):
Expand All @@ -59,6 +67,11 @@ def __check_init__(self) -> None:
check_r_non_negative(self.r)
check_phi_range(self.phi)

@partial(jax.jit)
def norm(self) -> BatchableLength:
"""Return the norm of the vector."""
return self.r

Check warning on line 73 in src/vector/_d2/builtin.py

View check run for this annotation

Codecov / codecov/patch

src/vector/_d2/builtin.py#L73

Added line #L73 was not covered by tests


# class LnPolarVector(Abstract2DVector):
# """Log-polar vector representation."""
Expand Down
11 changes: 11 additions & 0 deletions src/vector/_d3/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,25 @@
__all__ = ["Abstract3DVector", "Abstract3DVectorDifferential"]


from functools import partial

import equinox as eqx
import jax
from jax_quantity import Quantity

from vector._base import AbstractVector, AbstractVectorDifferential


class Abstract3DVector(AbstractVector):
"""Abstract representation of 3D coordinates in different systems."""

@partial(jax.jit)
def norm(self) -> Quantity["length"]:
"""Return the norm of the vector."""
from .builtin import Cartesian3DVector # pylint: disable=C0415

Check warning on line 21 in src/vector/_d3/base.py

View check run for this annotation

Codecov / codecov/patch

src/vector/_d3/base.py#L21

Added line #L21 was not covered by tests

return self.represent_as(Cartesian3DVector).norm()

Check warning on line 23 in src/vector/_d3/base.py

View check run for this annotation

Codecov / codecov/patch

src/vector/_d3/base.py#L23

Added line #L23 was not covered by tests


class Abstract3DVectorDifferential(AbstractVectorDifferential):
"""Abstract representation of 3D vector differentials."""
Expand Down
18 changes: 18 additions & 0 deletions src/vector/_d3/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@
"CylindricalDifferential",
]

from functools import partial
from typing import ClassVar, final

import array_api_jax_compat as xp
import equinox as eqx
import jax

from vector._checks import check_phi_range, check_r_non_negative, check_theta_range
from vector._typing import (
Expand Down Expand Up @@ -43,6 +46,11 @@ class Cartesian3DVector(Abstract3DVector):
z: BatchableLength = eqx.field(converter=converter_quantity_array)
r"""Z coordinate :math:`z \in (-\infty,+\infty)`."""

@partial(jax.jit)
def norm(self) -> BatchableLength:
"""Return the norm of the vector."""
return xp.sqrt(self.x**2 + self.y**2 + self.z**2)

Check warning on line 52 in src/vector/_d3/builtin.py

View check run for this annotation

Codecov / codecov/patch

src/vector/_d3/builtin.py#L52

Added line #L52 was not covered by tests


@final
class SphericalVector(Abstract3DVector):
Expand All @@ -63,6 +71,11 @@ def __check_init__(self) -> None:
check_theta_range(self.theta)
check_phi_range(self.phi)

@partial(jax.jit)
def norm(self) -> BatchableLength:
"""Return the norm of the vector."""
return self.r

Check warning on line 77 in src/vector/_d3/builtin.py

View check run for this annotation

Codecov / codecov/patch

src/vector/_d3/builtin.py#L77

Added line #L77 was not covered by tests


@final
class CylindricalVector(Abstract3DVector):
Expand All @@ -82,6 +95,11 @@ def __check_init__(self) -> None:
check_r_non_negative(self.rho)
check_phi_range(self.phi)

@partial(jax.jit)
def norm(self) -> BatchableLength:
"""Return the norm of the vector."""
return xp.sqrt(self.rho**2 + self.z**2)

Check warning on line 101 in src/vector/_d3/builtin.py

View check run for this annotation

Codecov / codecov/patch

src/vector/_d3/builtin.py#L101

Added line #L101 was not covered by tests


##############################################################################
# Differential
Expand Down

0 comments on commit b7dcbe7

Please sign in to comment.