Skip to content

Commit

Permalink
feat: add array api properties and methods (#29)
Browse files Browse the repository at this point in the history
* feat: add array api properties and methods
* ci: add pylint ignore
* fix: type hints

Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman authored Feb 25, 2024
1 parent 29a21c6 commit a1334a6
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 1 deletion.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@
"design",
"fixme",
"function-redefined", # plum-dispatch
"invalid-name", # handled by ruff
"line-too-long",
"missing-function-docstring", # TODO: resolve
"missing-module-docstring",
Expand Down
51 changes: 50 additions & 1 deletion src/vector/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import equinox as eqx
import jax
import jax.numpy as jnp
from jax import Device
from jax_quantity import Quantity
from plum import dispatch

Expand Down Expand Up @@ -56,7 +57,21 @@ def constructor(
return cls(**obj)

# ===============================================================
# Array
# Array API

def __array_namespace__(self) -> "ArrayAPINamespace":
"""Return the array API namespace."""
return xp

@property
def mT(self) -> "Self": # noqa: N802
"""Transpose the vector."""
return replace(self, **{k: v.mT for k, v in dataclass_items(self)})

@property
def ndim(self) -> int:
"""Number of array dimensions (axes)."""
return len(self.shape)

@property
def shape(self) -> Any:
Expand All @@ -67,6 +82,25 @@ def shape(self) -> Any:
"""
return jnp.broadcast_shapes(*self.shapes.values())

@property
def size(self) -> int:
"""Total number of elements in the vector."""
return int(jnp.prod(xp.asarray(self.shape)))

@property
def T(self) -> "Self": # noqa: N802
"""Transpose the vector."""
return replace(self, **{k: v.T for k, v in dataclass_items(self)})

def to_device(self, device: None | Device = None) -> "Self":
"""Move the vector to a new device."""
return replace(
self, **{k: v.to_device(device) for k, v in dataclass_items(self)}
)

# ===============================================================
# Further array methods

def flatten(self) -> "Self":
"""Flatten the vector."""
return replace(self, **{k: v.flatten() for k, v in dataclass_items(self)})
Expand Down Expand Up @@ -117,11 +151,26 @@ def components(self) -> tuple[str, ...]:
"""Vector component names."""
return tuple(f.name for f in fields(self))

@property
def dtypes(self) -> Mapping[str, jnp.dtype]:
"""Get the dtypes of the vector's components."""
return MappingProxyType({k: v.dtype for k, v in dataclass_items(self)})

@property
def devices(self) -> Mapping[str, Device]:
"""Get the devices of the vector's components."""
return MappingProxyType({k: v.device for k, v in dataclass_items(self)})

@property
def shapes(self) -> Mapping[str, tuple[int, ...]]:
"""Get the shapes of the vector's components."""
return MappingProxyType({k: v.shape for k, v in dataclass_items(self)})

@property
def sizes(self) -> Mapping[str, int]:
"""Get the sizes of the vector's components."""
return MappingProxyType({k: v.size for k, v in dataclass_items(self)})

# ===============================================================
# Convenience methods

Expand Down

0 comments on commit a1334a6

Please sign in to comment.