From a1334a6ae8c9e94102124e6bde2af488e8309926 Mon Sep 17 00:00:00 2001 From: Nathaniel Starkman Date: Sat, 24 Feb 2024 21:54:37 -0500 Subject: [PATCH] feat: add array api properties and methods (#29) * feat: add array api properties and methods * ci: add pylint ignore * fix: type hints Signed-off-by: nstarman --- pyproject.toml | 1 + src/vector/_base.py | 51 ++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 961d1b70..f5f8fa07 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -182,6 +182,7 @@ "design", "fixme", "function-redefined", # plum-dispatch + "invalid-name", # handled by ruff "line-too-long", "missing-function-docstring", # TODO: resolve "missing-module-docstring", diff --git a/src/vector/_base.py b/src/vector/_base.py index 0135b8ba..34bfce22 100644 --- a/src/vector/_base.py +++ b/src/vector/_base.py @@ -14,6 +14,7 @@ import equinox as eqx import jax import jax.numpy as jnp +from jax import Device from jax_quantity import Quantity from plum import dispatch @@ -56,7 +57,21 @@ def constructor( return cls(**obj) # =============================================================== - # Array + # Array API + + def __array_namespace__(self) -> "ArrayAPINamespace": + """Return the array API namespace.""" + return xp + + @property + def mT(self) -> "Self": # noqa: N802 + """Transpose the vector.""" + return replace(self, **{k: v.mT for k, v in dataclass_items(self)}) + + @property + def ndim(self) -> int: + """Number of array dimensions (axes).""" + return len(self.shape) @property def shape(self) -> Any: @@ -67,6 +82,25 @@ def shape(self) -> Any: """ return jnp.broadcast_shapes(*self.shapes.values()) + @property + def size(self) -> int: + """Total number of elements in the vector.""" + return int(jnp.prod(xp.asarray(self.shape))) + + @property + def T(self) -> "Self": # noqa: N802 + """Transpose the vector.""" + return replace(self, **{k: v.T for k, v in dataclass_items(self)}) + + def to_device(self, device: None | Device = None) -> "Self": + """Move the vector to a new device.""" + return replace( + self, **{k: v.to_device(device) for k, v in dataclass_items(self)} + ) + + # =============================================================== + # Further array methods + def flatten(self) -> "Self": """Flatten the vector.""" return replace(self, **{k: v.flatten() for k, v in dataclass_items(self)}) @@ -117,11 +151,26 @@ def components(self) -> tuple[str, ...]: """Vector component names.""" return tuple(f.name for f in fields(self)) + @property + def dtypes(self) -> Mapping[str, jnp.dtype]: + """Get the dtypes of the vector's components.""" + return MappingProxyType({k: v.dtype for k, v in dataclass_items(self)}) + + @property + def devices(self) -> Mapping[str, Device]: + """Get the devices of the vector's components.""" + return MappingProxyType({k: v.device for k, v in dataclass_items(self)}) + @property def shapes(self) -> Mapping[str, tuple[int, ...]]: """Get the shapes of the vector's components.""" return MappingProxyType({k: v.shape for k, v in dataclass_items(self)}) + @property + def sizes(self) -> Mapping[str, int]: + """Get the sizes of the vector's components.""" + return MappingProxyType({k: v.size for k, v in dataclass_items(self)}) + # =============================================================== # Convenience methods