Skip to content

Commit

Permalink
feat: doc annotations
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman committed Jul 17, 2024
1 parent a563e96 commit 68528b7
Showing 1 changed file with 14 additions and 26 deletions.
40 changes: 14 additions & 26 deletions src/immutable_map_jax/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
__all__: list[str] = []

from collections.abc import ItemsView, Iterable, Iterator, KeysView, Mapping, ValuesView
from typing import Any, TypeVar, overload
from typing import Annotated, Any, TypeVar, overload

from jax.tree_util import register_pytree_node_class
from typing_extensions import Doc

_T = TypeVar("_T")
K = TypeVar("K")
Expand Down Expand Up @@ -108,33 +109,20 @@ def __repr__(self) -> str:
# ===========================================
# JAX PyTree

def tree_flatten(self) -> tuple[tuple[V, ...], tuple[K, ...]]:
"""Flatten dict to the values (and keys).
Returns
-------
tuple[V, ...] tuple[str, ...]
A pair of an iterable with the values to be flattened recursively,
and the keys to pass back to the unflattening recipe.
"""
return (tuple(self._data.values()), tuple(self._data.keys()))
def tree_flatten(
self,
) -> tuple[
Annotated[tuple[V, ...], Doc("The values.")],
Annotated[tuple[K, ...], Doc("The keys as auxiliary data.")],
]:
"""Flatten dict to the values (and keys)."""
return tuple(self._data.values()), tuple(self._data.keys())

@classmethod
def tree_unflatten(
cls,
aux_data: tuple[K, ...],
children: tuple[V, ...],
) -> "ImmutableMap": # type: ignore[type-arg] # TODO: upstream beartype fix for ImmutableMap[V]
"""Unflatten.
Params:
aux_data: the opaque data that was specified during flattening of the
current treedef.
children: the unflattened children
Returns
-------
a re-constructed object of the registered type, using the specified
children and auxiliary data.
"""
aux_data: Annotated[tuple[K, ...], Doc("The keys.")],
children: Annotated[tuple[V, ...], Doc("The values.")],
) -> "ImmutableMap[K, V]":
"""Unflatten into an ImmutableMap from the keys and values."""
return cls(tuple(zip(aux_data, children, strict=True)))

0 comments on commit 68528b7

Please sign in to comment.