Skip to content

Commit

Permalink
Fix global_norm() signature
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Feb 7, 2023
1 parent a77d69c commit 3d0c422
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion optax/_src/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from optax._src import numerics


def global_norm(updates: base.Updates) -> base.Updates:
def global_norm(updates: base.PyTree) -> chex.Array:
"""Compute the global norm across a nested structure of tensors."""
return jnp.sqrt(sum(
jnp.sum(numerics.abs_sq(x)) for x in jax.tree_util.tree_leaves(updates)))
Expand Down

0 comments on commit 3d0c422

Please sign in to comment.