From 3d0c4225e25db8f079bc95a7d1eaa62b3118a8ab Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Fri, 3 Feb 2023 01:11:32 -0800 Subject: [PATCH] Fix global_norm() signature --- optax/_src/linear_algebra.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optax/_src/linear_algebra.py b/optax/_src/linear_algebra.py index 420aff36f..7bd64300a 100644 --- a/optax/_src/linear_algebra.py +++ b/optax/_src/linear_algebra.py @@ -24,7 +24,7 @@ from optax._src import numerics -def global_norm(updates: base.Updates) -> base.Updates: +def global_norm(updates: base.PyTree) -> chex.Array: """Compute the global norm across a nested structure of tensors.""" return jnp.sqrt(sum( jnp.sum(numerics.abs_sq(x)) for x in jax.tree_util.tree_leaves(updates)))