Skip to content

Commit

Permalink
✨ feat(vecs): initialize space from non-vectors (#357)
Browse files Browse the repository at this point in the history
  • Loading branch information
nstarman authored Jan 25, 2025
1 parent f89da37 commit 2df61b1
Showing 1 changed file with 52 additions and 16 deletions.
68 changes: 52 additions & 16 deletions src/coordinax/_src/vectors/space/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,36 +104,49 @@ class Space(AbstractVector, ImmutableMap[Dimension, AbstractVector]): # type: i
>>> w.mT.shapes
mappingproxy({'length': (2, 1), 'speed': (2, 1)})
There are convenience ways to initialize the vectors in the space:
>>> space = cx.Space.from_({"length": u.Quantity([1, 2, 3], "km"),
... "speed": u.Quantity([4, 5, 6], "km/s")})
>>> print(space)
Space({
'length': <CartesianPos3D (x[km], y[km], z[km])
[1 2 3]>,
'speed': <CartesianVel3D (d_x[km / s], d_y[km / s], d_z[km / s])
[4 5 6]>
})
"""

_data: dict[str, AbstractVector] = eqx.field(init=False)

def __init__( # pylint: disable=super-init-not-called # TODO: resolve this
self,
/,
*args: Mapping[DimensionLike, AbstractVector]
| tuple[DimensionLike, AbstractVector]
| Iterable[tuple[DimensionLike, AbstractVector]],
**kwargs: AbstractVector,
*args: Mapping[DimensionLike, Any]
| tuple[DimensionLike, Any]
| Iterable[tuple[DimensionLike, Any]],
**kwargs: Any,
) -> None:
# Process the input data
# Consolidate the inputs into a single dict, then process keys & values.
raw = dict(*args, **kwargs) # process the input data
keys = [_get_dimension_name(k) for k in raw]
keys = eqx.error_if(
keys,
len(keys) < len(raw),
f"Space(**input) contained duplicate keys {set(raw) - set(keys)}.",
)
# TODO: check the key dimension makes sense for the value

# Process the keys
dims = tuple(u.dimension(k) for k in raw)
keys = tuple(_get_dimension_name(dim) for dim in dims)
# Convert the values to vectors
values = tuple(vector(v) for v in raw.values())

# TODO: check the dimension makes sense for the value

# Check that the shapes are broadcastable
keys = eqx.error_if(
keys,
not _can_broadcast_shapes(*(v.shape for v in raw.values())),
values = eqx.error_if(
values,
not _can_broadcast_shapes(*map(jnp.shape, values)),
"vector shapes are not broadcastable.",
)

ImmutableMap.__init__(self, dict(zip(keys, raw.values(), strict=True)))
ImmutableMap.__init__(self, dict(zip(keys, values, strict=True)))

@classmethod
def _dimensionality(cls) -> int:
Expand Down Expand Up @@ -581,6 +594,29 @@ def vector(
return cls(length=q, speed=p, acceleration=a)


@dispatch
def vector(
cls: type[Space],
obj: Mapping[str, Any],
) -> Space:
"""Construct a Space from a Mapping.
Examples
--------
>>> import unxt as u
>>> import coordinax as cx
>>> space = cx.Space.from_({ 'length': u.Quantity([1, 2, 3], "m") })
>>> print(space)
Space({
'length': <CartesianPos3D (x[m], y[m], z[m])
[1 2 3]>
})
"""
return cls({k: vector(v) for k, v in obj.items()})


# ===============================================================
# Vector API dispatches

Expand Down

0 comments on commit 2df61b1

Please sign in to comment.