From baf42361216d4845dc14ed863ee6a59651433e11 Mon Sep 17 00:00:00 2001 From: marcocuturi Date: Tue, 30 Jul 2024 11:52:34 +0900 Subject: [PATCH 1/2] correct convergence criterion for GW --- src/ott/solvers/quadratic/gromov_wasserstein.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/ott/solvers/quadratic/gromov_wasserstein.py b/src/ott/solvers/quadratic/gromov_wasserstein.py index 13bfc0f47..70be4c73a 100644 --- a/src/ott/solvers/quadratic/gromov_wasserstein.py +++ b/src/ott/solvers/quadratic/gromov_wasserstein.py @@ -261,7 +261,8 @@ def __call__( linear_state = out.linear_state.set_cost(linearization, True, True) iteration = jnp.sum(out.costs != -1) converged = jnp.logical_and( - iteration < self.max_iterations, jnp.all(out.linear_convergence) + iteration < self.max_iterations, + jnp.nanmean(out.linear_convergence) == 1.0 ) return out.set( linear_state=linear_state, geom=linearization.geom, converged=converged @@ -294,7 +295,7 @@ def init_state( return GWState( costs=-jnp.ones((num_iter,)), - linear_convergence=-jnp.ones((num_iter,)), + linear_convergence=jnp.zeros((num_iter,)) * jnp.nan, linear_state=linear_state, linear_pb=init, old_transport_mass=transport_mass, From 71bbd29947f9b4b37351e6958ba958a1c8b942b1 Mon Sep 17 00:00:00 2001 From: marcocuturi Date: Tue, 30 Jul 2024 22:12:38 +0900 Subject: [PATCH 2/2] fix following comment --- src/ott/solvers/quadratic/gromov_wasserstein.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ott/solvers/quadratic/gromov_wasserstein.py b/src/ott/solvers/quadratic/gromov_wasserstein.py index 70be4c73a..5b0f122be 100644 --- a/src/ott/solvers/quadratic/gromov_wasserstein.py +++ b/src/ott/solvers/quadratic/gromov_wasserstein.py @@ -295,7 +295,7 @@ def init_state( return GWState( costs=-jnp.ones((num_iter,)), - linear_convergence=jnp.zeros((num_iter,)) * jnp.nan, + linear_convergence=jnp.full((num_iter,), fill_value=jnp.nan), linear_state=linear_state, linear_pb=init, old_transport_mass=transport_mass,