Skip to content

Commit

Permalink
Merge branch 'main' into jax_tree_util_legacy
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosgmartin committed Sep 20, 2024
1 parent e961623 commit 25c1149
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 3 deletions.
5 changes: 5 additions & 0 deletions docs/api/utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ Tree
tree_mul
tree_ones_like
tree_random_like
tree_split_key_like
tree_scalar_mul
tree_set
tree_sub
Expand Down Expand Up @@ -153,6 +154,10 @@ Tree ones like
~~~~~~~~~~~~~~
.. autofunction:: tree_ones_like

Tree with random keys
~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: tree_split_key_like

Tree with random values
~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: tree_random_like
Expand Down
1 change: 1 addition & 0 deletions optax/tree_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint: disable=g-importing-member
from optax.tree_utils._casting import tree_cast
from optax.tree_utils._random import tree_random_like
from optax.tree_utils._random import tree_split_key_like
from optax.tree_utils._state_utils import NamedTupleKey
from optax.tree_utils._state_utils import tree_get
from optax.tree_utils._state_utils import tree_get_all_with_path
Expand Down
6 changes: 3 additions & 3 deletions optax/tree_utils/_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import jax


def _tree_rng_keys_split(
def tree_split_key_like(
rng_key: chex.PRNGKey, target_tree: chex.ArrayTree
) -> chex.ArrayTree:
"""Split keys to match structure of target tree.
Expand Down Expand Up @@ -66,9 +66,9 @@ def tree_random_like(
.. versionadded:: 0.2.1
"""
keys_tree = _tree_rng_keys_split(rng_key, target_tree)
keys_tree = tree_split_key_like(rng_key, target_tree)
return jax.tree.map(
lambda l, k: sampler(k, l.shape, dtype or l.dtype),
lambda leaf, key: sampler(key, leaf.shape, dtype or leaf.dtype),
target_tree,
keys_tree,
)
14 changes: 14 additions & 0 deletions optax/tree_utils/_random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import jax
import jax.numpy as jnp
import jax.random as jrd
import numpy as np
from optax import tree_utils as otu

# We consider samplers with varying input dtypes, we do not test all possible
Expand All @@ -48,6 +49,19 @@ def get_variable(type_var: str):

class RandomTest(chex.TestCase):

def test_tree_split_key_like(self):
rng_key = jrd.PRNGKey(0)
tree = {'a': jnp.zeros(2), 'b': {'c': [jnp.ones(3), jnp.zeros([4, 5])]}}
keys_tree = otu.tree_split_key_like(rng_key, tree)

with self.subTest('Test structure matches'):
self.assertEqual(jax.tree.structure(tree), jax.tree.structure(keys_tree))

with self.subTest('Test random key split'):
fst = jnp.stack(jax.tree.flatten(keys_tree)[0])
snd = jrd.split(rng_key, jax.tree.structure(tree).num_leaves)
np.testing.assert_array_equal(fst, snd)

@parameterized.product(
_SAMPLER_DTYPES,
type_var=['real_array', 'complex_array', 'pytree'],
Expand Down

0 comments on commit 25c1149

Please sign in to comment.