Skip to content

Commit

Permalink
Fix/h legendre tests (#529)
Browse files Browse the repository at this point in the history
* Handle `axis=None` in `lse`'s JVP

* Add `h_legendre` test

* Add more tests for `h_transform`

* Add more h_legendre tests

* Increase tolerance for `h_transform`

* Increase rtol/atol

* Remove old file
  • Loading branch information
michalk8 authored Apr 30, 2024
1 parent 62ab665 commit edbc621
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/ott/math/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,11 @@ def logsumexp_jvp(axis, keepdims, return_sign, primals, tangents):
if return_sign:
lse, sign = lse
lse = jnp.where(jnp.isfinite(lse), lse, 0.0)
centered_exp = jnp.exp(mat - jnp.expand_dims(lse, axis=axis))

if axis is not None:
centered_exp = jnp.exp(mat - jnp.expand_dims(lse, axis=axis))
else:
centered_exp = jnp.exp(mat - lse)

if b is None:
res = jnp.sum(centered_exp * tan_mat, axis=axis, keepdims=keepdims)
Expand Down
63 changes: 63 additions & 0 deletions tests/geometry/costs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import numpy as np

from ott.geometry import costs, pointcloud
from ott.math import utils as mu
from ott.solvers import linear

try:
Expand Down Expand Up @@ -128,6 +129,45 @@ def test_bures(self, rng: jax.Array):
np.testing.assert_equal(diffs.shape[0], max_iterations // inner_iterations)


class TestTIRegCost:

@pytest.mark.parametrize(
"cost_fn", [
costs.SqPNorm(p=1.0),
costs.SqPNorm(2.4),
costs.PNormP(p=1.1),
costs.PNormP(1.3),
costs.SqEuclidean()
]
)
def test_h_legendre(self, rng: jax.Array, cost_fn: costs.TICost):
x = jax.random.normal(rng, (15, 3))
h_transform = cost_fn.h_transform(mu.logsumexp)
h_transform = jax.jit(jax.vmap(jax.grad(h_transform)))

np.testing.assert_array_equal(jnp.isfinite(h_transform(x)), True)

@pytest.mark.parametrize("ridge", [1e-12, 1e-6])
def test_h_legendre_sqeucl(self, rng: jax.Array, ridge: float):
n, d = 12, 4
rngs = jax.random.split(rng, 2)
u = jnp.abs(jax.random.uniform(rngs[0], (d,)))
x = jax.random.normal(rngs[1], (n, d))

sqeucl = costs.SqEuclidean()
el_l2 = costs.ElasticL2(scaling_reg=0.0)

h_concave = lambda z: 0.5 * (-sqeucl.h(z) + jnp.dot(z, u))
h_concave_half = lambda z: -sqeucl.h(z) + jnp.dot(z, u)

pred = jax.jit(
jax.vmap(jax.grad(sqeucl.h_transform(h_concave, ridge=ridge)))
)
gt = jax.jit(jax.vmap(jax.grad(el_l2.h_transform(h_concave_half))))

np.testing.assert_allclose(pred(x), gt(x), rtol=1e-5, atol=1e-5)


@pytest.mark.fast()
class TestRegTICost:

Expand Down Expand Up @@ -216,6 +256,29 @@ def test_stronger_regularization_increases_sparsity(
for fwd in [False, True]:
np.testing.assert_array_equal(np.diff(sparsity[fwd]) > 0.0, True)

@pytest.mark.parametrize("d", [5, 10])
def test_h_legendre_elastic_l2(self, rng: jax.Array, d: int):
n, d = 13, d
rngs = jax.random.split(rng, 2)
x = jax.random.normal(rngs[0], (n, d))
u = jax.random.normal(rngs[1], (d,))

elastic_l2 = costs.ElasticL2(scaling_reg=0.0)
p_norm_p = costs.PNormP(p=2)

concave_fn = lambda z: -elastic_l2.h(z) + jnp.dot(z, u)

p_grad_h = jax.jit(
jax.vmap(jax.grad(p_norm_p.h_transform(concave_fn, tol=1e-5)))
)
elastic_grad_h = jax.vmap(
jax.grad(elastic_l2.h_transform(concave_fn, tol=1e-5))
)

np.testing.assert_allclose(
elastic_grad_h(x), p_grad_h(x), rtol=1e-4, atol=1e-4
)


@pytest.mark.skipif(ts_metrics is None, reason="Not supported for Python 3.11")
@pytest.mark.fast()
Expand Down

0 comments on commit edbc621

Please sign in to comment.