Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update jax.tree.map to comply with jax 0.4.34 #1094

Merged
merged 3 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions optax/_src/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,11 @@ def test_stateless(self):

@base.stateless
def opt(g, p):
return jax.tree.map(lambda g_, p_: g_ + 0.1 * p_, g, p)
return jax.tree.map(
lambda g_, p_: None if g_ is None else g_ + 0.1 * p_,
g,
p,
is_leaf=lambda x: x is None)

state = opt.init(params)
update_fn = self.variant(opt.update)
Expand All @@ -156,7 +160,11 @@ def opt(g, _):

def test_init_returns_emptystate(self):
def weight_decay(g, p):
return jax.tree.map(lambda g_, p_: g_ + 0.1 * p_, g, p)
return jax.tree.map(
lambda g_, p_: None if g_ is None else g_ + 0.1 * p_,
g,
p,
is_leaf=lambda x: x is None)

opt = base.stateless(weight_decay)
state = opt.init(None) # pytype: disable=wrong-arg-types # numpy-scalars
Expand Down
26 changes: 21 additions & 5 deletions optax/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,10 @@ def update_fn(updates, state, params=None):
# unclear why. Other Nadam implementations also omit the extra b2 factor.
nu_hat = otu.tree_bias_correction(nu, b2, count_inc)
updates = jax.tree.map(
lambda m, v: m / (jnp.sqrt(v + eps_root) + eps), mu_hat, nu_hat)
lambda m, v: None if m is None else m / (jnp.sqrt(v + eps_root) + eps),
mu_hat,
nu_hat,
is_leaf=lambda x: x is None)
mu = otu.tree_cast(mu, mu_dtype)
return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu)

Expand Down Expand Up @@ -385,7 +388,10 @@ def update_fn(updates, state, params=None):
nu_hat = otu.tree_bias_correction(nu, b2, count_inc)
nu_max = jax.tree.map(jnp.maximum, state.nu_max, nu_hat)
updates = jax.tree.map(
lambda m, v: m / (jnp.sqrt(v + eps_root) + eps), mu_hat, nu_max)
lambda m, v: None if m is None else m / (jnp.sqrt(v + eps_root) + eps),
mu_hat,
nu_max,
is_leaf=lambda x: x is None)
mu = otu.tree_cast(mu, mu_dtype)
return updates, ScaleByAmsgradState(
count=count_inc,
Expand Down Expand Up @@ -640,7 +646,10 @@ def update_fn(updates, state, params=None):
mu_hat = otu.tree_bias_correction(mu, b1, count_inc)
nu_hat = otu.tree_bias_correction(nu, b2, count_inc)
updates = jax.tree.map(
lambda m, v: m / (jnp.sqrt(v) + eps), mu_hat, nu_hat)
lambda m, v: None if m is None else m / (jnp.sqrt(v) + eps),
mu_hat,
nu_hat,
is_leaf=lambda x: x is None)
return updates, ScaleByBeliefState(count=count_inc, mu=mu, nu=nu)

return base.GradientTransformation(init_fn, update_fn)
Expand Down Expand Up @@ -689,7 +698,10 @@ def update_fn(updates, state, params=None):
mu_hat = otu.tree_bias_correction(mu, b1, count_inc)
nu_hat = otu.tree_bias_correction(nu, b2, count_inc)
updates = jax.tree.map(
lambda m, v: m / (jnp.sqrt(v + eps_root) + eps), mu_hat, nu_hat)
lambda m, v: None if m is None else m / (jnp.sqrt(v + eps_root) + eps),
mu_hat,
nu_hat,
is_leaf=lambda x: x is None)
return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu)

return base.GradientTransformation(init_fn, update_fn)
Expand Down Expand Up @@ -1166,7 +1178,11 @@ def init_mu(grads, params, mu, nu):

def update_mu(grads, params, mu, nu):
updates = jax.tree.map(mu_addition, grads, params, nu)
return jax.tree.map(lambda m, u: b1 * m + u, mu, updates)
return jax.tree.map(
lambda m, u: None if m is None else b1 * m + u,
mu,
updates,
is_leaf=lambda x: x is None)

def update_fn(updates, state, params):
count_inc = numerics.safe_increment(state.count)
Expand Down
14 changes: 10 additions & 4 deletions optax/_src/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,11 @@ def apply_updates(params: base.Params, updates: base.Updates) -> base.Params:
Updated parameters, with same structure, shape and type as `params`.
"""
return jax.tree.map(
lambda p, u: jnp.asarray(p + u).astype(jnp.asarray(p).dtype),
params, updates)
lambda p, u: (
None if p is None
else jnp.asarray(p + u).astype(jnp.asarray(p).dtype)
),
params, updates, is_leaf=lambda x: x is None)


def incremental_update(
Expand All @@ -66,8 +69,11 @@ def incremental_update(
an updated moving average `step_size*new+(1-step_size)*old` of the params.
"""
return jax.tree.map(
lambda new, old: step_size * new + (1.0 - step_size) * old,
new_tensors, old_tensors)
lambda new, old: (
None if new is None
else step_size * new + (1.0 - step_size) * old
),
new_tensors, old_tensors, is_leaf=lambda x: x is None)


