Skip to content

Commit

Permalink
Fixing doctests and removing etils.lazy_import for compatibility with…
Browse files Browse the repository at this point in the history
… Python 3.9

PiperOrigin-RevId: 727340840
  • Loading branch information
rdyro authored and OptaxDev committed Feb 15, 2025
1 parent e8eeea2 commit 40d886b
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 15 deletions.
6 changes: 1 addition & 5 deletions optax/_src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,15 @@
from typing import Any, Optional, Sequence, Union

import chex
from etils import epy
import jax
import jax.numpy as jnp
import jax.scipy.stats.norm as multivariate_normal
from optax import tree_utils as otu
from optax._src import base
from optax._src import linear_algebra
from optax._src import numerics


with epy.lazy_imports():
import jax.scipy.stats.norm as multivariate_normal # pylint: disable=g-import-not-at-top,ungrouped-imports


def tile_second_to_last_dim(a: chex.Array) -> chex.Array:
ones = jnp.ones_like(a)
a = jnp.expand_dims(a, axis=-1)
Expand Down
9 changes: 6 additions & 3 deletions optax/losses/_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,11 +260,12 @@ def softmax_cross_entropy(
Examples:
>>> import optax
>>> import jax.numpy as jnp
>>> jnp.set_printoptions(precision=4)
>>> # example: batch_size = 2, num_classes = 3
>>> logits = jnp.array([[1.2, -0.8, -0.5], [0.9, -1.2, 1.1]])
>>> labels = jnp.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]])
>>> print(optax.softmax_cross_entropy(logits, labels))
[0.2761297 2.951799 ]
[0.2761 2.9518]
References:
`Cross-entropy Loss <https://en.wikipedia.org/wiki/Cross-entropy>`_,
Expand Down Expand Up @@ -329,15 +330,17 @@ def softmax_cross_entropy_with_integer_labels(
Examples:
>>> import optax
>>> import jax.numpy as jnp
>>> jnp.set_printoptions(precision=4)
>>> # example: batch_size = 2, num_classes = 3
>>> logits = jnp.array([[1.2, -0.8, -0.5], [0.9, -1.2, 1.1]])
>>> labels = jnp.array([0, 1])
>>> print(optax.softmax_cross_entropy_with_integer_labels(logits, labels))
[0.2761297 2.951799 ]
[0.2761 2.9518]
>>> import jax.numpy as jnp
>>> import numpy as np
>>> import optax
>>> jnp.set_printoptions(precision=4)
>>> # example: batch_size = (1, 2), num_classes = 12 (i.e. 3 * 4)
>>> shape = (1, 2, 3, 4)
>>> logits = jnp.arange(np.prod(shape), dtype=jnp.float32).reshape(shape)
Expand All @@ -348,7 +351,7 @@ def softmax_cross_entropy_with_integer_labels(
>>> cross_entropy = optax.softmax_cross_entropy_with_integer_labels(
... logits, labels, axis=(2, 3))
>>> print(cross_entropy)
[[6.458669 0.45866907]]
[[6.4587 0.4587]]
References:
`Cross-entropy Loss <https://en.wikipedia.org/wiki/Cross-entropy>`_,
Expand Down
3 changes: 2 additions & 1 deletion optax/losses/_self_supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,14 @@ def triplet_margin_loss(
Examples:
>>> import jax.numpy as jnp, optax, chex
>>> jnp.set_printoptions(precision=4)
>>> anchors = jnp.array([[0.0, 0.0], [1.0, 1.0]])
>>> positives = jnp.array([[0.1, 0.1], [1.1, 1.1]])
>>> negatives = jnp.array([[1.0, 0.0], [0.0, 1.0]])
>>> output = optax.losses.triplet_margin_loss(anchors, positives, negatives,
... margin=1.0)
>>> print(output)
[0.14142442 0.14142442]
[0.1414 0.1414]
Args:
anchors: An array of anchor embeddings, with shape [batch, feature_dim].
Expand Down
12 changes: 6 additions & 6 deletions optax/projections/_projections.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,10 @@ def projection_simplex(tree: Any, scale: chex.Numeric = 1.0) -> Any:
>>> import jax.numpy as jnp
>>> from optax import tree_utils, projections
>>> tree = {"w": jnp.array([2.5, 3.2]), "b": 0.5}
>>> tree_utils.tree_sum(tree)
>>> print(tree_utils.tree_sum(tree))
6.2
>>> new_tree = projections.projection_simplex(tree)
>>> tree_utils.tree_sum(new_tree)
>>> print(tree_utils.tree_sum(new_tree))
1.0000002
.. versionadded:: 0.2.3
Expand Down Expand Up @@ -212,11 +212,11 @@ def projection_l1_ball(tree: Any, scale: float = 1.0) -> Any:
>>> import jax.numpy as jnp
>>> from optax import tree_utils, projections
>>> tree = {"w": jnp.array([2.5, 3.2]), "b": 0.5}
>>> tree_utils.tree_l1_norm(tree)
Array(6.2, dtype=float32)
>>> print(tree_utils.tree_l1_norm(tree))
6.2
>>> new_tree = projections.projection_l1_ball(tree)
>>> tree_utils.tree_l1_norm(new_tree)
Array(1.0000002, dtype=float32)
>>> print(tree_utils.tree_l1_norm(new_tree))
1.0000002
.. versionadded:: 0.2.4
"""
Expand Down

0 comments on commit 40d886b

Please sign in to comment.