Skip to content

Commit

Permalink
Feature/refactor initializers (#599)
Browse files Browse the repository at this point in the history
* Refactor initializers

* Rename type alias

* Fix new initializers

* Update init in the notebook

* Fix rest of the tests and types

* Fix docs phrasing

* Better docs.
  • Loading branch information
michalk8 authored Nov 26, 2024
1 parent b479e5f commit 2ffd45f
Show file tree
Hide file tree
Showing 16 changed files with 158 additions and 404 deletions.
10 changes: 7 additions & 3 deletions docs/tutorials/neural/400_MetaOT.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -582,15 +582,19 @@
" ot_problem = linear_problem.LinearProblem(geom, a=a, b=b)\n",
" solver = sinkhorn.Sinkhorn(**sink_kwargs)\n",
"\n",
" base_sink_out = solver(ot_problem, init=(None, None))\n",
" base_sink_out = solver(ot_problem, init=None)\n",
"\n",
" init_dual_a = meta_initializer.init_dual_a(ot_problem, lse_mode=True)\n",
" meta_sink_out = solver(ot_problem, init=(init_dual_a, None))\n",
" meta_sink_out = solver(\n",
" ot_problem, init=(init_dual_a, jnp.zeros_like(init_dual_a))\n",
" )\n",
"\n",
" init_dual_a = initializers.GaussianInitializer().init_dual_a(\n",
" ot_problem, lse_mode=True\n",
" )\n",
" gaus_sink_out = solver(ot_problem, init=(init_dual_a, None))\n",
" gaus_sink_out = solver(\n",
" ot_problem, init=(init_dual_a, jnp.zeros_like(init_dual_a))\n",
" )\n",
"\n",
" error_log[\"base\"].append(base_sink_out.errors)\n",
" error_log[\"meta_ot\"].append(meta_sink_out.errors)\n",
Expand Down
45 changes: 16 additions & 29 deletions src/ott/initializers/linear/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class SinkhornInitializer(abc.ABC):
"""Base class for Sinkhorn initializers."""

@abc.abstractmethod
def init_dual_a(
def init_fu(
self,
ot_prob: linear_problem.LinearProblem,
lse_mode: bool,
Expand All @@ -50,7 +50,7 @@ def init_dual_a(
"""

@abc.abstractmethod
def init_dual_b(
def init_gv(
self,
ot_prob: linear_problem.LinearProblem,
lse_mode: bool,
Expand All @@ -70,8 +70,6 @@ def init_dual_b(
def __call__(
self,
ot_prob: linear_problem.LinearProblem,
a: Optional[jnp.ndarray],
b: Optional[jnp.ndarray],
lse_mode: bool,
rng: Optional[jax.Array] = None,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
Expand All @@ -90,25 +88,15 @@ def __call__(
The initial potentials/scalings.
"""
rng = utils.default_prng_key(rng)
rng_x, rng_y = jax.random.split(rng, 2)
n, m = ot_prob.geom.shape
if a is None:
a = self.init_dual_a(ot_prob, lse_mode=lse_mode, rng=rng_x)
if b is None:
b = self.init_dual_b(ot_prob, lse_mode=lse_mode, rng=rng_y)

assert a.shape == (
n,
), f"Expected `f_u` to have shape `{n,}`, found `{a.shape}`."
assert b.shape == (
m,
), f"Expected `g_v` to have shape `{m,}`, found `{b.shape}`."
rng_f, rng_g = jax.random.split(rng, 2)
fu = self.init_fu(ot_prob, lse_mode=lse_mode, rng=rng_f)
gv = self.init_gv(ot_prob, lse_mode=lse_mode, rng=rng_g)

# cancel dual variables for zero weights
a = jnp.where(ot_prob.a > 0.0, a, -jnp.inf if lse_mode else 0.0)
b = jnp.where(ot_prob.b > 0.0, b, -jnp.inf if lse_mode else 0.0)

return a, b
mask_value = -jnp.inf if lse_mode else 0.0
fu = jnp.where(ot_prob.a > 0.0, fu, mask_value)
gv = jnp.where(ot_prob.b > 0.0, gv, mask_value)
return fu, gv

def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102
return [], {}
Expand All @@ -124,7 +112,7 @@ def tree_unflatten( # noqa: D102
class DefaultInitializer(SinkhornInitializer):
"""Default initialization of Sinkhorn dual potentials/primal scalings."""

def init_dual_a( # noqa: D102
def init_fu( # noqa: D102
self,
ot_prob: linear_problem.LinearProblem,
lse_mode: bool,
Expand All @@ -133,7 +121,7 @@ def init_dual_a( # noqa: D102
del rng
return jnp.zeros_like(ot_prob.a) if lse_mode else jnp.ones_like(ot_prob.a)

def init_dual_b( # noqa: D102
def init_gv( # noqa: D102
self,
ot_prob: linear_problem.LinearProblem,
lse_mode: bool,
Expand All @@ -154,7 +142,7 @@ class GaussianInitializer(DefaultInitializer):
to initialize Sinkhorn potentials/scalings.
"""

def init_dual_a( # noqa: D102
def init_fu( # noqa: D102
self,
ot_prob: linear_problem.LinearProblem,
lse_mode: bool,
Expand Down Expand Up @@ -241,7 +229,7 @@ def cond_fn(state: Tuple[jnp.ndarray, float, int]) -> bool:

return f_potential

def init_dual_a(
def init_fu(
self,
ot_prob: linear_problem.LinearProblem,
lse_mode: bool,
Expand Down Expand Up @@ -304,9 +292,8 @@ class SubsampleInitializer(DefaultInitializer):
:class:`~ott.geometry.pointcloud.PointCloud`.
subsample_n_y: number of points to subsample from the second measure in
:class:`~ott.geometry.pointcloud.PointCloud`.
If ``None``, use ``subsample_n_x``.
kwargs: Keyword arguments for
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn`.
If :obj:`None`, use ``subsample_n_x``.
kwargs: Keyword arguments for :func:`~ott.solvers.linear.solve`.
"""

def __init__(
Expand All @@ -320,7 +307,7 @@ def __init__(
self.subsample_n_y = subsample_n_y or subsample_n_x
self.sinkhorn_kwargs = kwargs

def init_dual_a( # noqa: D102
def init_fu( # noqa: D102
self,
ot_prob: linear_problem.LinearProblem,
lse_mode: bool,
Expand Down
69 changes: 8 additions & 61 deletions src/ott/initializers/linear/initializers_lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@
if TYPE_CHECKING:
from ott.problems.linear import linear_problem
from ott.problems.quadratic import quadratic_problem
from ott.solvers.linear import sinkhorn, sinkhorn_lr
from ott.solvers.quadratic import gromov_wasserstein_lr
from ott.solvers.linear import sinkhorn

Problem_t = Union["linear_problem.LinearProblem",
"quadratic_problem.QuadraticProblem"]
Expand Down Expand Up @@ -96,7 +95,7 @@ def init_r(
"""Initialize the low-rank factor :math:`R`.
Args:
ot_prob: Linear OT problem.
ot_prob: OT problem.
rng: Random key for seeding.
init_g: Initial value for :math:`g` factor.
kwargs: Additional keyword arguments.
Expand All @@ -123,65 +122,16 @@ def init_g(
Array of shape ``[rank,]``.
"""

@classmethod
def from_solver(
cls,
solver: Union["sinkhorn_lr.LRSinkhorn",
"gromov_wasserstein_lr.LRGromovWasserstein"],
*,
kind: Literal["random", "rank2", "k-means", "generalized-k-means"],
**kwargs: Any,
) -> "LRInitializer":
"""Create a low-rank initializer from a linear or quadratic solver.
Args:
solver: Low-rank linear or quadratic solver.
kind: Which initializer to instantiate.
kwargs: Keyword arguments when creating the initializer.
Returns:
Low-rank initializer.
"""
rank = solver.rank
sinkhorn_kwargs = {
"norm_error": solver._norm_error,
"lse_mode": solver.lse_mode,
"implicit_diff": solver.implicit_diff,
"use_danskin": solver.use_danskin
}

if kind == "random":
return RandomInitializer(rank, **kwargs)
if kind == "rank2":
return Rank2Initializer(rank, **kwargs)
if kind == "k-means":
return KMeansInitializer(rank, sinkhorn_kwargs=sinkhorn_kwargs, **kwargs)
if kind == "generalized-k-means":
return GeneralizedKMeansInitializer(
rank, sinkhorn_kwargs=sinkhorn_kwargs, **kwargs
)
raise NotImplementedError(f"Initializer `{kind}` is not implemented.")

def __call__(
self,
ot_prob: Problem_t,
q: Optional[jnp.ndarray] = None,
r: Optional[jnp.ndarray] = None,
g: Optional[jnp.ndarray] = None,
*,
rng: Optional[jax.Array] = None,
**kwargs: Any
**kwargs: Any,
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Initialize the factors :math:`Q`, :math:`R` and :math:`g`.
Args:
ot_prob: OT problem.
q: Factor of shape ``[n, rank]``. If `None`, it will be initialized
using :meth:`init_q`.
r: Factor of shape ``[m, rank]``. If `None`, it will be initialized
using :meth:`init_r`.
g: Factor of shape ``[rank,]``. If `None`, it will be initialized
using :meth:`init_g`.
rng: Random key for seeding.
kwargs: Additional keyword arguments for :meth:`init_q`, :meth:`init_r`
and :meth:`init_g`.
Expand All @@ -190,14 +140,11 @@ def __call__(
The factors :math:`Q`, :math:`R` and :math:`g`, respectively.
"""
rng = utils.default_prng_key(rng)
rng1, rng2, rng3 = jax.random.split(rng, 3)

if g is None:
g = self.init_g(ot_prob, rng1, **kwargs)
if q is None:
q = self.init_q(ot_prob, rng2, init_g=g, **kwargs)
if r is None:
r = self.init_r(ot_prob, rng3, init_g=g, **kwargs)
rng_g, rng_q, rng_r = jax.random.split(rng, 3)

g = self.init_g(ot_prob, rng_g, **kwargs)
q = self.init_q(ot_prob, rng_q, init_g=g, **kwargs)
r = self.init_r(ot_prob, rng_r, init_g=g, **kwargs)

assert g.shape == (self.rank,)
assert q.shape == (ot_prob.a.shape[0], self.rank)
Expand Down
5 changes: 3 additions & 2 deletions src/ott/initializers/neural/meta_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import jax
import jax.numpy as jnp
import jax.tree_util as jtu

import optax
from flax import linen as nn
Expand All @@ -31,7 +32,7 @@
__all__ = ["MetaInitializer"]


@jax.tree_util.register_pytree_node_class
@jtu.register_pytree_node_class
class MetaInitializer(initializers.DefaultInitializer):
"""Meta OT Initializer with a fixed geometry :cite:`amos:22`.
Expand Down Expand Up @@ -133,7 +134,7 @@ def update(
"""
return self.update_impl(state, a, b)

def init_dual_a( # noqa: D102
def init_fu( # noqa: D102
self,
ot_prob: linear_problem.LinearProblem,
lse_mode: bool,
Expand Down
30 changes: 11 additions & 19 deletions src/ott/initializers/quadratic/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
import abc
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple

import jax
import jax.numpy as jnp
import jax.tree_util as jtu

from ott.geometry import geometry

Expand All @@ -26,16 +26,9 @@
__all__ = ["BaseQuadraticInitializer", "QuadraticInitializer"]


@jax.tree_util.register_pytree_node_class
@jtu.register_pytree_node_class
class BaseQuadraticInitializer(abc.ABC):
"""Base class for quadratic initializers.
Args:
kwargs: Keyword arguments.
"""

def __init__(self, **kwargs: Any):
self._kwargs = kwargs
"""Base class for quadratic initializers."""

def __call__(
self, quad_prob: "quadratic_problem.QuadraticProblem", **kwargs: Any
Expand All @@ -47,7 +40,7 @@ def __call__(
kwargs: Additional keyword arguments.
Returns:
Linear problem.
The linearized problem.
"""
from ott.problems.linear import linear_problem

Expand Down Expand Up @@ -80,7 +73,7 @@ def _create_geometry(
"""

def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102
return [], self._kwargs
return [], {}

@classmethod
def tree_unflatten( # noqa: D102
Expand All @@ -89,6 +82,7 @@ def tree_unflatten( # noqa: D102
return cls(*children, **aux_data)


@jtu.register_pytree_node_class
class QuadraticInitializer(BaseQuadraticInitializer):
r"""Initialize a linear problem locally around a selected coupling.
Expand Down Expand Up @@ -125,10 +119,8 @@ class QuadraticInitializer(BaseQuadraticInitializer):
defaults to the product coupling :math:`ab^T`.
"""

def __init__(
self, init_coupling: Optional[jnp.ndarray] = None, **kwargs: Any
):
super().__init__(**kwargs)
def __init__(self, init_coupling: Optional[jnp.ndarray] = None):
super().__init__()
self.init_coupling = init_coupling

def _create_geometry(
Expand All @@ -145,10 +137,10 @@ def _create_geometry(
quad_prob: Quadratic OT problem.
epsilon: Epsilon regularization.
relative_epsilon: Flag, use `relative_epsilon` or not in geometry.
kwargs: Keyword arguments for :class:`~ott.geometry.geometry.Geometry`.
kwargs: Unused.
Returns:
The initial geometry used to initialize the linearized problem.
Geometry used to initialize the linearized problem.
"""
from ott.problems.quadratic import quadratic_problem

Expand Down Expand Up @@ -188,4 +180,4 @@ def _create_geometry(
)

def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102
return [self.init_coupling], self._kwargs
return [self.init_coupling], {}
Loading

0 comments on commit 2ffd45f

Please sign in to comment.