diff --git a/optax/_src/utils.py b/optax/_src/utils.py index 75a0ebb79..ba3780cf3 100644 --- a/optax/_src/utils.py +++ b/optax/_src/utils.py @@ -21,19 +21,15 @@ from typing import Any, Optional, Sequence, Union import chex -from etils import epy import jax import jax.numpy as jnp +import jax.scipy.stats.norm as multivariate_normal from optax import tree_utils as otu from optax._src import base from optax._src import linear_algebra from optax._src import numerics -with epy.lazy_imports(): - import jax.scipy.stats.norm as multivariate_normal # pylint: disable=g-import-not-at-top,ungrouped-imports - - def tile_second_to_last_dim(a: chex.Array) -> chex.Array: ones = jnp.ones_like(a) a = jnp.expand_dims(a, axis=-1) diff --git a/optax/losses/_classification.py b/optax/losses/_classification.py index 4abbf936f..bf3ad5d3b 100644 --- a/optax/losses/_classification.py +++ b/optax/losses/_classification.py @@ -260,11 +260,12 @@ def softmax_cross_entropy( Examples: >>> import optax >>> import jax.numpy as jnp + >>> jnp.set_printoptions(precision=4) >>> # example: batch_size = 2, num_classes = 3 >>> logits = jnp.array([[1.2, -0.8, -0.5], [0.9, -1.2, 1.1]]) >>> labels = jnp.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]) >>> print(optax.softmax_cross_entropy(logits, labels)) - [0.2761297 2.951799 ] + [0.2761 2.9518] References: `Cross-entropy Loss `_, @@ -329,15 +330,17 @@ def softmax_cross_entropy_with_integer_labels( Examples: >>> import optax >>> import jax.numpy as jnp + >>> jnp.set_printoptions(precision=4) >>> # example: batch_size = 2, num_classes = 3 >>> logits = jnp.array([[1.2, -0.8, -0.5], [0.9, -1.2, 1.1]]) >>> labels = jnp.array([0, 1]) >>> print(optax.softmax_cross_entropy_with_integer_labels(logits, labels)) - [0.2761297 2.951799 ] + [0.2761 2.9518] >>> import jax.numpy as jnp >>> import numpy as np >>> import optax + >>> jnp.set_printoptions(precision=4) >>> # example: batch_size = (1, 2), num_classes = 12 (i.e. 3 * 4) >>> shape = (1, 2, 3, 4) >>> logits = jnp.arange(np.prod(shape), dtype=jnp.float32).reshape(shape) @@ -348,7 +351,7 @@ def softmax_cross_entropy_with_integer_labels( >>> cross_entropy = optax.softmax_cross_entropy_with_integer_labels( ... logits, labels, axis=(2, 3)) >>> print(cross_entropy) - [[6.458669 0.45866907]] + [[6.4587 0.4587]] References: `Cross-entropy Loss `_, diff --git a/optax/losses/_self_supervised.py b/optax/losses/_self_supervised.py index 5c3a5e22b..cb2d50548 100644 --- a/optax/losses/_self_supervised.py +++ b/optax/losses/_self_supervised.py @@ -133,13 +133,14 @@ def triplet_margin_loss( Examples: >>> import jax.numpy as jnp, optax, chex + >>> jnp.set_printoptions(precision=4) >>> anchors = jnp.array([[0.0, 0.0], [1.0, 1.0]]) >>> positives = jnp.array([[0.1, 0.1], [1.1, 1.1]]) >>> negatives = jnp.array([[1.0, 0.0], [0.0, 1.0]]) >>> output = optax.losses.triplet_margin_loss(anchors, positives, negatives, ... margin=1.0) >>> print(output) - [0.14142442 0.14142442] + [0.1414 0.1414] Args: anchors: An array of anchor embeddings, with shape [batch, feature_dim]. diff --git a/optax/projections/_projections.py b/optax/projections/_projections.py index 91d83ebb3..84e053546 100644 --- a/optax/projections/_projections.py +++ b/optax/projections/_projections.py @@ -148,10 +148,10 @@ def projection_simplex(tree: Any, scale: chex.Numeric = 1.0) -> Any: >>> import jax.numpy as jnp >>> from optax import tree_utils, projections >>> tree = {"w": jnp.array([2.5, 3.2]), "b": 0.5} - >>> tree_utils.tree_sum(tree) + >>> print(tree_utils.tree_sum(tree)) 6.2 >>> new_tree = projections.projection_simplex(tree) - >>> tree_utils.tree_sum(new_tree) + >>> print(tree_utils.tree_sum(new_tree)) 1.0000002 .. versionadded:: 0.2.3 @@ -212,11 +212,11 @@ def projection_l1_ball(tree: Any, scale: float = 1.0) -> Any: >>> import jax.numpy as jnp >>> from optax import tree_utils, projections >>> tree = {"w": jnp.array([2.5, 3.2]), "b": 0.5} - >>> tree_utils.tree_l1_norm(tree) - Array(6.2, dtype=float32) + >>> print(tree_utils.tree_l1_norm(tree)) + 6.2 >>> new_tree = projections.projection_l1_ball(tree) - >>> tree_utils.tree_l1_norm(new_tree) - Array(1.0000002, dtype=float32) + >>> print(tree_utils.tree_l1_norm(new_tree)) + 1.0000002 .. versionadded:: 0.2.4 """