diff --git a/optax/losses/_classification.py b/optax/losses/_classification.py index 68bda2146..4fb8eaa5f 100644 --- a/optax/losses/_classification.py +++ b/optax/losses/_classification.py @@ -159,8 +159,7 @@ def softmax_cross_entropy( distributions, with shape `[...]`. """ chex.assert_type([logits], float) - log_probs = jax.nn.log_softmax(logits, axis=-1) - return -jnp.where(labels == 0, 0, labels * log_probs).sum(axis=-1) + return -jnp.sum(labels * jax.nn.log_softmax(logits, axis=-1), axis=-1) def softmax_cross_entropy_with_integer_labels( diff --git a/optax/losses/_classification_test.py b/optax/losses/_classification_test.py index 1ad7d9db2..7d77906eb 100644 --- a/optax/losses/_classification_test.py +++ b/optax/losses/_classification_test.py @@ -29,37 +29,10 @@ class SoftmaxCrossEntropyTest(parameterized.TestCase): def setUp(self): super().setUp() - self.ys = np.array( - [ - [10.0, 1.0, -2.0], - [1.0, 4.0, 0.2], - [-np.inf, 0.0, 0.0], - [-np.inf, 0.0, 0.0], - [-np.inf, 0.0, -np.inf], - ], - dtype=np.float32, - ) - self.ts = np.array( - [ - [0.0, 1.0, 0.0], - [1.0, 0.0, 0.0], - [0.0, 0.5, 0.5], - [0.4, 0.3, 0.3], - [0.0, 1.0, 0.0], - ], - dtype=np.float32, - ) + self.ys = np.array([[10., 1., -2.], [1., 4., 0.2]], dtype=np.float32) + self.ts = np.array([[0., 1., 0.], [1., 0., 0.]], dtype=np.float32) # taken expected outputs from rlax. - self.exp = np.array( - [ - 9.00013, - 3.0696733, - 0.693147, - np.inf, - 0.0, - ], - dtype=np.float32, - ) + self.exp = np.array([9.00013, 3.0696733], dtype=np.float32) @chex.all_variants def test_scalar(self):