Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing doctests and removing etils.lazy_import for compatibility with Python 3.9 #1195

Merged
merged 1 commit into from
Feb 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions optax/_src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,15 @@
from typing import Any, Optional, Sequence, Union

import chex
from etils import epy
import jax
import jax.numpy as jnp
import jax.scipy.stats.norm as multivariate_normal
from optax import tree_utils as otu
from optax._src import base
from optax._src import linear_algebra
from optax._src import numerics


with epy.lazy_imports():
import jax.scipy.stats.norm as multivariate_normal # pylint: disable=g-import-not-at-top,ungrouped-imports


def tile_second_to_last_dim(a: chex.Array) -> chex.Array:
ones = jnp.ones_like(a)
a = jnp.expand_dims(a, axis=-1)
Expand Down
9 changes: 6 additions & 3 deletions optax/losses/_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,11 +260,12 @@ def softmax_cross_entropy(
Examples:
>>> import optax
>>> import jax.numpy as jnp
>>> jnp.set_printoptions(precision=4)
>>> # example: batch_size = 2, num_classes = 3
>>> logits = jnp.array([[1.2, -0.8, -0.5], [0.9, -1.2, 1.1]])
>>> labels = jnp.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]])
>>> print(optax.softmax_cross_entropy(logits, labels))
[0.2761297 2.951799 ]
[0.2761 2.9518]

References:
`Cross-entropy Loss <https://en.wikipedia.org/wiki/Cross-entropy>`_,
Expand Down Expand Up @@ -329,15 +330,17 @@ def softmax_cross_entropy_with_integer_labels(
Examples:
>>> import optax
>>> import jax.numpy as jnp
>>> jnp.set_printoptions(precision=4)
>>> # example: batch_size = 2, num_classes = 3
>>> logits = jnp.array([[1.2, -0.8, -0.5], [0.9, -1.2, 1.1]])
>>> labels = jnp.array([0, 1])
>>> print(optax.softmax_cross_entropy_with_integer_labels(logits, labels))
[0.2761297 2.951799 ]
[0.2761 2.9518]

>>> import jax.numpy as jnp
>>> import numpy as np
>>> import optax
>>> jnp.set_printoptions(precision=4)
>>> # example: batch_size = (1, 2), num_classes = 12 (i.e. 3 * 4)
>>> shape = (1, 2, 3, 4)
>>> logits = jnp.arange(np.prod(shape), dtype=jnp.float32).reshape(shape)
Expand All @@ -348,7 +351,7 @@ def softmax_cross_entropy_with_integer_labels(
>>> cross_entropy = optax.softmax_cross_entropy_with_integer_labels(
... logits, labels, axis=(2, 3))
>>> print(cross_entropy)
[[6.458669 0.45866907]]
[[6.4587 0.4587]]

References:
`Cross-entropy Loss <https://en.wikipedia.org/wiki/Cross-entropy>`_,
Expand Down
3 changes: 2 additions & 1 deletion optax/losses/_self_supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,14 @@ def triplet_margin_loss(

Examples:
>>> import jax.numpy as jnp, optax, chex
>>> jnp.set_printoptions(precision=4)
>>> anchors = jnp.array([[0.0, 0.0], [1.0, 1.0]])
>>> positives = jnp.array([[0.1, 0.1], [1.1, 1.1]])
>>> negatives = jnp.array([[1.0, 0.0], [0.0, 1.0]])
>>> output = optax.losses.triplet_margin_loss(anchors, positives, negatives,
... margin=1.0)
>>> print(output)
[0.14142442 0.14142442]
[0.1414 0.1414]

Args:
anchors: An array of anchor embeddings, with shape [batch, feature_dim].
Expand Down
12 changes: 6 additions & 6 deletions optax/projections/_projections.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,10 @@ def projection_simplex(tree: Any, scale: chex.Numeric = 1.0) -> Any:
>>> import jax.numpy as jnp
>>> from optax import tree_utils, projections
>>> tree = {"w": jnp.array([2.5, 3.2]), "b": 0.5}
>>> tree_utils.tree_sum(tree)
>>> print(tree_utils.tree_sum(tree))
6.2
>>> new_tree = projections.projection_simplex(tree)
>>> tree_utils.tree_sum(new_tree)
>>> print(tree_utils.tree_sum(new_tree))
1.0000002

.. versionadded:: 0.2.3
Expand Down Expand Up @@ -212,11 +212,11 @@ def projection_l1_ball(tree: Any, scale: float = 1.0) -> Any:
>>> import jax.numpy as jnp
>>> from optax import tree_utils, projections
>>> tree = {"w": jnp.array([2.5, 3.2]), "b": 0.5}
>>> tree_utils.tree_l1_norm(tree)
Array(6.2, dtype=float32)
>>> print(tree_utils.tree_l1_norm(tree))
6.2
>>> new_tree = projections.projection_l1_ball(tree)
>>> tree_utils.tree_l1_norm(new_tree)
Array(1.0000002, dtype=float32)
>>> print(tree_utils.tree_l1_norm(new_tree))
1.0000002

.. versionadded:: 0.2.4
"""
Expand Down