Skip to content

Commit

Permalink
feat: doc annotations (#4)
Browse files Browse the repository at this point in the history
* feat: doc annotations
* build: typing_extensions

Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman authored Jul 17, 2024
1 parent a563e96 commit 5ee462c
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 26 deletions.
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ classifiers = [
dynamic = ["version"]
dependencies = [
"jax",
"typing_extensions >= 4.8",
]

[project.optional-dependencies]
Expand Down Expand Up @@ -140,6 +141,9 @@ ignore = [
"ISC001", # Conflicts with formatter
]

[tool.ruff.lint.isort]
extra-standard-library = ["typing_extensions"]

[tool.ruff.lint.per-file-ignores]
"tests/**" = ["T20"]
"noxfile.py" = ["T20"]
Expand Down
40 changes: 14 additions & 26 deletions src/immutable_map_jax/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
__all__: list[str] = []

from collections.abc import ItemsView, Iterable, Iterator, KeysView, Mapping, ValuesView
from typing import Any, TypeVar, overload
from typing import Annotated, Any, TypeVar, overload
from typing_extensions import Doc

from jax.tree_util import register_pytree_node_class

Expand Down Expand Up @@ -108,33 +109,20 @@ def __repr__(self) -> str:
# ===========================================
# JAX PyTree

def tree_flatten(self) -> tuple[tuple[V, ...], tuple[K, ...]]:
"""Flatten dict to the values (and keys).
Returns
-------
tuple[V, ...] tuple[str, ...]
A pair of an iterable with the values to be flattened recursively,
and the keys to pass back to the unflattening recipe.
"""
return (tuple(self._data.values()), tuple(self._data.keys()))
def tree_flatten(
self,
) -> tuple[
Annotated[tuple[V, ...], Doc("The values.")],
Annotated[tuple[K, ...], Doc("The keys as auxiliary data.")],
]:
"""Flatten dict to the values (and keys)."""
return tuple(self._data.values()), tuple(self._data.keys())

@classmethod
def tree_unflatten(
cls,
aux_data: tuple[K, ...],
children: tuple[V, ...],
) -> "ImmutableMap": # type: ignore[type-arg] # TODO: upstream beartype fix for ImmutableMap[V]
"""Unflatten.
Params:
aux_data: the opaque data that was specified during flattening of the
current treedef.
children: the unflattened children
Returns
-------
a re-constructed object of the registered type, using the specified
children and auxiliary data.
"""
aux_data: Annotated[tuple[K, ...], Doc("The keys.")],
children: Annotated[tuple[V, ...], Doc("The values.")],
) -> "ImmutableMap[K, V]":
"""Unflatten into an ImmutableMap from the keys and values."""
return cls(tuple(zip(aux_data, children, strict=True)))

0 comments on commit 5ee462c

Please sign in to comment.