From f14846cf920502773d7094744146c322b713b54d Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Fri, 29 Nov 2024 14:08:59 +0100 Subject: [PATCH 01/13] Don't install tensorstore on 3.13 yet --- .github/workflows/tests.yml | 2 -- pyproject.toml | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1c6d25843..93424e7cb 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -16,8 +16,6 @@ jobs: fast-tests: name: Fast tests Python ${{ matrix.python-version }} ${{ matrix.jax-version }} runs-on: ubuntu-latest - # allow tests using the latest JAX to fail - continue-on-error: ${{ matrix.jax-version == 'jax-latest' }} strategy: fail-fast: false matrix: diff --git a/pyproject.toml b/pyproject.toml index d1014be4d..dda716bfd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -190,7 +190,7 @@ skip_missing_interpreters = true extras = test # https://github.com/google/flax/issues/3329 - py{3.9,3.10,3.11,3.12,3.13},py3.10-jax-default: neural + py{3.9,3.10,3.11,3.12},py3.10-jax-default: neural pass_env = CUDA_*,PYTEST_*,CI commands_pre = gpu: python -I -m pip install "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html From b4f72459bc6bbde0b29d34c4ca3cb15ef1a23305 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Fri, 29 Nov 2024 19:15:08 +0100 Subject: [PATCH 02/13] Fix epsilon scheduler --- src/ott/geometry/epsilon_scheduler.py | 85 ++++++----------- src/ott/geometry/geometry.py | 94 ++++++++----------- src/ott/geometry/grid.py | 11 +-- src/ott/geometry/low_rank.py | 3 +- src/ott/geometry/pointcloud.py | 3 +- .../initializers/quadratic/initializers.py | 4 +- src/ott/neural/methods/monge_gap.py | 4 +- .../problems/quadratic/quadratic_problem.py | 6 +- src/ott/solvers/linear/discrete_barycenter.py | 2 +- .../solvers/quadratic/gromov_wasserstein.py | 3 +- tests/solvers/linear/sinkhorn_diff_test.py | 4 +- tests/solvers/linear/sinkhorn_test.py | 15 +-- tests/solvers/quadratic/gw_test.py | 4 +- tests/tools/sinkhorn_divergence_test.py | 6 +- 14 files changed, 106 insertions(+), 138 deletions(-) diff --git a/src/ott/geometry/epsilon_scheduler.py b/src/ott/geometry/epsilon_scheduler.py index 209c45c37..61b38b48c 100644 --- a/src/ott/geometry/epsilon_scheduler.py +++ b/src/ott/geometry/epsilon_scheduler.py @@ -11,10 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from typing import Optional -import jax import jax.numpy as jnp +import jax.tree_util as jtu __all__ = ["Epsilon", "DEFAULT_SCALE"] @@ -22,9 +22,9 @@ DEFAULT_SCALE = 0.05 -@jax.tree_util.register_pytree_node_class +@jtu.register_pytree_node_class class Epsilon: - """Scheduler class for the regularization parameter epsilon. + r"""Scheduler class for the regularization parameter epsilon. An epsilon scheduler outputs a regularization strength, to be used by in a Sinkhorn-type algorithm, at any iteration count. That value is either the @@ -36,70 +36,43 @@ class Epsilon: multiply the max computed previously by ``scale_epsilon``. Args: - target: the epsilon regularizer that is targeted. If :obj:`None`, - use :obj:`DEFAULT_SCALE`, currently set at :math:`0.05`. - scale_epsilon: if passed, used to multiply the regularizer, to rescale it. - If :obj:`None`, use :math:`1`. - init: initial value when using epsilon scheduling, understood as multiple + target: The epsilon regularizer that is targeted. + init: Initial value when using epsilon scheduling, understood as multiple of target value. if passed, ``int * decay ** iteration`` will be used to rescale target. - decay: geometric decay factor, :math:`<1`. + decay: Geometric decay factor, :math:`\leq 1`. """ - def __init__( - self, - target: Optional[float] = None, - scale_epsilon: Optional[float] = None, - init: float = 1.0, - decay: float = 1.0 - ): - self._target_init = target - self._scale_epsilon = scale_epsilon - self._init = init - self._decay = decay + def __init__(self, target: jnp.array, init: float = 1.0, decay: float = 1.0): + assert decay <= 1.0, f"Decay must be <= 1, found {decay}." + self.target = target + self.init = init + self.decay = decay - @property - def target(self) -> float: - """Return the final regularizer value of scheduler.""" - target = DEFAULT_SCALE if self._target_init is None else self._target_init - scale = 1.0 if self._scale_epsilon is None else self._scale_epsilon - return scale * target + def __call__(self, it: Optional[int]) -> jnp.array: + """Return (intermediate) regularizer value at a given iteration. - def at(self, iteration: Optional[int] = 1) -> float: - """Return (intermediate) regularizer value at a given iteration.""" - if iteration is None: + Args: + it: Current iteration. If :obj:`None`, return :attr:`target`. + + Returns: + The epsilon value at the iteration. + """ + if it is None: return self.target - # check the decay is smaller than 1.0. - decay = jnp.minimum(self._decay, 1.0) # the multiple is either 1.0 or a larger init value that is decayed. - multiple = jnp.maximum(self._init * (decay ** iteration), 1.0) + multiple = jnp.maximum(self.init * (self.decay ** it), 1.0) return multiple * self.target - def done(self, eps: float) -> bool: - """Return whether the scheduler is done at a given value.""" - return eps == self.target - - def done_at(self, iteration: Optional[int]) -> bool: - """Return whether the scheduler is done at a given iteration.""" - return self.done(self.at(iteration)) - - def set(self, **kwargs: Any) -> "Epsilon": - """Return a copy of self, with potential overwrites.""" - kwargs = { - "target": self._target_init, - "scale_epsilon": self._scale_epsilon, - "init": self._init, - "decay": self._decay, - **kwargs - } - return Epsilon(**kwargs) + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(target={self.target:.4f}, " + f"init={self.init:.4f}, decay={self.decay:.4f})" + ) def tree_flatten(self): # noqa: D102 - return ( - self._target_init, self._scale_epsilon, self._init, self._decay - ), None + return (self.target,), {"init": self.init, "decay": self.decay} @classmethod def tree_unflatten(cls, aux_data, children): # noqa: D102 - del aux_data - return cls(*children) + return cls(*children, **aux_data) diff --git a/src/ott/geometry/geometry.py b/src/ott/geometry/geometry.py index ab527722e..a2df1d5d3 100644 --- a/src/ott/geometry/geometry.py +++ b/src/ott/geometry/geometry.py @@ -20,16 +20,17 @@ import jax import jax.numpy as jnp import jax.scipy as jsp +import jax.tree_util as jtu import numpy as np from ott import utils -from ott.geometry import epsilon_scheduler +from ott.geometry import epsilon_scheduler as eps_scheduler from ott.math import utils as mu __all__ = ["Geometry"] -@jax.tree_util.register_pytree_node_class +@jtu.register_pytree_node_class class Geometry: r"""Base class to define ground costs/kernels used in optimal transport. @@ -47,8 +48,8 @@ class Geometry: Args: cost_matrix: Cost matrix of shape ``[n, m]``. kernel_matrix: Kernel matrix of shape ``[n, m]``. - epsilon: Regularization parameter. If ``None`` and either - ``relative_epsilon = True`` or ``relative_epsilon = None`` or + epsilon: Regularization parameter. If ``None`` and + ``relative_epsilon = None`` ``relative_epsilon = str`` where ``str`` can be either ``mean`` or ``std`` , this value defaults to a multiple of :attr:`std_cost_matrix` (or :attr:`mean_cost_matrix` if ``str`` is ``mean``), where that multiple @@ -87,8 +88,8 @@ def __init__( self, cost_matrix: Optional[jnp.ndarray] = None, kernel_matrix: Optional[jnp.ndarray] = None, - epsilon: Optional[Union[float, epsilon_scheduler.Epsilon]] = None, - relative_epsilon: Optional[Union[bool, Literal["mean", "std"]]] = None, + epsilon: Optional[Union[float, eps_scheduler.Epsilon]] = None, + relative_epsilon: Optional[Literal["mean", "std"]] = None, scale_cost: Union[float, Literal["mean", "max_cost", "median", "std"]] = 1.0, src_mask: Optional[jnp.ndarray] = None, @@ -96,13 +97,8 @@ def __init__( ): self._cost_matrix = cost_matrix self._kernel_matrix = kernel_matrix - - # needed for `copy_epsilon`, because of the `isinstance` check - self._epsilon_init = epsilon if isinstance( - epsilon, epsilon_scheduler.Epsilon - ) else epsilon_scheduler.Epsilon(epsilon) + self._epsilon_init = epsilon self._relative_epsilon = relative_epsilon - self._scale_cost = scale_cost self._src_mask = src_mask @@ -150,7 +146,7 @@ def std_cost_matrix(self) -> float: to output :math:`\sigma`. """ tmp = self._masked_geom().apply_square_cost(self._n_normed_ones).squeeze() - tmp = jnp.sum(tmp * self._m_normed_ones) - (self.mean_cost_matrix) ** 2 + tmp = jnp.sum(tmp * self._m_normed_ones) - (self.mean_cost_matrix ** 2) return jnp.sqrt(jax.nn.relu(tmp)) @property @@ -164,35 +160,36 @@ def kernel_matrix(self) -> jnp.ndarray: return self._kernel_matrix ** self.inv_scale_cost @property - def _epsilon(self) -> epsilon_scheduler.Epsilon: - (target, scale_eps, _, _), _ = self._epsilon_init.tree_flatten() - rel = self._relative_epsilon - - # If nothing passed, default to STD - if rel is None and target is None and scale_eps is None: - scale_eps = jax.lax.stop_gradient(self.std_cost_matrix) - # If instructions passed change, otherwise (notably if False) skip. - elif rel is not None: - if rel == "mean" or rel is True: # Legacy option. - scale_eps = jax.lax.stop_gradient(self.mean_cost_matrix) - elif rel == "std": - scale_eps = jax.lax.stop_gradient(self.std_cost_matrix) - # Avoid 0 std, since this would set epsilon to 0.0 and result in - # a division by 0. - scale_eps = jnp.where(scale_eps <= 0.0, 1.0, scale_eps) - - if isinstance(self._epsilon_init, epsilon_scheduler.Epsilon): - return self._epsilon_init.set(scale_epsilon=scale_eps) - - return epsilon_scheduler.Epsilon( - target=epsilon_scheduler.DEFAULT_SCALE if target is None else target, - scale_epsilon=scale_eps + def epsilon_scheduler(self) -> eps_scheduler.Epsilon: + """TODO.""" + if isinstance(self._epsilon_init, eps_scheduler.Epsilon): + return self._epsilon_init + + # no relative epsilon + if self._relative_epsilon is None: + if self._epsilon_init is not None: + return eps_scheduler.Epsilon(self._epsilon_init) + multiplier = eps_scheduler.DEFAULT_SCALE + scale = jax.lax.stop_gradient(self.std_cost_matrix) + return eps_scheduler.Epsilon(target=multiplier * scale) + + if self._relative_epsilon == "std": + scale = jax.lax.stop_gradient(self.std_cost_matrix) + elif self._relative_epsilon == "mean": + scale = jax.lax.stop_gradient(self.mean_cost_matrix) + else: + raise ValueError(f"Invalid relative epsilon: {self._relative_epsilon}.") + + multiplier = ( + eps_scheduler.DEFAULT_SCALE + if self._epsilon_init is None else self._epsilon_init ) + return eps_scheduler.Epsilon(target=multiplier * scale) @property def epsilon(self) -> float: """Epsilon regularization value.""" - return self._epsilon.target + return self.epsilon_scheduler.target @property def shape(self) -> Tuple[int, int]: @@ -257,20 +254,11 @@ def set_scale_cost(self, scale_cost: Union[float, str]) -> "Geometry": def copy_epsilon(self, other: "Geometry") -> "Geometry": """Copy the epsilon parameters from another geometry.""" - other_epsilon = other._epsilon children, aux_data = self.tree_flatten() - - new_children = [] - for child in children: - if isinstance(child, epsilon_scheduler.Epsilon): - child = child.set( - target=other_epsilon._target_init, - scale_epsilon=other_epsilon._scale_epsilon - ) - new_children.append(child) - - aux_data["relative_epsilon"] = False - return type(self).tree_unflatten(aux_data, new_children) + new_geom = type(self).tree_unflatten(aux_data, children) + new_geom._epsilon_init = other.epsilon_scheduler + new_geom._relative_epsilon = other._relative_epsilon # has no effect + return new_geom # The functions below are at the core of Sinkhorn iterations, they # are implemented here in their default form, either in lse (using directly @@ -412,7 +400,7 @@ def update_potential( Returns: new potential value, g if axis=0, f if axis is 1. """ - eps = self._epsilon.at(iteration) + eps = self.epsilon_scheduler(iteration) app_lse = self.apply_lse_kernel(f, g, eps, axis=axis)[0] return eps * log_marginal - jnp.where(jnp.isfinite(app_lse), app_lse, 0) @@ -434,7 +422,7 @@ def update_scaling( Returns: new scaling vector, of size num_b if axis=0, num_a if axis is 1. """ - eps = self._epsilon.at(iteration) + eps = self.epsilon_scheduler(iteration) app_kernel = self.apply_kernel(scaling, eps, axis=axis) return marginal / jnp.where(app_kernel > 0, app_kernel, 1.0) @@ -931,7 +919,7 @@ def tree_flatten(self): # noqa: D102 self._src_mask, self._tgt_mask ), { "scale_cost": self._scale_cost, - "relative_epsilon": self._relative_epsilon + "relative_epsilon": self._relative_epsilon, } @classmethod diff --git a/src/ott/geometry/grid.py b/src/ott/geometry/grid.py index 6df230046..1755b72f6 100644 --- a/src/ott/geometry/grid.py +++ b/src/ott/geometry/grid.py @@ -78,6 +78,7 @@ def __init__( grid_dimension: Optional[int] = None, **kwargs: Any, ): + super().__init__(**kwargs) if ( grid_size is not None and x is not None and num_a is not None and grid_dimension is not None @@ -105,11 +106,10 @@ def __init__( self.kwargs = { "num_a": self.num_a, "grid_size": self.grid_size, - "grid_dimension": self.grid_dimension + "grid_dimension": self.grid_dimension, + "relative_epsilon": self._relative_epsilon, } - super().__init__(**kwargs) - @property def geometries(self) -> List[geometry.Geometry]: """Cost matrices along each dimension of the grid.""" @@ -365,9 +365,8 @@ def tree_flatten(self): # noqa: D102 @classmethod def tree_unflatten(cls, aux_data, children): # noqa: D102 - return cls( - x=children[0], cost_fns=children[1], epsilon=children[2], **aux_data - ) + x, cost_fns, epsilon = children + return cls(x, cost_fns=cost_fns, epsilon=epsilon, **aux_data) def to_LRCGeometry( self, diff --git a/src/ott/geometry/low_rank.py b/src/ott/geometry/low_rank.py index 7a776e17f..ec0d7d7b9 100644 --- a/src/ott/geometry/low_rank.py +++ b/src/ott/geometry/low_rank.py @@ -280,6 +280,7 @@ def tree_flatten(self): # noqa: D102 self._scale_factor, ), { "scale_cost": self._scale_cost, + "relative_epsilon": self._relative_epsilon, } @classmethod @@ -319,7 +320,7 @@ def __init__( epsilon: Optional[float] = None, **kwargs: Any ): - super().__init__(epsilon=epsilon, relative_epsilon=False, **kwargs) + super().__init__(epsilon=epsilon, relative_epsilon=None, **kwargs) self.k1 = k1 self.k2 = k2 diff --git a/src/ott/geometry/pointcloud.py b/src/ott/geometry/pointcloud.py index 7ac9f4425..af43e3307 100644 --- a/src/ott/geometry/pointcloud.py +++ b/src/ott/geometry/pointcloud.py @@ -255,7 +255,8 @@ def tree_flatten(self): # noqa: D102 self.cost_fn, ), { "batch_size": self._batch_size, - "scale_cost": self._scale_cost + "scale_cost": self._scale_cost, + "relative_epsilon": self._relative_epsilon, } @classmethod diff --git a/src/ott/initializers/quadratic/initializers.py b/src/ott/initializers/quadratic/initializers.py index f640921d1..eeaf2d97a 100644 --- a/src/ott/initializers/quadratic/initializers.py +++ b/src/ott/initializers/quadratic/initializers.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc -from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple +from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence, Tuple import jax.numpy as jnp import jax.tree_util as jtu @@ -128,7 +128,7 @@ def _create_geometry( quad_prob: "quadratic_problem.QuadraticProblem", *, epsilon: float, - relative_epsilon: Optional[bool] = None, + relative_epsilon: Optional[Literal["mean", "std"]] = None, **kwargs: Any, ) -> geometry.Geometry: """Compute initial geometry for linearization. diff --git a/src/ott/neural/methods/monge_gap.py b/src/ott/neural/methods/monge_gap.py index 6dbf7c0d9..31da10f55 100644 --- a/src/ott/neural/methods/monge_gap.py +++ b/src/ott/neural/methods/monge_gap.py @@ -46,7 +46,7 @@ def monge_gap( reference_points: jnp.ndarray, cost_fn: Optional[costs.CostFn] = None, epsilon: Optional[float] = None, - relative_epsilon: Optional[bool] = None, + relative_epsilon: Optional[Literal["mean", "std"]] = None, scale_cost: Union[float, Literal["mean", "max_cost", "median"]] = 1.0, return_output: bool = False, **kwargs: Any @@ -111,7 +111,7 @@ def monge_gap_from_samples( target: jnp.ndarray, cost_fn: Optional[costs.CostFn] = None, epsilon: Optional[float] = None, - relative_epsilon: Optional[bool] = None, + relative_epsilon: Optional[Literal["mean", "std"]] = None, scale_cost: Union[float, Literal["mean", "max_cost", "median"]] = 1.0, return_output: bool = False, **kwargs: Any diff --git a/src/ott/problems/quadratic/quadratic_problem.py b/src/ott/problems/quadratic/quadratic_problem.py index b16970aa6..a96f4e09c 100644 --- a/src/ott/problems/quadratic/quadratic_problem.py +++ b/src/ott/problems/quadratic/quadratic_problem.py @@ -234,7 +234,7 @@ def init_transport_mass(self) -> float: def update_lr_geom( self, lr_sink: "sinkhorn_lr.LRSinkhornOutput", - relative_epsilon: Optional[bool] = None, + relative_epsilon: Optional[Literal["mean", "std"]] = None, ) -> geometry.Geometry: """Recompute (possibly LRC) linearization using LR Sinkhorn output.""" marginal_1 = lr_sink.marginal(1) @@ -270,7 +270,7 @@ def update_linearization( transport: Transport, epsilon: Optional[float] = None, old_transport_mass: float = 1.0, - relative_epsilon: Optional[bool] = None, + relative_epsilon: Optional[Literal["mean", "std"]] = None, ) -> linear_problem.LinearProblem: """Update linearization of GW problem by updating cost matrix. @@ -337,7 +337,7 @@ def update_lr_linearization( self, lr_sink: "sinkhorn_lr.LRSinkhornOutput", *, - relative_epsilon: Optional[bool] = None, + relative_epsilon: Optional[Literal["mean", "std"]] = None, ) -> linear_problem.LinearProblem: """Update a Quad problem linearization using a LR Sinkhorn.""" return linear_problem.LinearProblem( diff --git a/src/ott/solvers/linear/discrete_barycenter.py b/src/ott/solvers/linear/discrete_barycenter.py index 2c1277aa4..513147af7 100644 --- a/src/ott/solvers/linear/discrete_barycenter.py +++ b/src/ott/solvers/linear/discrete_barycenter.py @@ -191,7 +191,7 @@ def body_fn(iteration, const, state, compute_error): geom, a, weights = const errors, d, f_u, g_v = state - eps = geom._epsilon.at(iteration) # pylint: disable=protected-access + eps = geom.epsilon_scheduler(iteration) # pylint: disable=protected-access f_u = parallel_update(f_u, g_v, a, iteration) # kernel_f_u stands for K times potential u if running in scaling mode, # eps log K exp f / eps in lse mode. diff --git a/src/ott/solvers/quadratic/gromov_wasserstein.py b/src/ott/solvers/quadratic/gromov_wasserstein.py index 5d6b7abe5..6affee0db 100644 --- a/src/ott/solvers/quadratic/gromov_wasserstein.py +++ b/src/ott/solvers/quadratic/gromov_wasserstein.py @@ -15,6 +15,7 @@ Any, Callable, Dict, + Literal, NamedTuple, Optional, Sequence, @@ -181,7 +182,7 @@ def __init__( self, linear_solver: sinkhorn.Sinkhorn, epsilon: float = 1.0, - relative_epsilon: Optional[bool] = None, + relative_epsilon: Optional[Literal["mean", "std"]] = None, initializer: Optional[quad_initializers.BaseQuadraticInitializer] = None, warm_start: bool = False, progress_fn: Optional[ProgressCallbackFn] = None, diff --git a/tests/solvers/linear/sinkhorn_diff_test.py b/tests/solvers/linear/sinkhorn_diff_test.py index d96025b58..aa5a25432 100644 --- a/tests/solvers/linear/sinkhorn_diff_test.py +++ b/tests/solvers/linear/sinkhorn_diff_test.py @@ -788,9 +788,7 @@ def loss(a: jnp.ndarray, x: jnp.ndarray, implicit: bool = True): delta_a = delta_a - jnp.mean(delta_a) delta_x = jax.random.uniform(rngs[5], (n, dim)) - hess_loss_imp = jax.jit( - jax.hessian(lambda a, x: loss(a, x, True), argnums=arg) - ) + hess_loss_imp = (jax.hessian(lambda a, x: loss(a, x, True), argnums=arg)) hess_imp = hess_loss_imp(a, x) # Test that Hessians produced with either backprop or implicit do match. diff --git a/tests/solvers/linear/sinkhorn_test.py b/tests/solvers/linear/sinkhorn_test.py index a90ef078b..b2e320d62 100644 --- a/tests/solvers/linear/sinkhorn_test.py +++ b/tests/solvers/linear/sinkhorn_test.py @@ -146,8 +146,8 @@ def test_autoepsilon(self): ) np.testing.assert_allclose( - geom_1._epsilon.at(2) * scale ** 2, - geom_2._epsilon.at(2), + geom_1.epsilon_scheduler(2) * scale ** 2, + geom_2.epsilon_scheduler(2), rtol=1e-3, atol=1e-3 ) @@ -167,9 +167,10 @@ def test_autoepsilon_with_decay( tau_b: float ): """Check that variations in init/decay work, and result in same solution.""" - epsilon = epsilon_scheduler.Epsilon(init=init, decay=decay) - geom1 = pointcloud.PointCloud(self.x, self.y, epsilon=epsilon) - geom2 = pointcloud.PointCloud(self.x, self.y) + geom = pointcloud.PointCloud(self.x, self.y) + target = epsilon_scheduler.DEFAULT_SCALE * geom.std_cost_matrix + epsilon = epsilon_scheduler.Epsilon(target, init=init, decay=decay) + geom_eps = pointcloud.PointCloud(self.x, self.y, epsilon=epsilon) run_fn = jax.jit( linear.solve, static_argnames=[ @@ -178,7 +179,7 @@ def test_autoepsilon_with_decay( ) out_1 = run_fn( - geom1, + geom_eps, self.a, self.b, tau_a=tau_a, @@ -188,7 +189,7 @@ def test_autoepsilon_with_decay( recenter_potentials=True ) out_2 = run_fn( - geom2, + geom, self.a, self.b, tau_a=tau_a, diff --git a/tests/solvers/quadratic/gw_test.py b/tests/solvers/quadratic/gw_test.py index d7dd5a9c6..d84a2aa85 100644 --- a/tests/solvers/quadratic/gw_test.py +++ b/tests/solvers/quadratic/gw_test.py @@ -433,7 +433,9 @@ def test_relative_epsilon( linear_solver = sinkhorn.Sinkhorn() solver = gromov_wasserstein.GromovWasserstein( - linear_solver, epsilon=eps, relative_epsilon=True + linear_solver, + epsilon=eps, + relative_epsilon="std", ) out = solver(prob) diff --git a/tests/tools/sinkhorn_divergence_test.py b/tests/tools/sinkhorn_divergence_test.py index f5995e21e..f893c70c3 100644 --- a/tests/tools/sinkhorn_divergence_test.py +++ b/tests/tools/sinkhorn_divergence_test.py @@ -150,7 +150,11 @@ def test_euclidean_autoepsilon(self): assert div > 0.0 assert len(out.potentials) == 3 assert len(out.geoms) == 3 - np.testing.assert_allclose(out.geoms[0].epsilon, out.geoms[1].epsilon) + + geom0, geom1, *_ = out.geoms + + assert geom0.epsilon_scheduler is not geom1.epsilon_scheduler + assert geom0.epsilon == geom1.epsilon def test_euclidean_autoepsilon_not_share_epsilon(self): rngs = jax.random.split(self.rng, 2) From 28da1be0336fc544cb64b2f31566c282e5014872 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Fri, 29 Nov 2024 19:21:44 +0100 Subject: [PATCH 03/13] Undo changes in the test --- tests/solvers/linear/sinkhorn_diff_test.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/solvers/linear/sinkhorn_diff_test.py b/tests/solvers/linear/sinkhorn_diff_test.py index aa5a25432..d96025b58 100644 --- a/tests/solvers/linear/sinkhorn_diff_test.py +++ b/tests/solvers/linear/sinkhorn_diff_test.py @@ -788,7 +788,9 @@ def loss(a: jnp.ndarray, x: jnp.ndarray, implicit: bool = True): delta_a = delta_a - jnp.mean(delta_a) delta_x = jax.random.uniform(rngs[5], (n, dim)) - hess_loss_imp = (jax.hessian(lambda a, x: loss(a, x, True), argnums=arg)) + hess_loss_imp = jax.jit( + jax.hessian(lambda a, x: loss(a, x, True), argnums=arg) + ) hess_imp = hess_loss_imp(a, x) # Test that Hessians produced with either backprop or implicit do match. From ade9aea644c2aa29e8abe97501aec06945a5f392 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Fri, 29 Nov 2024 19:54:39 +0100 Subject: [PATCH 04/13] Fix eps sched docs --- src/ott/geometry/epsilon_scheduler.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/ott/geometry/epsilon_scheduler.py b/src/ott/geometry/epsilon_scheduler.py index 61b38b48c..163a036e5 100644 --- a/src/ott/geometry/epsilon_scheduler.py +++ b/src/ott/geometry/epsilon_scheduler.py @@ -37,9 +37,8 @@ class Epsilon: Args: target: The epsilon regularizer that is targeted. - init: Initial value when using epsilon scheduling, understood as multiple - of target value. if passed, ``int * decay ** iteration`` will be used - to rescale target. + init: Initial value when using epsilon scheduling, understood as a multiple + of the ``target``, following :math:`\text{init} \text{decay}^{\text{it}}`. decay: Geometric decay factor, :math:`\leq 1`. """ From 1da51f6004456434b057614026575cd910a35f80 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Fri, 29 Nov 2024 19:56:49 +0100 Subject: [PATCH 05/13] Remove mention of `scale_epsilon` --- src/ott/geometry/epsilon_scheduler.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/ott/geometry/epsilon_scheduler.py b/src/ott/geometry/epsilon_scheduler.py index 163a036e5..f5d2cc3f4 100644 --- a/src/ott/geometry/epsilon_scheduler.py +++ b/src/ott/geometry/epsilon_scheduler.py @@ -26,14 +26,10 @@ class Epsilon: r"""Scheduler class for the regularization parameter epsilon. - An epsilon scheduler outputs a regularization strength, to be used by in a + An epsilon scheduler outputs a regularization strength, to be used by a Sinkhorn-type algorithm, at any iteration count. That value is either the final, targeted regularization, or one that is larger, obtained by geometric decay of an initial value that is larger than the intended target. - Concretely, the value returned by such a scheduler will consider first - the max between ``target`` and ``init * target * decay ** iteration``. - If the ``scale_epsilon`` parameter is provided, that value is used to - multiply the max computed previously by ``scale_epsilon``. Args: target: The epsilon regularizer that is targeted. From 3ee8bc085c95a041ed75bb84e557034560d64178 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Fri, 29 Nov 2024 19:58:57 +0100 Subject: [PATCH 06/13] Remove mention of norms in point cloud --- src/ott/geometry/pointcloud.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/ott/geometry/pointcloud.py b/src/ott/geometry/pointcloud.py index af43e3307..14d950089 100644 --- a/src/ott/geometry/pointcloud.py +++ b/src/ott/geometry/pointcloud.py @@ -27,13 +27,9 @@ class PointCloud(geometry.Geometry): """Defines geometry for 2 point clouds (possibly 1 vs itself). - Creates a geometry, specifying a cost function passed as CostFn type object. - When the number of points is large, setting the ``batch_size`` flag implies - that cost and kernel matrices used to update potentials or scalings - will be recomputed on the fly, rather than stored in memory. More precisely, - when setting ``batch_size``, the cost function will be partially cached by - storing norm values for each point in both point clouds, but the pairwise cost - function evaluations won't be. + When the number of points is large, setting the :attr:`batch_size` flag + implies that cost and kernel matrices used to update potentials or scalings + will be recomputed on the fly, rather than stored in memory. Args: x : n x d array of n d-dimensional vectors From 2b1edc351041c6b713b8cefb8d253a94aeb04f5a Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Fri, 29 Nov 2024 20:01:06 +0100 Subject: [PATCH 07/13] Nicer pointcloud docs --- src/ott/geometry/pointcloud.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/ott/geometry/pointcloud.py b/src/ott/geometry/pointcloud.py index 14d950089..16083e7e5 100644 --- a/src/ott/geometry/pointcloud.py +++ b/src/ott/geometry/pointcloud.py @@ -15,6 +15,7 @@ import jax import jax.numpy as jnp +import jax.tree_util as jtu from ott import utils from ott.geometry import costs, geometry, low_rank @@ -23,7 +24,7 @@ __all__ = ["PointCloud"] -@jax.tree_util.register_pytree_node_class +@jtu.register_pytree_node_class class PointCloud(geometry.Geometry): """Defines geometry for 2 point clouds (possibly 1 vs itself). @@ -32,9 +33,9 @@ class PointCloud(geometry.Geometry): will be recomputed on the fly, rather than stored in memory. Args: - x : n x d array of n d-dimensional vectors - y : m x d array of m d-dimensional vectors. If `None`, use ``x``. - cost_fn: a CostFn function between two points in dimension d. + x: Array of shape ``[n, d]``. + y: Array of shape ``[m, d]``. If :obj:`None`, use ``x``. + cost_fn: Cost function between two points in dimension :math:`d`. batch_size: When ``None``, the cost matrix corresponding to that point cloud is computed, stored and later re-used at each application of :meth:`apply_lse_kernel`. When ``batch_size`` is a positive integer, From 723f6eb00eb7c072850263983db5c218250be28f Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Fri, 29 Nov 2024 20:03:43 +0100 Subject: [PATCH 08/13] Update docs of relative epsilon --- src/ott/geometry/geometry.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/ott/geometry/geometry.py b/src/ott/geometry/geometry.py index a2df1d5d3..b8893993f 100644 --- a/src/ott/geometry/geometry.py +++ b/src/ott/geometry/geometry.py @@ -62,12 +62,8 @@ class Geometry: :attr:`mean_cost_matrix` (if ``relative_epsilon = `mean```). Look for :class:`~ott.geometry.epsilon_scheduler.Epsilon` when passed as a scheduler. - relative_epsilon: when :obj:`False`, the parameter ``epsilon`` specifies the - value of the entropic regularization parameter. When :obj:`True` or set - to a string, ``epsilon`` refers to a fraction of the - :attr:`std_cost_matrix` or :attr:`mean_cost_matrix`, which is computed - adaptively from data, depending on whether it is set to ``mean`` or - ``std``. + relative_epsilon: Whether ``epsilon`` refers to a fraction of the + :attr:`mean_cost_matrix` or :attr:`std_cost_matrix`. scale_cost: option to rescale the cost matrix. Implemented scalings are 'median', 'mean', 'std' and 'max_cost'. Alternatively, a float factor can be given to rescale the cost such that ``cost_matrix /= scale_cost``. From 21ce601f3e4505a1d5105d85da1c50ae0aeba63d Mon Sep 17 00:00:00 2001 From: Marco Cuturi Date: Fri, 29 Nov 2024 21:56:32 +0100 Subject: [PATCH 09/13] Update geometry.py change pydocs --- src/ott/geometry/geometry.py | 40 +++++++++++++++++------------------- 1 file changed, 19 insertions(+), 21 deletions(-) diff --git a/src/ott/geometry/geometry.py b/src/ott/geometry/geometry.py index b8893993f..bb199109f 100644 --- a/src/ott/geometry/geometry.py +++ b/src/ott/geometry/geometry.py @@ -36,32 +36,30 @@ class Geometry: Optimal transport problems are intrinsically geometric: they compute an optimal way to transport mass from one configuration onto another. To define - what is meant by optimality of transport requires defining a cost, of moving - mass from one among several sources, towards one out of multiple targets. - These sources and targets can be provided as points in vectors spaces, grids, - or more generally exclusively described through a (dissimilarity) cost matrix, - or almost equivalently, a (similarity) kernel matrix. - - Once that cost or kernel matrix is set, the ``Geometry`` class provides a - basic operations to be run with the Sinkhorn algorithm. + what is meant by optimality of transport requires defining a + :term:`ground cost`, which quantifies how costly it is to move mass from + one among several source locations, towards one out of multiple + target locations. These source and target locations can be described as + points in vectors spaces, grids, or more generally described + through a (dissimilarity) cost matrix, or almost equivalently, a + (similarity) kernel matrix. This class describes such a + geometry and several useful methods to exploit it. Args: cost_matrix: Cost matrix of shape ``[n, m]``. kernel_matrix: Kernel matrix of shape ``[n, m]``. - epsilon: Regularization parameter. If ``None`` and - ``relative_epsilon = None`` - ``relative_epsilon = str`` where ``str`` can be either ``mean`` or ``std`` - , this value defaults to a multiple of :attr:`std_cost_matrix` - (or :attr:`mean_cost_matrix` if ``str`` is ``mean``), where that multiple - is set as ``DEFAULT_SCALE`` in ``epsilon_scheduler.py```. - If passed as a - ``float``, then the regularizer that is ultimately used is either that - ``float`` value (if ``relative_epsilon = False`` or ``None``) or that - ``float`` times the :attr:`std_cost_matrix` (if - ``relative_epsilon = True`` or ``relative_epsilon = `std```) or - :attr:`mean_cost_matrix` (if ``relative_epsilon = `mean```). Look for + epsilon: Regularization parameter or scheduler. Look for :class:`~ott.geometry.epsilon_scheduler.Epsilon` when passed as a - scheduler. + scheduler directly. Otherwise, if :obj:`None` and + ``relative_epsilon`` is :obj:`None` the regularizer value + defaults to a multiple of :attr:`std_cost_matrix`, that multiple + is set as :obj:`~ott.geometry.epsilon_scheduler.DEFAULT_SCALE`, + currently equal to `0.05`. If passed as + a ``float``, then the regularizer that is ultimately used is either + that ``float`` value (if ``relative_epsilon`` is :obj:`None`) or that + ``float`` times the :attr:`std_cost_matrix` (if + ``relative_epsilon`` is ``"std"``) or + :attr:`mean_cost_matrix` (if ``relative_epsilon`` is ``"mean"``). relative_epsilon: Whether ``epsilon`` refers to a fraction of the :attr:`mean_cost_matrix` or :attr:`std_cost_matrix`. scale_cost: option to rescale the cost matrix. Implemented scalings are From 86aa841cd622b9e367fdcb31b49f9fd10031a836 Mon Sep 17 00:00:00 2001 From: Marco Cuturi Date: Fri, 29 Nov 2024 22:26:44 +0100 Subject: [PATCH 10/13] Update epsilon_scheduler.py pydoc --- src/ott/geometry/epsilon_scheduler.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/ott/geometry/epsilon_scheduler.py b/src/ott/geometry/epsilon_scheduler.py index f5d2cc3f4..379bd2c62 100644 --- a/src/ott/geometry/epsilon_scheduler.py +++ b/src/ott/geometry/epsilon_scheduler.py @@ -26,10 +26,10 @@ class Epsilon: r"""Scheduler class for the regularization parameter epsilon. - An epsilon scheduler outputs a regularization strength, to be used by a - Sinkhorn-type algorithm, at any iteration count. That value is either the - final, targeted regularization, or one that is larger, obtained by - geometric decay of an initial value that is larger than the intended target. + An epsilon scheduler outputs a regularization strength, to be used by the + :term:`Sinkhorn algorithm` or variant, at any iteration count. That value is + either the final, targeted regularization, or one that is larger, obtained by + geometric decay of an initial multiplier. Args: target: The epsilon regularizer that is targeted. @@ -45,7 +45,7 @@ def __init__(self, target: jnp.array, init: float = 1.0, decay: float = 1.0): self.decay = decay def __call__(self, it: Optional[int]) -> jnp.array: - """Return (intermediate) regularizer value at a given iteration. + """Intermediate regularizer value at a given iteration number. Args: it: Current iteration. If :obj:`None`, return :attr:`target`. From ae1be757e040bd2ae44b90d96466957212b1162f Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Mon, 2 Dec 2024 10:41:54 +0100 Subject: [PATCH 11/13] Remove TODOs --- src/ott/geometry/geometry.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/ott/geometry/geometry.py b/src/ott/geometry/geometry.py index bb199109f..b1d66c0fd 100644 --- a/src/ott/geometry/geometry.py +++ b/src/ott/geometry/geometry.py @@ -155,7 +155,7 @@ def kernel_matrix(self) -> jnp.ndarray: @property def epsilon_scheduler(self) -> eps_scheduler.Epsilon: - """TODO.""" + """Epsilon scheduler.""" if isinstance(self._epsilon_init, eps_scheduler.Epsilon): return self._epsilon_init @@ -595,7 +595,6 @@ def apply_cost( app = functools.partial( self._apply_cost_to_vec, axis=axis, fn=fn, is_linear=is_linear ) - # TODO(michalk8): vmap over multiple dims? return jax.vmap(app, in_axes=1, out_axes=1)(arr) def _apply_cost_to_vec( From e142beced5f3c58a894095f4c0900138a721f57e Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Mon, 2 Dec 2024 11:42:49 +0100 Subject: [PATCH 12/13] Polish docs, rename to `DEFAULT_EPSILON_SCALE --- docs/geometry.rst | 2 +- src/ott/experimental/mmsinkhorn.py | 2 +- src/ott/geometry/epsilon_scheduler.py | 4 ++-- src/ott/geometry/geometry.py | 30 +++++++++++++-------------- tests/geometry/geometry_test.py | 2 +- tests/solvers/linear/sinkhorn_test.py | 2 +- 6 files changed, 21 insertions(+), 21 deletions(-) diff --git a/docs/geometry.rst b/docs/geometry.rst index 407d8fca0..932278147 100644 --- a/docs/geometry.rst +++ b/docs/geometry.rst @@ -49,7 +49,7 @@ Geometries low_rank.LRCGeometry low_rank.LRKGeometry epsilon_scheduler.Epsilon - epsilon_scheduler.DEFAULT_SCALE + epsilon_scheduler.DEFAULT_EPSILON_SCALE Cost Functions -------------- diff --git a/src/ott/experimental/mmsinkhorn.py b/src/ott/experimental/mmsinkhorn.py index 39a30eef3..ed9c27d77 100644 --- a/src/ott/experimental/mmsinkhorn.py +++ b/src/ott/experimental/mmsinkhorn.py @@ -305,7 +305,7 @@ def __call__( cost_t = cost_tensor(x_s, cost_fns) state = self.init_state(n_s) if epsilon is None: - epsilon = epsilon_scheduler.DEFAULT_SCALE * jnp.std(cost_t) + epsilon = epsilon_scheduler.DEFAULT_EPSILON_SCALE * jnp.std(cost_t) const = cost_t, a_s, epsilon out = run(const, self, state) return out.set(x_s=x_s, a_s=a_s, cost_fns=cost_fns, epsilon=epsilon) diff --git a/src/ott/geometry/epsilon_scheduler.py b/src/ott/geometry/epsilon_scheduler.py index 379bd2c62..1b63fd4a8 100644 --- a/src/ott/geometry/epsilon_scheduler.py +++ b/src/ott/geometry/epsilon_scheduler.py @@ -16,10 +16,10 @@ import jax.numpy as jnp import jax.tree_util as jtu -__all__ = ["Epsilon", "DEFAULT_SCALE"] +__all__ = ["Epsilon", "DEFAULT_EPSILON_SCALE"] #: Scaling applied to statistic (mean/std) of cost to compute default epsilon. -DEFAULT_SCALE = 0.05 +DEFAULT_EPSILON_SCALE = 0.05 @jtu.register_pytree_node_class diff --git a/src/ott/geometry/geometry.py b/src/ott/geometry/geometry.py index b1d66c0fd..96b62f798 100644 --- a/src/ott/geometry/geometry.py +++ b/src/ott/geometry/geometry.py @@ -48,18 +48,18 @@ class Geometry: Args: cost_matrix: Cost matrix of shape ``[n, m]``. kernel_matrix: Kernel matrix of shape ``[n, m]``. - epsilon: Regularization parameter or scheduler. Look for - :class:`~ott.geometry.epsilon_scheduler.Epsilon` when passed as a - scheduler directly. Otherwise, if :obj:`None` and - ``relative_epsilon`` is :obj:`None` the regularizer value - defaults to a multiple of :attr:`std_cost_matrix`, that multiple - is set as :obj:`~ott.geometry.epsilon_scheduler.DEFAULT_SCALE`, - currently equal to `0.05`. If passed as - a ``float``, then the regularizer that is ultimately used is either - that ``float`` value (if ``relative_epsilon`` is :obj:`None`) or that - ``float`` times the :attr:`std_cost_matrix` (if - ``relative_epsilon`` is ``"std"``) or - :attr:`mean_cost_matrix` (if ``relative_epsilon`` is ``"mean"``). + epsilon: Regularization parameter or a scheduler: + + - ``epsilon = None`` and ``relative_epsilon = None``, use + :math:`0.05 * \text{stddev(cost_matrix)}`. + - if ``epsilon`` is a :class:`float` and ``relative_epsilon = None``, + it directly corresponds to the regularization strength. + - otherwise, ``epsilon`` multiplies the :attr:`mean_cost_matrix` or + :attr:`std_cost_matrix`, depending on the value of ``relative_epsilon``. + + If ``epsilon = None``, the value of + :obj:`DEFAULT_EPSILON_SCALE = 0.05 `. + will be used. relative_epsilon: Whether ``epsilon`` refers to a fraction of the :attr:`mean_cost_matrix` or :attr:`std_cost_matrix`. scale_cost: option to rescale the cost matrix. Implemented scalings are @@ -76,7 +76,7 @@ class Geometry: parameter that is meaningful. That parameter can be provided by the user, or assigned a default value through a simple rule, using for instance the :attr:`mean_cost_matrix` or the :attr:`std_cost_matrix`. - """ + """ # noqa: E501 def __init__( self, @@ -163,7 +163,7 @@ def epsilon_scheduler(self) -> eps_scheduler.Epsilon: if self._relative_epsilon is None: if self._epsilon_init is not None: return eps_scheduler.Epsilon(self._epsilon_init) - multiplier = eps_scheduler.DEFAULT_SCALE + multiplier = eps_scheduler.DEFAULT_EPSILON_SCALE scale = jax.lax.stop_gradient(self.std_cost_matrix) return eps_scheduler.Epsilon(target=multiplier * scale) @@ -175,7 +175,7 @@ def epsilon_scheduler(self) -> eps_scheduler.Epsilon: raise ValueError(f"Invalid relative epsilon: {self._relative_epsilon}.") multiplier = ( - eps_scheduler.DEFAULT_SCALE + eps_scheduler.DEFAULT_EPSILON_SCALE if self._epsilon_init is None else self._epsilon_init ) return eps_scheduler.Epsilon(target=multiplier * scale) diff --git a/tests/geometry/geometry_test.py b/tests/geometry/geometry_test.py index ebb5703f4..06eff9fdd 100644 --- a/tests/geometry/geometry_test.py +++ b/tests/geometry/geometry_test.py @@ -28,7 +28,7 @@ class TestCostMeanStd: def test_cost_stdmean(self, rng: jax.Array, geom_type: str): """Test consistency of std evaluation.""" n, m, d = 5, 18, 10 - default_scale = epsilon_scheduler.DEFAULT_SCALE + default_scale = epsilon_scheduler.DEFAULT_EPSILON_SCALE rngs = jax.random.split(rng, 5) x = jax.random.normal(rngs[0], (n, d)) y = jax.random.normal(rngs[1], (m, d)) + 1 diff --git a/tests/solvers/linear/sinkhorn_test.py b/tests/solvers/linear/sinkhorn_test.py index b2e320d62..87df4ddfb 100644 --- a/tests/solvers/linear/sinkhorn_test.py +++ b/tests/solvers/linear/sinkhorn_test.py @@ -168,7 +168,7 @@ def test_autoepsilon_with_decay( ): """Check that variations in init/decay work, and result in same solution.""" geom = pointcloud.PointCloud(self.x, self.y) - target = epsilon_scheduler.DEFAULT_SCALE * geom.std_cost_matrix + target = epsilon_scheduler.DEFAULT_EPSILON_SCALE * geom.std_cost_matrix epsilon = epsilon_scheduler.Epsilon(target, init=init, decay=decay) geom_eps = pointcloud.PointCloud(self.x, self.y, epsilon=epsilon) run_fn = jax.jit( From d0e250faa87b03313886290540dd5ef70cd4f82f Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Mon, 2 Dec 2024 11:44:38 +0100 Subject: [PATCH 13/13] Update PC docs --- src/ott/geometry/pointcloud.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/ott/geometry/pointcloud.py b/src/ott/geometry/pointcloud.py index 16083e7e5..fafc39b5b 100644 --- a/src/ott/geometry/pointcloud.py +++ b/src/ott/geometry/pointcloud.py @@ -36,14 +36,14 @@ class PointCloud(geometry.Geometry): x: Array of shape ``[n, d]``. y: Array of shape ``[m, d]``. If :obj:`None`, use ``x``. cost_fn: Cost function between two points in dimension :math:`d`. - batch_size: When ``None``, the cost matrix corresponding to that point cloud - is computed, stored and later re-used at each application of - :meth:`apply_lse_kernel`. When ``batch_size`` is a positive integer, - computations are done in an online fashion, namely the cost matrix is - recomputed at each call of the :meth:`apply_lse_kernel` step, - ``batch_size`` lines at a time, used on a vector and discarded. - The online computation is particularly useful for big point clouds - whose cost matrix does not fit in memory. + batch_size: If :obj:`None`, the cost matrix corresponding to that + point cloud is computed, stored and later re-used at each application of + :meth:`apply_lse_kernel`. When ``batch_size`` is a positive integer, + computations are done in an online fashion, namely the cost matrix is + recomputed at each call of the :meth:`apply_lse_kernel` step, + ``batch_size`` lines at a time, used on a vector and discarded. + The online computation is particularly useful for big point clouds + whose cost matrix does not fit in memory. scale_cost: option to rescale the cost matrix. Implemented scalings are 'median', 'mean', 'max_cost', 'max_norm' and 'max_bound'. Alternatively, a float factor can be given to rescale the cost such