diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1c6d25843..93424e7cb 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -16,8 +16,6 @@ jobs: fast-tests: name: Fast tests Python ${{ matrix.python-version }} ${{ matrix.jax-version }} runs-on: ubuntu-latest - # allow tests using the latest JAX to fail - continue-on-error: ${{ matrix.jax-version == 'jax-latest' }} strategy: fail-fast: false matrix: diff --git a/docs/geometry.rst b/docs/geometry.rst index 407d8fca0..932278147 100644 --- a/docs/geometry.rst +++ b/docs/geometry.rst @@ -49,7 +49,7 @@ Geometries low_rank.LRCGeometry low_rank.LRKGeometry epsilon_scheduler.Epsilon - epsilon_scheduler.DEFAULT_SCALE + epsilon_scheduler.DEFAULT_EPSILON_SCALE Cost Functions -------------- diff --git a/pyproject.toml b/pyproject.toml index d1014be4d..dda716bfd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -190,7 +190,7 @@ skip_missing_interpreters = true extras = test # https://github.com/google/flax/issues/3329 - py{3.9,3.10,3.11,3.12,3.13},py3.10-jax-default: neural + py{3.9,3.10,3.11,3.12},py3.10-jax-default: neural pass_env = CUDA_*,PYTEST_*,CI commands_pre = gpu: python -I -m pip install "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html diff --git a/src/ott/experimental/mmsinkhorn.py b/src/ott/experimental/mmsinkhorn.py index 39a30eef3..ed9c27d77 100644 --- a/src/ott/experimental/mmsinkhorn.py +++ b/src/ott/experimental/mmsinkhorn.py @@ -305,7 +305,7 @@ def __call__( cost_t = cost_tensor(x_s, cost_fns) state = self.init_state(n_s) if epsilon is None: - epsilon = epsilon_scheduler.DEFAULT_SCALE * jnp.std(cost_t) + epsilon = epsilon_scheduler.DEFAULT_EPSILON_SCALE * jnp.std(cost_t) const = cost_t, a_s, epsilon out = run(const, self, state) return out.set(x_s=x_s, a_s=a_s, cost_fns=cost_fns, epsilon=epsilon) diff --git a/src/ott/geometry/epsilon_scheduler.py b/src/ott/geometry/epsilon_scheduler.py index 209c45c37..1b63fd4a8 100644 --- a/src/ott/geometry/epsilon_scheduler.py +++ b/src/ott/geometry/epsilon_scheduler.py @@ -11,95 +11,63 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from typing import Optional -import jax import jax.numpy as jnp +import jax.tree_util as jtu -__all__ = ["Epsilon", "DEFAULT_SCALE"] +__all__ = ["Epsilon", "DEFAULT_EPSILON_SCALE"] #: Scaling applied to statistic (mean/std) of cost to compute default epsilon. -DEFAULT_SCALE = 0.05 +DEFAULT_EPSILON_SCALE = 0.05 -@jax.tree_util.register_pytree_node_class +@jtu.register_pytree_node_class class Epsilon: - """Scheduler class for the regularization parameter epsilon. + r"""Scheduler class for the regularization parameter epsilon. - An epsilon scheduler outputs a regularization strength, to be used by in a - Sinkhorn-type algorithm, at any iteration count. That value is either the - final, targeted regularization, or one that is larger, obtained by - geometric decay of an initial value that is larger than the intended target. - Concretely, the value returned by such a scheduler will consider first - the max between ``target`` and ``init * target * decay ** iteration``. - If the ``scale_epsilon`` parameter is provided, that value is used to - multiply the max computed previously by ``scale_epsilon``. + An epsilon scheduler outputs a regularization strength, to be used by the + :term:`Sinkhorn algorithm` or variant, at any iteration count. That value is + either the final, targeted regularization, or one that is larger, obtained by + geometric decay of an initial multiplier. Args: - target: the epsilon regularizer that is targeted. If :obj:`None`, - use :obj:`DEFAULT_SCALE`, currently set at :math:`0.05`. - scale_epsilon: if passed, used to multiply the regularizer, to rescale it. - If :obj:`None`, use :math:`1`. - init: initial value when using epsilon scheduling, understood as multiple - of target value. if passed, ``int * decay ** iteration`` will be used - to rescale target. - decay: geometric decay factor, :math:`<1`. + target: The epsilon regularizer that is targeted. + init: Initial value when using epsilon scheduling, understood as a multiple + of the ``target``, following :math:`\text{init} \text{decay}^{\text{it}}`. + decay: Geometric decay factor, :math:`\leq 1`. """ - def __init__( - self, - target: Optional[float] = None, - scale_epsilon: Optional[float] = None, - init: float = 1.0, - decay: float = 1.0 - ): - self._target_init = target - self._scale_epsilon = scale_epsilon - self._init = init - self._decay = decay + def __init__(self, target: jnp.array, init: float = 1.0, decay: float = 1.0): + assert decay <= 1.0, f"Decay must be <= 1, found {decay}." + self.target = target + self.init = init + self.decay = decay - @property - def target(self) -> float: - """Return the final regularizer value of scheduler.""" - target = DEFAULT_SCALE if self._target_init is None else self._target_init - scale = 1.0 if self._scale_epsilon is None else self._scale_epsilon - return scale * target + def __call__(self, it: Optional[int]) -> jnp.array: + """Intermediate regularizer value at a given iteration number. - def at(self, iteration: Optional[int] = 1) -> float: - """Return (intermediate) regularizer value at a given iteration.""" - if iteration is None: + Args: + it: Current iteration. If :obj:`None`, return :attr:`target`. + + Returns: + The epsilon value at the iteration. + """ + if it is None: return self.target - # check the decay is smaller than 1.0. - decay = jnp.minimum(self._decay, 1.0) # the multiple is either 1.0 or a larger init value that is decayed. - multiple = jnp.maximum(self._init * (decay ** iteration), 1.0) + multiple = jnp.maximum(self.init * (self.decay ** it), 1.0) return multiple * self.target - def done(self, eps: float) -> bool: - """Return whether the scheduler is done at a given value.""" - return eps == self.target - - def done_at(self, iteration: Optional[int]) -> bool: - """Return whether the scheduler is done at a given iteration.""" - return self.done(self.at(iteration)) - - def set(self, **kwargs: Any) -> "Epsilon": - """Return a copy of self, with potential overwrites.""" - kwargs = { - "target": self._target_init, - "scale_epsilon": self._scale_epsilon, - "init": self._init, - "decay": self._decay, - **kwargs - } - return Epsilon(**kwargs) + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(target={self.target:.4f}, " + f"init={self.init:.4f}, decay={self.decay:.4f})" + ) def tree_flatten(self): # noqa: D102 - return ( - self._target_init, self._scale_epsilon, self._init, self._decay - ), None + return (self.target,), {"init": self.init, "decay": self.decay} @classmethod def tree_unflatten(cls, aux_data, children): # noqa: D102 - del aux_data - return cls(*children) + return cls(*children, **aux_data) diff --git a/src/ott/geometry/geometry.py b/src/ott/geometry/geometry.py index ab527722e..96b62f798 100644 --- a/src/ott/geometry/geometry.py +++ b/src/ott/geometry/geometry.py @@ -20,53 +20,48 @@ import jax import jax.numpy as jnp import jax.scipy as jsp +import jax.tree_util as jtu import numpy as np from ott import utils -from ott.geometry import epsilon_scheduler +from ott.geometry import epsilon_scheduler as eps_scheduler from ott.math import utils as mu __all__ = ["Geometry"] -@jax.tree_util.register_pytree_node_class +@jtu.register_pytree_node_class class Geometry: r"""Base class to define ground costs/kernels used in optimal transport. Optimal transport problems are intrinsically geometric: they compute an optimal way to transport mass from one configuration onto another. To define - what is meant by optimality of transport requires defining a cost, of moving - mass from one among several sources, towards one out of multiple targets. - These sources and targets can be provided as points in vectors spaces, grids, - or more generally exclusively described through a (dissimilarity) cost matrix, - or almost equivalently, a (similarity) kernel matrix. - - Once that cost or kernel matrix is set, the ``Geometry`` class provides a - basic operations to be run with the Sinkhorn algorithm. + what is meant by optimality of transport requires defining a + :term:`ground cost`, which quantifies how costly it is to move mass from + one among several source locations, towards one out of multiple + target locations. These source and target locations can be described as + points in vectors spaces, grids, or more generally described + through a (dissimilarity) cost matrix, or almost equivalently, a + (similarity) kernel matrix. This class describes such a + geometry and several useful methods to exploit it. Args: cost_matrix: Cost matrix of shape ``[n, m]``. kernel_matrix: Kernel matrix of shape ``[n, m]``. - epsilon: Regularization parameter. If ``None`` and either - ``relative_epsilon = True`` or ``relative_epsilon = None`` or - ``relative_epsilon = str`` where ``str`` can be either ``mean`` or ``std`` - , this value defaults to a multiple of :attr:`std_cost_matrix` - (or :attr:`mean_cost_matrix` if ``str`` is ``mean``), where that multiple - is set as ``DEFAULT_SCALE`` in ``epsilon_scheduler.py```. - If passed as a - ``float``, then the regularizer that is ultimately used is either that - ``float`` value (if ``relative_epsilon = False`` or ``None``) or that - ``float`` times the :attr:`std_cost_matrix` (if - ``relative_epsilon = True`` or ``relative_epsilon = `std```) or - :attr:`mean_cost_matrix` (if ``relative_epsilon = `mean```). Look for - :class:`~ott.geometry.epsilon_scheduler.Epsilon` when passed as a - scheduler. - relative_epsilon: when :obj:`False`, the parameter ``epsilon`` specifies the - value of the entropic regularization parameter. When :obj:`True` or set - to a string, ``epsilon`` refers to a fraction of the - :attr:`std_cost_matrix` or :attr:`mean_cost_matrix`, which is computed - adaptively from data, depending on whether it is set to ``mean`` or - ``std``. + epsilon: Regularization parameter or a scheduler: + + - ``epsilon = None`` and ``relative_epsilon = None``, use + :math:`0.05 * \text{stddev(cost_matrix)}`. + - if ``epsilon`` is a :class:`float` and ``relative_epsilon = None``, + it directly corresponds to the regularization strength. + - otherwise, ``epsilon`` multiplies the :attr:`mean_cost_matrix` or + :attr:`std_cost_matrix`, depending on the value of ``relative_epsilon``. + + If ``epsilon = None``, the value of + :obj:`DEFAULT_EPSILON_SCALE = 0.05 `. + will be used. + relative_epsilon: Whether ``epsilon`` refers to a fraction of the + :attr:`mean_cost_matrix` or :attr:`std_cost_matrix`. scale_cost: option to rescale the cost matrix. Implemented scalings are 'median', 'mean', 'std' and 'max_cost'. Alternatively, a float factor can be given to rescale the cost such that ``cost_matrix /= scale_cost``. @@ -81,14 +76,14 @@ class Geometry: parameter that is meaningful. That parameter can be provided by the user, or assigned a default value through a simple rule, using for instance the :attr:`mean_cost_matrix` or the :attr:`std_cost_matrix`. - """ + """ # noqa: E501 def __init__( self, cost_matrix: Optional[jnp.ndarray] = None, kernel_matrix: Optional[jnp.ndarray] = None, - epsilon: Optional[Union[float, epsilon_scheduler.Epsilon]] = None, - relative_epsilon: Optional[Union[bool, Literal["mean", "std"]]] = None, + epsilon: Optional[Union[float, eps_scheduler.Epsilon]] = None, + relative_epsilon: Optional[Literal["mean", "std"]] = None, scale_cost: Union[float, Literal["mean", "max_cost", "median", "std"]] = 1.0, src_mask: Optional[jnp.ndarray] = None, @@ -96,13 +91,8 @@ def __init__( ): self._cost_matrix = cost_matrix self._kernel_matrix = kernel_matrix - - # needed for `copy_epsilon`, because of the `isinstance` check - self._epsilon_init = epsilon if isinstance( - epsilon, epsilon_scheduler.Epsilon - ) else epsilon_scheduler.Epsilon(epsilon) + self._epsilon_init = epsilon self._relative_epsilon = relative_epsilon - self._scale_cost = scale_cost self._src_mask = src_mask @@ -150,7 +140,7 @@ def std_cost_matrix(self) -> float: to output :math:`\sigma`. """ tmp = self._masked_geom().apply_square_cost(self._n_normed_ones).squeeze() - tmp = jnp.sum(tmp * self._m_normed_ones) - (self.mean_cost_matrix) ** 2 + tmp = jnp.sum(tmp * self._m_normed_ones) - (self.mean_cost_matrix ** 2) return jnp.sqrt(jax.nn.relu(tmp)) @property @@ -164,35 +154,36 @@ def kernel_matrix(self) -> jnp.ndarray: return self._kernel_matrix ** self.inv_scale_cost @property - def _epsilon(self) -> epsilon_scheduler.Epsilon: - (target, scale_eps, _, _), _ = self._epsilon_init.tree_flatten() - rel = self._relative_epsilon - - # If nothing passed, default to STD - if rel is None and target is None and scale_eps is None: - scale_eps = jax.lax.stop_gradient(self.std_cost_matrix) - # If instructions passed change, otherwise (notably if False) skip. - elif rel is not None: - if rel == "mean" or rel is True: # Legacy option. - scale_eps = jax.lax.stop_gradient(self.mean_cost_matrix) - elif rel == "std": - scale_eps = jax.lax.stop_gradient(self.std_cost_matrix) - # Avoid 0 std, since this would set epsilon to 0.0 and result in - # a division by 0. - scale_eps = jnp.where(scale_eps <= 0.0, 1.0, scale_eps) - - if isinstance(self._epsilon_init, epsilon_scheduler.Epsilon): - return self._epsilon_init.set(scale_epsilon=scale_eps) - - return epsilon_scheduler.Epsilon( - target=epsilon_scheduler.DEFAULT_SCALE if target is None else target, - scale_epsilon=scale_eps + def epsilon_scheduler(self) -> eps_scheduler.Epsilon: + """Epsilon scheduler.""" + if isinstance(self._epsilon_init, eps_scheduler.Epsilon): + return self._epsilon_init + + # no relative epsilon + if self._relative_epsilon is None: + if self._epsilon_init is not None: + return eps_scheduler.Epsilon(self._epsilon_init) + multiplier = eps_scheduler.DEFAULT_EPSILON_SCALE + scale = jax.lax.stop_gradient(self.std_cost_matrix) + return eps_scheduler.Epsilon(target=multiplier * scale) + + if self._relative_epsilon == "std": + scale = jax.lax.stop_gradient(self.std_cost_matrix) + elif self._relative_epsilon == "mean": + scale = jax.lax.stop_gradient(self.mean_cost_matrix) + else: + raise ValueError(f"Invalid relative epsilon: {self._relative_epsilon}.") + + multiplier = ( + eps_scheduler.DEFAULT_EPSILON_SCALE + if self._epsilon_init is None else self._epsilon_init ) + return eps_scheduler.Epsilon(target=multiplier * scale) @property def epsilon(self) -> float: """Epsilon regularization value.""" - return self._epsilon.target + return self.epsilon_scheduler.target @property def shape(self) -> Tuple[int, int]: @@ -257,20 +248,11 @@ def set_scale_cost(self, scale_cost: Union[float, str]) -> "Geometry": def copy_epsilon(self, other: "Geometry") -> "Geometry": """Copy the epsilon parameters from another geometry.""" - other_epsilon = other._epsilon children, aux_data = self.tree_flatten() - - new_children = [] - for child in children: - if isinstance(child, epsilon_scheduler.Epsilon): - child = child.set( - target=other_epsilon._target_init, - scale_epsilon=other_epsilon._scale_epsilon - ) - new_children.append(child) - - aux_data["relative_epsilon"] = False - return type(self).tree_unflatten(aux_data, new_children) + new_geom = type(self).tree_unflatten(aux_data, children) + new_geom._epsilon_init = other.epsilon_scheduler + new_geom._relative_epsilon = other._relative_epsilon # has no effect + return new_geom # The functions below are at the core of Sinkhorn iterations, they # are implemented here in their default form, either in lse (using directly @@ -412,7 +394,7 @@ def update_potential( Returns: new potential value, g if axis=0, f if axis is 1. """ - eps = self._epsilon.at(iteration) + eps = self.epsilon_scheduler(iteration) app_lse = self.apply_lse_kernel(f, g, eps, axis=axis)[0] return eps * log_marginal - jnp.where(jnp.isfinite(app_lse), app_lse, 0) @@ -434,7 +416,7 @@ def update_scaling( Returns: new scaling vector, of size num_b if axis=0, num_a if axis is 1. """ - eps = self._epsilon.at(iteration) + eps = self.epsilon_scheduler(iteration) app_kernel = self.apply_kernel(scaling, eps, axis=axis) return marginal / jnp.where(app_kernel > 0, app_kernel, 1.0) @@ -613,7 +595,6 @@ def apply_cost( app = functools.partial( self._apply_cost_to_vec, axis=axis, fn=fn, is_linear=is_linear ) - # TODO(michalk8): vmap over multiple dims? return jax.vmap(app, in_axes=1, out_axes=1)(arr) def _apply_cost_to_vec( @@ -931,7 +912,7 @@ def tree_flatten(self): # noqa: D102 self._src_mask, self._tgt_mask ), { "scale_cost": self._scale_cost, - "relative_epsilon": self._relative_epsilon + "relative_epsilon": self._relative_epsilon, } @classmethod diff --git a/src/ott/geometry/grid.py b/src/ott/geometry/grid.py index 6df230046..1755b72f6 100644 --- a/src/ott/geometry/grid.py +++ b/src/ott/geometry/grid.py @@ -78,6 +78,7 @@ def __init__( grid_dimension: Optional[int] = None, **kwargs: Any, ): + super().__init__(**kwargs) if ( grid_size is not None and x is not None and num_a is not None and grid_dimension is not None @@ -105,11 +106,10 @@ def __init__( self.kwargs = { "num_a": self.num_a, "grid_size": self.grid_size, - "grid_dimension": self.grid_dimension + "grid_dimension": self.grid_dimension, + "relative_epsilon": self._relative_epsilon, } - super().__init__(**kwargs) - @property def geometries(self) -> List[geometry.Geometry]: """Cost matrices along each dimension of the grid.""" @@ -365,9 +365,8 @@ def tree_flatten(self): # noqa: D102 @classmethod def tree_unflatten(cls, aux_data, children): # noqa: D102 - return cls( - x=children[0], cost_fns=children[1], epsilon=children[2], **aux_data - ) + x, cost_fns, epsilon = children + return cls(x, cost_fns=cost_fns, epsilon=epsilon, **aux_data) def to_LRCGeometry( self, diff --git a/src/ott/geometry/low_rank.py b/src/ott/geometry/low_rank.py index 7a776e17f..ec0d7d7b9 100644 --- a/src/ott/geometry/low_rank.py +++ b/src/ott/geometry/low_rank.py @@ -280,6 +280,7 @@ def tree_flatten(self): # noqa: D102 self._scale_factor, ), { "scale_cost": self._scale_cost, + "relative_epsilon": self._relative_epsilon, } @classmethod @@ -319,7 +320,7 @@ def __init__( epsilon: Optional[float] = None, **kwargs: Any ): - super().__init__(epsilon=epsilon, relative_epsilon=False, **kwargs) + super().__init__(epsilon=epsilon, relative_epsilon=None, **kwargs) self.k1 = k1 self.k2 = k2 diff --git a/src/ott/geometry/pointcloud.py b/src/ott/geometry/pointcloud.py index 7ac9f4425..fafc39b5b 100644 --- a/src/ott/geometry/pointcloud.py +++ b/src/ott/geometry/pointcloud.py @@ -15,6 +15,7 @@ import jax import jax.numpy as jnp +import jax.tree_util as jtu from ott import utils from ott.geometry import costs, geometry, low_rank @@ -23,30 +24,26 @@ __all__ = ["PointCloud"] -@jax.tree_util.register_pytree_node_class +@jtu.register_pytree_node_class class PointCloud(geometry.Geometry): """Defines geometry for 2 point clouds (possibly 1 vs itself). - Creates a geometry, specifying a cost function passed as CostFn type object. - When the number of points is large, setting the ``batch_size`` flag implies - that cost and kernel matrices used to update potentials or scalings - will be recomputed on the fly, rather than stored in memory. More precisely, - when setting ``batch_size``, the cost function will be partially cached by - storing norm values for each point in both point clouds, but the pairwise cost - function evaluations won't be. + When the number of points is large, setting the :attr:`batch_size` flag + implies that cost and kernel matrices used to update potentials or scalings + will be recomputed on the fly, rather than stored in memory. Args: - x : n x d array of n d-dimensional vectors - y : m x d array of m d-dimensional vectors. If `None`, use ``x``. - cost_fn: a CostFn function between two points in dimension d. - batch_size: When ``None``, the cost matrix corresponding to that point cloud - is computed, stored and later re-used at each application of - :meth:`apply_lse_kernel`. When ``batch_size`` is a positive integer, - computations are done in an online fashion, namely the cost matrix is - recomputed at each call of the :meth:`apply_lse_kernel` step, - ``batch_size`` lines at a time, used on a vector and discarded. - The online computation is particularly useful for big point clouds - whose cost matrix does not fit in memory. + x: Array of shape ``[n, d]``. + y: Array of shape ``[m, d]``. If :obj:`None`, use ``x``. + cost_fn: Cost function between two points in dimension :math:`d`. + batch_size: If :obj:`None`, the cost matrix corresponding to that + point cloud is computed, stored and later re-used at each application of + :meth:`apply_lse_kernel`. When ``batch_size`` is a positive integer, + computations are done in an online fashion, namely the cost matrix is + recomputed at each call of the :meth:`apply_lse_kernel` step, + ``batch_size`` lines at a time, used on a vector and discarded. + The online computation is particularly useful for big point clouds + whose cost matrix does not fit in memory. scale_cost: option to rescale the cost matrix. Implemented scalings are 'median', 'mean', 'max_cost', 'max_norm' and 'max_bound'. Alternatively, a float factor can be given to rescale the cost such @@ -255,7 +252,8 @@ def tree_flatten(self): # noqa: D102 self.cost_fn, ), { "batch_size": self._batch_size, - "scale_cost": self._scale_cost + "scale_cost": self._scale_cost, + "relative_epsilon": self._relative_epsilon, } @classmethod diff --git a/src/ott/initializers/quadratic/initializers.py b/src/ott/initializers/quadratic/initializers.py index f640921d1..eeaf2d97a 100644 --- a/src/ott/initializers/quadratic/initializers.py +++ b/src/ott/initializers/quadratic/initializers.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc -from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple +from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence, Tuple import jax.numpy as jnp import jax.tree_util as jtu @@ -128,7 +128,7 @@ def _create_geometry( quad_prob: "quadratic_problem.QuadraticProblem", *, epsilon: float, - relative_epsilon: Optional[bool] = None, + relative_epsilon: Optional[Literal["mean", "std"]] = None, **kwargs: Any, ) -> geometry.Geometry: """Compute initial geometry for linearization. diff --git a/src/ott/neural/methods/monge_gap.py b/src/ott/neural/methods/monge_gap.py index 6dbf7c0d9..31da10f55 100644 --- a/src/ott/neural/methods/monge_gap.py +++ b/src/ott/neural/methods/monge_gap.py @@ -46,7 +46,7 @@ def monge_gap( reference_points: jnp.ndarray, cost_fn: Optional[costs.CostFn] = None, epsilon: Optional[float] = None, - relative_epsilon: Optional[bool] = None, + relative_epsilon: Optional[Literal["mean", "std"]] = None, scale_cost: Union[float, Literal["mean", "max_cost", "median"]] = 1.0, return_output: bool = False, **kwargs: Any @@ -111,7 +111,7 @@ def monge_gap_from_samples( target: jnp.ndarray, cost_fn: Optional[costs.CostFn] = None, epsilon: Optional[float] = None, - relative_epsilon: Optional[bool] = None, + relative_epsilon: Optional[Literal["mean", "std"]] = None, scale_cost: Union[float, Literal["mean", "max_cost", "median"]] = 1.0, return_output: bool = False, **kwargs: Any diff --git a/src/ott/problems/quadratic/quadratic_problem.py b/src/ott/problems/quadratic/quadratic_problem.py index b16970aa6..a96f4e09c 100644 --- a/src/ott/problems/quadratic/quadratic_problem.py +++ b/src/ott/problems/quadratic/quadratic_problem.py @@ -234,7 +234,7 @@ def init_transport_mass(self) -> float: def update_lr_geom( self, lr_sink: "sinkhorn_lr.LRSinkhornOutput", - relative_epsilon: Optional[bool] = None, + relative_epsilon: Optional[Literal["mean", "std"]] = None, ) -> geometry.Geometry: """Recompute (possibly LRC) linearization using LR Sinkhorn output.""" marginal_1 = lr_sink.marginal(1) @@ -270,7 +270,7 @@ def update_linearization( transport: Transport, epsilon: Optional[float] = None, old_transport_mass: float = 1.0, - relative_epsilon: Optional[bool] = None, + relative_epsilon: Optional[Literal["mean", "std"]] = None, ) -> linear_problem.LinearProblem: """Update linearization of GW problem by updating cost matrix. @@ -337,7 +337,7 @@ def update_lr_linearization( self, lr_sink: "sinkhorn_lr.LRSinkhornOutput", *, - relative_epsilon: Optional[bool] = None, + relative_epsilon: Optional[Literal["mean", "std"]] = None, ) -> linear_problem.LinearProblem: """Update a Quad problem linearization using a LR Sinkhorn.""" return linear_problem.LinearProblem( diff --git a/src/ott/solvers/linear/discrete_barycenter.py b/src/ott/solvers/linear/discrete_barycenter.py index 2c1277aa4..513147af7 100644 --- a/src/ott/solvers/linear/discrete_barycenter.py +++ b/src/ott/solvers/linear/discrete_barycenter.py @@ -191,7 +191,7 @@ def body_fn(iteration, const, state, compute_error): geom, a, weights = const errors, d, f_u, g_v = state - eps = geom._epsilon.at(iteration) # pylint: disable=protected-access + eps = geom.epsilon_scheduler(iteration) # pylint: disable=protected-access f_u = parallel_update(f_u, g_v, a, iteration) # kernel_f_u stands for K times potential u if running in scaling mode, # eps log K exp f / eps in lse mode. diff --git a/src/ott/solvers/quadratic/gromov_wasserstein.py b/src/ott/solvers/quadratic/gromov_wasserstein.py index 5d6b7abe5..6affee0db 100644 --- a/src/ott/solvers/quadratic/gromov_wasserstein.py +++ b/src/ott/solvers/quadratic/gromov_wasserstein.py @@ -15,6 +15,7 @@ Any, Callable, Dict, + Literal, NamedTuple, Optional, Sequence, @@ -181,7 +182,7 @@ def __init__( self, linear_solver: sinkhorn.Sinkhorn, epsilon: float = 1.0, - relative_epsilon: Optional[bool] = None, + relative_epsilon: Optional[Literal["mean", "std"]] = None, initializer: Optional[quad_initializers.BaseQuadraticInitializer] = None, warm_start: bool = False, progress_fn: Optional[ProgressCallbackFn] = None, diff --git a/tests/geometry/geometry_test.py b/tests/geometry/geometry_test.py index ebb5703f4..06eff9fdd 100644 --- a/tests/geometry/geometry_test.py +++ b/tests/geometry/geometry_test.py @@ -28,7 +28,7 @@ class TestCostMeanStd: def test_cost_stdmean(self, rng: jax.Array, geom_type: str): """Test consistency of std evaluation.""" n, m, d = 5, 18, 10 - default_scale = epsilon_scheduler.DEFAULT_SCALE + default_scale = epsilon_scheduler.DEFAULT_EPSILON_SCALE rngs = jax.random.split(rng, 5) x = jax.random.normal(rngs[0], (n, d)) y = jax.random.normal(rngs[1], (m, d)) + 1 diff --git a/tests/solvers/linear/sinkhorn_test.py b/tests/solvers/linear/sinkhorn_test.py index a90ef078b..87df4ddfb 100644 --- a/tests/solvers/linear/sinkhorn_test.py +++ b/tests/solvers/linear/sinkhorn_test.py @@ -146,8 +146,8 @@ def test_autoepsilon(self): ) np.testing.assert_allclose( - geom_1._epsilon.at(2) * scale ** 2, - geom_2._epsilon.at(2), + geom_1.epsilon_scheduler(2) * scale ** 2, + geom_2.epsilon_scheduler(2), rtol=1e-3, atol=1e-3 ) @@ -167,9 +167,10 @@ def test_autoepsilon_with_decay( tau_b: float ): """Check that variations in init/decay work, and result in same solution.""" - epsilon = epsilon_scheduler.Epsilon(init=init, decay=decay) - geom1 = pointcloud.PointCloud(self.x, self.y, epsilon=epsilon) - geom2 = pointcloud.PointCloud(self.x, self.y) + geom = pointcloud.PointCloud(self.x, self.y) + target = epsilon_scheduler.DEFAULT_EPSILON_SCALE * geom.std_cost_matrix + epsilon = epsilon_scheduler.Epsilon(target, init=init, decay=decay) + geom_eps = pointcloud.PointCloud(self.x, self.y, epsilon=epsilon) run_fn = jax.jit( linear.solve, static_argnames=[ @@ -178,7 +179,7 @@ def test_autoepsilon_with_decay( ) out_1 = run_fn( - geom1, + geom_eps, self.a, self.b, tau_a=tau_a, @@ -188,7 +189,7 @@ def test_autoepsilon_with_decay( recenter_potentials=True ) out_2 = run_fn( - geom2, + geom, self.a, self.b, tau_a=tau_a, diff --git a/tests/solvers/quadratic/gw_test.py b/tests/solvers/quadratic/gw_test.py index d7dd5a9c6..d84a2aa85 100644 --- a/tests/solvers/quadratic/gw_test.py +++ b/tests/solvers/quadratic/gw_test.py @@ -433,7 +433,9 @@ def test_relative_epsilon( linear_solver = sinkhorn.Sinkhorn() solver = gromov_wasserstein.GromovWasserstein( - linear_solver, epsilon=eps, relative_epsilon=True + linear_solver, + epsilon=eps, + relative_epsilon="std", ) out = solver(prob) diff --git a/tests/tools/sinkhorn_divergence_test.py b/tests/tools/sinkhorn_divergence_test.py index f5995e21e..f893c70c3 100644 --- a/tests/tools/sinkhorn_divergence_test.py +++ b/tests/tools/sinkhorn_divergence_test.py @@ -150,7 +150,11 @@ def test_euclidean_autoepsilon(self): assert div > 0.0 assert len(out.potentials) == 3 assert len(out.geoms) == 3 - np.testing.assert_allclose(out.geoms[0].epsilon, out.geoms[1].epsilon) + + geom0, geom1, *_ = out.geoms + + assert geom0.epsilon_scheduler is not geom1.epsilon_scheduler + assert geom0.epsilon == geom1.epsilon def test_euclidean_autoepsilon_not_share_epsilon(self): rngs = jax.random.split(self.rng, 2)