Skip to content

Commit

Permalink
C++ tree with path API
Browse files Browse the repository at this point in the history
* Make tree_util.tree_flatten_with_path and tree_map_with_path APIs to be C++-based, to speed up the pytree flattening.

* Moves all the key classes down to C++ level, while keeping the APIs unchanged.
  * Known small caveats: they are no longer Python dataclasses, and pattern matching might make pytype unhappy.

* Registered defaultdict and ordereddict via the keypath API now.

PiperOrigin-RevId: 694219933
  • Loading branch information
IvyZX authored and Google-ML-Automation committed Nov 21, 2024
1 parent 26443bb commit 28f044a
Show file tree
Hide file tree
Showing 3 changed files with 287 additions and 45 deletions.
120 changes: 82 additions & 38 deletions jax/_src/tree_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@
from functools import partial
import operator as op
import textwrap
from typing import Any, NamedTuple, TypeVar, Union, overload
from typing import Any, NamedTuple, TypeVar, overload

from jax._src import traceback_util
from jax._src.lib import pytree
from jax._src.lib import xla_extension_version
from jax._src.util import safe_zip, set_module
from jax._src.util import unzip2

Expand Down Expand Up @@ -209,12 +210,21 @@ def all_leaves(iterable: Iterable[Any],

_Children = TypeVar("_Children", bound=Iterable[Any])
_AuxData = TypeVar("_AuxData", bound=Hashable)
KeyEntry = TypeVar("KeyEntry", bound=Hashable)
KeyLeafPair = tuple[KeyEntry, Any]
KeyLeafPairs = Iterable[KeyLeafPair]
KeyPath = tuple[KeyEntry, ...]


@export
def register_pytree_node(nodetype: type[T],
flatten_func: Callable[[T], tuple[_Children, _AuxData]],
unflatten_func: Callable[[_AuxData, _Children], T]) -> None:
def register_pytree_node(
nodetype: type[T],
flatten_func: Callable[[T], tuple[_Children, _AuxData]],
unflatten_func: Callable[[_AuxData, _Children], T],
flatten_with_keys_func: (
Callable[[T], tuple[KeyLeafPairs, _AuxData]] | None
) = None,
) -> None:
"""Extends the set of types that are considered internal nodes in pytrees.
See :ref:`example usage <pytrees>`.
Expand Down Expand Up @@ -279,9 +289,20 @@ def register_pytree_node(nodetype: type[T],
>>> jax.jit(f)(m)
Array([1., 2., 3., 4., 5.], dtype=float32)
"""
default_registry.register_node(nodetype, flatten_func, unflatten_func)
none_leaf_registry.register_node(nodetype, flatten_func, unflatten_func)
dispatch_registry.register_node(nodetype, flatten_func, unflatten_func)
if xla_extension_version >= 298:
default_registry.register_node( # type: ignore[call-arg]
nodetype, flatten_func, unflatten_func, flatten_with_keys_func
)
none_leaf_registry.register_node( # type: ignore[call-arg]
nodetype, flatten_func, unflatten_func, flatten_with_keys_func
)
dispatch_registry.register_node( # type: ignore[call-arg]
nodetype, flatten_func, unflatten_func, flatten_with_keys_func
)
else:
default_registry.register_node(nodetype, flatten_func, unflatten_func)
none_leaf_registry.register_node(nodetype, flatten_func, unflatten_func)
dispatch_registry.register_node(nodetype, flatten_func, unflatten_func)
_registry[nodetype] = _RegistryEntry(flatten_func, unflatten_func)


Expand Down Expand Up @@ -452,21 +473,6 @@ def tree_all(tree: Any, *, is_leaf: Callable[[Any], bool] | None = None) -> bool
return all(tree_leaves(tree, is_leaf=is_leaf))


register_pytree_node(
collections.OrderedDict,
lambda x: (tuple(x.values()), tuple(x.keys())),
lambda keys, values: collections.OrderedDict(safe_zip(keys, values)))

def _flatten_defaultdict(d):
keys = tuple(sorted(d))
return tuple(d[k] for k in keys), (d.default_factory, keys)

register_pytree_node(
collections.defaultdict,
_flatten_defaultdict,
lambda s, values: collections.defaultdict(s[0], safe_zip(s[1], values)))


class _HashableCallableShim:
"""Object that delegates __call__, __hash__, and __eq__ to another object."""

Expand Down Expand Up @@ -578,11 +584,11 @@ def broadcast_prefix(prefix_tree: Any, full_tree: Any,


# flatten_one_level is not exported.
def flatten_one_level(pytree: Any) -> tuple[Iterable[Any], Hashable]:
def flatten_one_level(tree: Any) -> tuple[Iterable[Any], Hashable]:
"""Flatten the given pytree node by one level.
Args:
pytree: A valid pytree node, either built-in or registered via
tree: A valid pytree node, either built-in or registered via
:func:`register_pytree_node` or related functions.
Returns:
Expand All @@ -601,9 +607,9 @@ def flatten_one_level(pytree: Any) -> tuple[Iterable[Any], Hashable]:
>>> meta
('a', 'b')
"""
out = default_registry.flatten_one_level(pytree)
out = default_registry.flatten_one_level(tree)
if out is None:
raise ValueError(f"can't tree-flatten type: {type(pytree)}")
raise ValueError(f"can't tree-flatten type: {type(tree)}")
else:
return out

Expand Down Expand Up @@ -739,10 +745,12 @@ class FlattenedIndexKey():
def __str__(self):
return f'[<flat index {self.key}>]'

BuiltInKeyEntry = Union[SequenceKey, DictKey, GetAttrKey, FlattenedIndexKey]

KeyEntry = TypeVar("KeyEntry", bound=Hashable)
KeyPath = tuple[KeyEntry, ...]
if xla_extension_version >= 298:
SequenceKey = pytree.SequenceKey # type: ignore
DictKey = pytree.DictKey # type: ignore
GetAttrKey = pytree.GetAttrKey # type: ignore
FlattenedIndexKey = pytree.FlattenedIndexKey # type: ignore


@export
Expand All @@ -764,6 +772,7 @@ def keystr(keys: KeyPath):
return ''.join(map(str, keys))


# TODO(ivyzheng): remove this after _child_keys() also moved to C++.
class _RegistryWithKeypathsEntry(NamedTuple):
flatten_with_keys: Callable[..., Any]
unflatten_func: Callable[..., Any]
Expand All @@ -780,7 +789,6 @@ def flatten_with_keys(xs):
flatten_with_keys, _registry[ty].from_iter
)


_registry_with_keypaths: dict[type[Any], _RegistryWithKeypathsEntry] = {}

_register_keypaths(
Expand All @@ -803,13 +811,9 @@ def flatten_with_keys(xs):
@export
def register_pytree_with_keys(
nodetype: type[T],
flatten_with_keys: Callable[
[T], tuple[Iterable[tuple[KeyEntry, Any]], _AuxData]
],
flatten_with_keys: Callable[[T], tuple[Iterable[KeyLeafPair], _AuxData]],
unflatten_func: Callable[[_AuxData, Iterable[Any]], T],
flatten_func: None | (
Callable[[T], tuple[Iterable[Any], _AuxData]]
) = None,
flatten_func: None | (Callable[[T], tuple[Iterable[Any], _AuxData]]) = None,
):
"""Extends the set of types that are considered internal nodes in pytrees.
Expand Down Expand Up @@ -870,7 +874,9 @@ def flatten_func_impl(tree):
return [c for _, c in key_children], treedef
flatten_func = flatten_func_impl

register_pytree_node(nodetype, flatten_func, unflatten_func)
register_pytree_node(
nodetype, flatten_func, unflatten_func, flatten_with_keys
)
_registry_with_keypaths[nodetype] = _RegistryWithKeypathsEntry(
flatten_with_keys, unflatten_func
)
Expand Down Expand Up @@ -1092,6 +1098,40 @@ def flatten_func(x):
return nodetype


if xla_extension_version >= 298:
register_pytree_with_keys(
collections.OrderedDict,
lambda x: (tuple((DictKey(k), x[k]) for k in x.keys()), tuple(x.keys())),
lambda keys, values: collections.OrderedDict(safe_zip(keys, values)),
)

def _flatten_defaultdict_with_keys(d):
keys = tuple(sorted(d))
return tuple((DictKey(k), d[k]) for k in keys), (d.default_factory, keys)

register_pytree_with_keys(
collections.defaultdict,
_flatten_defaultdict_with_keys,
lambda s, values: collections.defaultdict(s[0], safe_zip(s[1], values)),
)
else:
register_pytree_node(
collections.OrderedDict,
lambda x: (tuple(x.values()), tuple(x.keys())),
lambda keys, values: collections.OrderedDict(safe_zip(keys, values)),
)

def _flatten_defaultdict(d):
keys = tuple(sorted(d))
return tuple(d[k] for k in keys), (d.default_factory, keys)

register_pytree_node(
collections.defaultdict,
_flatten_defaultdict,
lambda s, values: collections.defaultdict(s[0], safe_zip(s[1], values)),
)


@export
def register_static(cls: type[H]) -> type[H]:
"""Registers `cls` as a pytree with no leaves.
Expand Down Expand Up @@ -1144,6 +1184,8 @@ def tree_flatten_with_path(
which contains a leaf and its key path. The second element is a treedef
representing the structure of the flattened tree.
"""
if xla_extension_version >= 298:
return default_registry.flatten_with_path(tree, is_leaf)
_, tree_def = tree_flatten(tree, is_leaf)
return _generate_key_paths(tree, is_leaf), tree_def

Expand All @@ -1164,13 +1206,15 @@ def tree_leaves_with_path(
- :func:`jax.tree_util.tree_leaves`
- :func:`jax.tree_util.tree_flatten_with_path`
"""
return _generate_key_paths(tree, is_leaf)
return tree_flatten_with_path(tree, is_leaf)[0]


# generate_key_paths is not exported.
def generate_key_paths(
tree: Any, is_leaf: Callable[[Any], bool] | None = None
) -> list[tuple[KeyPath, Any]]:
if xla_extension_version >= 298:
return tree_leaves_with_path(tree, is_leaf)
return list(_generate_key_paths_((), tree, is_leaf))
_generate_key_paths = generate_key_paths # alias for backward compat

Expand Down
48 changes: 42 additions & 6 deletions tests/package_structure_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,53 @@ class PackageStructureTest(jtu.JaxTestCase):
_mod("jax.errors", exclude=["JaxRuntimeError"]),
_mod(
"jax.numpy",
exclude=["array_repr", "array_str", "can_cast", "character", "complexfloating",
"dtype", "iinfo", "index_exp", "inexact", "integer", "iterable", "finfo",
"flexible", "floating", "generic", "get_printoptions", "ndarray", "ndim",
"number", "object_", "printoptions", "save", "savez", "set_printoptions",
"shape", "signedinteger", "size", "s_", "unsignedinteger", "ComplexWarning"]
exclude=[
"array_repr",
"array_str",
"can_cast",
"character",
"complexfloating",
"dtype",
"iinfo",
"index_exp",
"inexact",
"integer",
"iterable",
"finfo",
"flexible",
"floating",
"generic",
"get_printoptions",
"ndarray",
"ndim",
"number",
"object_",
"printoptions",
"save",
"savez",
"set_printoptions",
"shape",
"signedinteger",
"size",
"s_",
"unsignedinteger",
"ComplexWarning",
],
),
_mod("jax.numpy.linalg"),
_mod("jax.nn.initializers"),
_mod(
"jax.tree_util",
exclude=["PyTreeDef", "default_registry", "KeyEntry", "KeyPath"],
exclude=[
"PyTreeDef",
"default_registry",
"KeyEntry",
"KeyPath",
"DictKey",
"GetAttrKey",
"SequenceKey",
"FlattenedIndexKey",
],
),
])
def test_exported_names_match_module(self, module_name, include, exclude):
Expand Down
Loading

0 comments on commit 28f044a

Please sign in to comment.