def periodic_update(
Expand Down
10 changes: 9 additions & 1 deletion optax/_src/update_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# ==============================================================================
"""Tests for `update.py`."""

from absl.testing import absltest
from absl.testing import absltest, parameterized

import chex
import jax
Expand Down Expand Up @@ -78,6 +78,14 @@ def test_periodic_update(self):
chex.assert_trees_all_close(
params_2, new_params, atol=1e-10, rtol=1e-5)

@parameterized.named_parameters(
dict(testcase_name='apply_updates', operation=update.apply_updates),
dict(testcase_name='incremental_update',
operation=lambda x, y: update.incremental_update(x,y,1)),
)
def test_none_argument(self, operation):
x = jnp.array([1., 2., 3.])
operation(None, x)

if __name__ == '__main__':
absltest.main()
5 changes: 4 additions & 1 deletion optax/transforms/_adding.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@ def update_fn(updates, state, params):
if params is None:
raise ValueError(base.NO_PARAMS_MSG)
updates = jax.tree.map(
lambda g, p: g + weight_decay * p, updates, params)
lambda g, p: None if g is None else g + weight_decay * p,
updates,
params,
is_leaf=lambda x: x is None)
return updates, state

# If mask is not `None`, apply mask to the gradient transformation.
Expand Down
8 changes: 8 additions & 0 deletions optax/transforms/_adding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,14 @@ def test_add_noise_has_correct_variance_scaling(self):

chex.assert_trees_all_close(updates_i, updates_i_rescaled, rtol=1e-4)

def test_none_argument(self):
weights = (
jnp.ones((2,), dtype=jnp.float32),
dict(
a=jnp.ones((2,), dtype=jnp.float32),
b=jnp.ones((2,), dtype=jnp.float32),))
tf = _adding.add_decayed_weights(0.1, mask=None)
tf.update(None, 0, weights)

if __name__ == "__main__":
absltest.main()
3 changes: 2 additions & 1 deletion optax/transforms/_constraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def update_fn(updates, state, params):
raise ValueError(base.NO_PARAMS_MSG)

updates = jax.tree.map(
lambda p, u: jnp.where((p + u) < 0., -p, u), params, updates)
lambda p, u: None if p is None else jnp.where((p + u) < 0., -p, u),
params, updates, is_leaf=lambda x: x is None)
return updates, state

return base.GradientTransformation(init_fn, update_fn)
Expand Down
6 changes: 6 additions & 0 deletions optax/transforms/_constraining_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,12 @@ def test_zero_nans(self):
(jnp.array(False), jnp.array(False), jnp.array(False))))
chex.assert_trees_all_close(updates, grads)

def test_none_arguments(self):
tf = _constraining.keep_params_nonnegative()
state = tf.init(jnp.array([1.,2.,3.]))
with self.assertRaises(ValueError):
tf.update(jnp.array([1.,2.,3.]), state, None)


if __name__ == '__main__':
absltest.main()
5 changes: 3 additions & 2 deletions optax/tree_utils/_tree_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,10 @@ def tree_add_scalar_mul(
"""
scalar = jnp.asarray(scalar)
return jax.tree.map(
lambda x, y: x + scalar.astype(x.dtype) * y,
lambda x, y: None if x is None else x + scalar.astype(x.dtype) * y,
tree_x,
tree_y)
tree_y,
is_leaf=lambda x: x is None)


_vdot = functools.partial(jnp.vdot, precision=jax.lax.Precision.HIGHEST)
Expand Down
14 changes: 14 additions & 0 deletions optax/tree_utils/_tree_math_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,20 @@ def test_empty_tree_reduce(self):
self.assertEqual(tu.tree_sum(tree), 0)
self.assertEqual(tu.tree_vdot(tree, tree), 0)

@parameterized.named_parameters(
dict(testcase_name='tree_add_scalar_mul',
operation=lambda m: tu.tree_add_scalar_mul(None, 1, m)),
dict(testcase_name='tree_update_moment',
operation=lambda m: tu.tree_update_moment(None, m, 1, 1)),
dict(testcase_name='tree_update_infinity_moment',
operation=lambda m: tu.tree_update_infinity_moment(None, m, 1, 1)),
dict(testcase_name='tree_update_moment_per_elem_norm',
operation=lambda m:
tu.tree_update_moment_per_elem_norm(None, m, 1, 1)),
)
def test_none_arguments(self, operation):
m = jnp.array([1.,2.,3.])
operation(m)

if __name__ == '__main__':
absltest.main()
Loading