From 5ee462c26d70dd8dea3a104eba726e2d42fb66c0 Mon Sep 17 00:00:00 2001 From: Nathaniel Starkman Date: Tue, 16 Jul 2024 22:38:32 -0400 Subject: [PATCH] feat: doc annotations (#4) * feat: doc annotations * build: typing_extensions Signed-off-by: nstarman --- pyproject.toml | 4 ++++ src/immutable_map_jax/_core.py | 40 ++++++++++++---------------------- 2 files changed, 18 insertions(+), 26 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f3f91e3..2ff3437 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ classifiers = [ dynamic = ["version"] dependencies = [ "jax", + "typing_extensions >= 4.8", ] [project.optional-dependencies] @@ -140,6 +141,9 @@ ignore = [ "ISC001", # Conflicts with formatter ] +[tool.ruff.lint.isort] +extra-standard-library = ["typing_extensions"] + [tool.ruff.lint.per-file-ignores] "tests/**" = ["T20"] "noxfile.py" = ["T20"] diff --git a/src/immutable_map_jax/_core.py b/src/immutable_map_jax/_core.py index daebd4c..ef40058 100644 --- a/src/immutable_map_jax/_core.py +++ b/src/immutable_map_jax/_core.py @@ -7,7 +7,8 @@ __all__: list[str] = [] from collections.abc import ItemsView, Iterable, Iterator, KeysView, Mapping, ValuesView -from typing import Any, TypeVar, overload +from typing import Annotated, Any, TypeVar, overload +from typing_extensions import Doc from jax.tree_util import register_pytree_node_class @@ -108,33 +109,20 @@ def __repr__(self) -> str: # =========================================== # JAX PyTree - def tree_flatten(self) -> tuple[tuple[V, ...], tuple[K, ...]]: - """Flatten dict to the values (and keys). - - Returns - ------- - tuple[V, ...] tuple[str, ...] - A pair of an iterable with the values to be flattened recursively, - and the keys to pass back to the unflattening recipe. - """ - return (tuple(self._data.values()), tuple(self._data.keys())) + def tree_flatten( + self, + ) -> tuple[ + Annotated[tuple[V, ...], Doc("The values.")], + Annotated[tuple[K, ...], Doc("The keys as auxiliary data.")], + ]: + """Flatten dict to the values (and keys).""" + return tuple(self._data.values()), tuple(self._data.keys()) @classmethod def tree_unflatten( cls, - aux_data: tuple[K, ...], - children: tuple[V, ...], - ) -> "ImmutableMap": # type: ignore[type-arg] # TODO: upstream beartype fix for ImmutableMap[V] - """Unflatten. - - Params: - aux_data: the opaque data that was specified during flattening of the - current treedef. - children: the unflattened children - - Returns - ------- - a re-constructed object of the registered type, using the specified - children and auxiliary data. - """ + aux_data: Annotated[tuple[K, ...], Doc("The keys.")], + children: Annotated[tuple[V, ...], Doc("The values.")], + ) -> "ImmutableMap[K, V]": + """Unflatten into an ImmutableMap from the keys and values.""" return cls(tuple(zip(aux_data, children, strict=True)))