diff --git a/pyproject.toml b/pyproject.toml index f3f91e3..2ff3437 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ classifiers = [ dynamic = ["version"] dependencies = [ "jax", + "typing_extensions >= 4.8", ] [project.optional-dependencies] @@ -140,6 +141,9 @@ ignore = [ "ISC001", # Conflicts with formatter ] +[tool.ruff.lint.isort] +extra-standard-library = ["typing_extensions"] + [tool.ruff.lint.per-file-ignores] "tests/**" = ["T20"] "noxfile.py" = ["T20"] diff --git a/src/immutable_map_jax/_core.py b/src/immutable_map_jax/_core.py index 8138c4b..ef40058 100644 --- a/src/immutable_map_jax/_core.py +++ b/src/immutable_map_jax/_core.py @@ -8,9 +8,9 @@ from collections.abc import ItemsView, Iterable, Iterator, KeysView, Mapping, ValuesView from typing import Annotated, Any, TypeVar, overload +from typing_extensions import Doc from jax.tree_util import register_pytree_node_class -from typing_extensions import Doc _T = TypeVar("_T") K = TypeVar("K")