From 049eb228b9cac5504ae71bd0f366ce5949c989f7 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Mon, 29 Apr 2024 16:18:15 +0200 Subject: [PATCH 1/7] Handle `axis=None` in `lse`'s JVP --- src/ott/math/utils.py | 6 +- src/ott/univariate.py | 434 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 439 insertions(+), 1 deletion(-) create mode 100644 src/ott/univariate.py diff --git a/src/ott/math/utils.py b/src/ott/math/utils.py index afad18db1..6e1aeb7fa 100644 --- a/src/ott/math/utils.py +++ b/src/ott/math/utils.py @@ -164,7 +164,11 @@ def logsumexp_jvp(axis, keepdims, return_sign, primals, tangents): if return_sign: lse, sign = lse lse = jnp.where(jnp.isfinite(lse), lse, 0.0) - centered_exp = jnp.exp(mat - jnp.expand_dims(lse, axis=axis)) + + if axis is not None: + centered_exp = jnp.exp(mat - jnp.expand_dims(lse, axis=axis)) + else: + centered_exp = jnp.exp(mat - lse) if b is None: res = jnp.sum(centered_exp * tan_mat, axis=axis, keepdims=keepdims) diff --git a/src/ott/univariate.py b/src/ott/univariate.py new file mode 100644 index 000000000..d3c20f78a --- /dev/null +++ b/src/ott/univariate.py @@ -0,0 +1,434 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import NamedTuple, Optional, Tuple, Union + +import lineax as lx + +import jax +import jax.numpy as jnp + +from ott import utils +from ott.geometry import costs, pointcloud +from ott.math import utils as mu +from ott.problems.linear import linear_problem + +__all__ = [ + "UnivariateOutput", "UnivariateSolver", "uniform_distance", + "quantile_distance" +] + +Distance_t = Tuple[float, Optional[jnp.ndarray], Optional[jnp.ndarray], + Optional[jnp.ndarray], Optional[jnp.ndarray]] + + +class UnivariateOutput(NamedTuple): # noqa: D101 + """Output of the :class:`~ott.solvers.linear.UnivariateSolver`. + + Objects of this class contain both solutions and problem definition of a + univariate OT problem. + + Args: + prob: OT problem between 2 weighted ``[n, d]`` and ``[m, d]`` point clouds. + ot_costs: ``[d,]`` optimal transport cost values, computed independently + along each of the ``d`` slices. + paired_indices: ``None`` if no transport was computed / recorded (e.g. when + using quantiles or subsampling approximations). Otherwise, output a tensor + of shape ``[d, 2, m+n]``, of ``m+n`` pairs of indices, for which the + optimal transport assigns mass, on each slice of the ``d`` slices + described in the dataset. Namely, for each index ``0<=k jnp.ndarray: + """Outputs a ``[d, n, m]`` tensor of all ``[n, m]`` transport matrices. + + This tensor will be extremely sparse, since it will have at most ``d(n+m)`` + non-zero values, out of ``dnm`` total entries. + """ + assert self.paired_indices is not None, \ + "[d, n, m] tensor of transports cannot be computed, likely because an" \ + " approximate method was used (using either subsampling or quantiles)." + + n, m = self.prob.geom.shape + if self.prob.is_equal_size and self.prob.is_uniform: + transport_matrices_from_indices = jax.vmap( + lambda idx, idy: jnp.eye(n)[idx, :][:, idy].T, in_axes=[0, 0] + ) + return transport_matrices_from_indices( + self.paired_indices[:, 0, :], self.paired_indices[:, 1, :] + ) + + # raveled indexing of entries. + indices = self.paired_indices[:, 0] * m + self.paired_indices[:, 1] + # segment sum is needed to collect several contributions + return jax.vmap( + lambda idx, mass: jax.ops.segment_sum( + mass, idx, indices_are_sorted=True, num_segments=n * m + ).reshape(n, m), + in_axes=[0, 0] + )(indices, self.mass_paired_indices) + + @property + def mean_transport_matrix(self) -> jnp.ndarray: + """Return the mean transport matrix, averaged over slices.""" + return jnp.mean(self.transport_matrices, axis=0) + + +@jax.tree_util.register_pytree_node_class +class UnivariateSolver: + r"""Univariate solver to compute 1D OT distance over slices of data. + + Computes 1-Dimensional optimal transport distance between two $d$-dimensional + point clouds. The total distance is the sum of univariate Wasserstein + distances on the $d$ slices of data: given two weighted point-clouds, stored + as ``[n, d]`` and ``[m, d]`` in a + :class:`~ott.problems.linear.linear_problem.LinearProblem` object, with + respective weights ``a`` and ``b``, the solver + computes ``d`` OT distances between each of these ``[n, 1]`` and ``[m, 1]`` + slices. The distance is computed using the analytical formula by default, + which involves sorting each of the slices independently. The optimal transport + matrices are also outputted when possible (described in sparse form, i.e. + pairs of indices and mass transferred between those indices). + + When weights ``a`` and ``b`` are uniform, and ``n=m``, the computation only + involves comparing sorted entries per slice, and ``d`` assignments are given. + + The user may also supply a ``num_subsamples`` parameter to extract as many + points from the original point cloud, sampled with probability masses ``a`` + and ``b``. This then simply applied the method above to the subsamples, to + output ``d`` costs, but assignments are not provided. + + When the problem is not uniform or not of equal size, the method defaults to + an inversion of the CDF, and outputs both costs and transport matrix in sparse + form. + + When a ``quantiles`` argument is passed, either specifying explicit quantiles + or a grid of quantiles, the distance is evaluated by comparing the quantiles + of the two point clouds on each slice. The OT costs are returned but + assignments are not provided. + + Args: + num_subsamples: Option to reduce the size of inputs by doing random + subsampling, taken into account marginal probabilities. + quantiles: When a vector or a number of quantiles is passed, the distance + is computed by evaluating the cost function on the sectional (one for each + dimension) quantiles of the two point cloud distributions described in the + problem. + """ + + def __init__( + self, + num_subsamples: Optional[int] = None, + quantiles: Optional[Union[int, jnp.ndarray]] = None, + ): + self._quantiles = quantiles + self.num_subsamples = num_subsamples + + @property + def quantiles(self) -> Optional[jnp.ndarray]: + """Quantiles' values used to evaluate OT cost.""" + if self._quantiles is None: + return None + if isinstance(self._quantiles, int): + return jnp.linspace(0.0, 1.0, self._quantiles) + return self._quantiles + + @property + def num_quantiles(self) -> int: + """Number of quantiles used to evaluate OT cost.""" + return 0 if self.quantiles is None else self.quantiles.shape[0] + + def __call__( + self, + prob: linear_problem.LinearProblem, + return_transport: bool = True, + return_dual_vectors: bool = True, + rng: Optional[jax.Array] = None, + ) -> UnivariateOutput: + """Computes Univariate Distance between the ``d`` dimensional slices. + + Args: + prob: Problem with a :attr:`~ott.problems.linear.LinearProblem.geom` + attribute, the two point clouds ``x`` and ``y`` + (of respective sizes ``[n, d]`` and ``[m, d]``) and a ground + `TI cost ` between two scalars. + The ``[n,]`` and ``[m,]`` size probability weights vectors are stored + in attributes `:attr:`~ott.problems.linear.LinearProblem.a` and + :attr:`~ott.problems.linear.LinearProblem.b`. + return_transport: Whether to return pairs of matched indices used to + compute optimal transport matrices. + return_dual_vectors: Whether to return pairs of dual vectors + rng: Used for random downsampling, if specified in the solver. + + Returns: + An output object, that computes ``d`` OT costs, in addition to, possibly, + paired lists of indices and their corresponding masses, on each of the + ``d`` dimensional slices of the input. + """ + geom = prob.geom + assert isinstance(geom, pointcloud.PointCloud), \ + "Geometry object in problem must be a PointCloud." + assert isinstance(geom.cost_fn, costs.TICost), \ + "Geometry's cost must be translation invariant." + + rng = utils.default_prng_key(rng) + return_transp_variables = return_transport or return_dual_vectors + + assert not(self.num_subsamples and return_transp_variables), \ + "Cannot return any transport output variables when subsampling." + + if self.num_subsamples: + x, y = self._subsample(prob, rng) + is_uniform_same_size = True + else: + # check if problem has the property uniform / same number of points + x, y = geom.x, geom.y + is_uniform_same_size = prob.is_uniform and prob.is_equal_size + + if self.quantiles is not None: + assert prob.is_uniform, \ + "The 'quantiles' method can only be used with uniform marginals." + out = _quant_dist(x, y, geom.cost_fn, self.quantiles, self.num_quantiles) + elif is_uniform_same_size and not return_dual_vectors: + return_transport = return_transport and not self.num_subsamples + out = uniform_distance(x, y, geom.cost_fn, return_transport) + else: + if False: + # perturb uniform inputs with small deterministic noise. + perturb_a = jax.random.uniform( + jax.random.PRNGKey(0), (geom.shape[0],) + ) * 1e-7 + perturb_b = jax.random.uniform( + jax.random.PRNGKey(1), (geom.shape[1],) + ) * 1e-7 + perturb_a -= jnp.mean(perturb_a) + perturb_b -= jnp.mean(perturb_b) + a = prob.a + perturb_a + b = prob.b + perturb_b + else: + a, b = prob.a, prob.b + fn = jax.vmap( + quantile_distance, in_axes=[1, 1, None, None, None, None, None] + ) + fn = lambda x, y, *args: quantile_distance(x[:, 0], y[:, 0], *args) + out = fn(x, y, geom.cost_fn, a, b, return_transport, return_dual_vectors) + + return UnivariateOutput(prob, *out) + + def _subsample(self, prob: linear_problem.LinearProblem, + rng: jax.Array) -> Tuple[jnp.ndarray, jnp.ndarray]: + n, m = prob.geom.shape + x, y = prob.geom.x, prob.geom.y + + if prob.is_uniform: + x = x[jnp.linspace(0, n, num=self.num_subsamples).astype(int), :] + y = y[jnp.linspace(0, m, num=self.num_subsamples).astype(int), :] + return x, y + + rng1, rng2 = jax.random.split(rng, 2) + x = jax.random.choice(rng1, x, (self.num_subsamples,), p=prob.a, axis=0) + y = jax.random.choice(rng2, y, (self.num_subsamples,), p=prob.b, axis=0) + return x, y + + def tree_flatten(self): # noqa: D102 + return None, (self.num_subsamples, self._quantiles) + + @classmethod + def tree_unflatten(cls, aux_data, children): # noqa: D102 + del children + return cls(*aux_data) + + +def uniform_distance( + x: jnp.ndarray, + y: jnp.ndarray, + cost_fn: costs.TICost, + return_transport: bool = True +) -> Distance_t: + """Distance between two equal-size families of uniformly weighted values x/y. + + Args: + x: Vector ``[n,]`` of real values. + y: Vector ``[n,]`` of real values. + cost_fn: Translation invariant cost function, i.e. ``c(x, y) = h(x - y)``. + return_transport: whether to return mapped pairs. + + Returns: + optimal transport cost, a list of ``n+m`` paired indices, and their + corresponding transport mass. Note that said mass can be null in some + entries, but sums to 1.0 + """ + n = x.shape[0] + i_x, i_y = jnp.argsort(x, axis=0), jnp.argsort(y, axis=0) + x = jnp.take_along_axis(x, i_x, axis=0) + y = jnp.take_along_axis(y, i_y, axis=0) + ot_costs = jax.vmap(cost_fn.h, in_axes=[0])(x.T - y.T) / n + + if return_transport: + paired_indices = jnp.stack([i_x, i_y]).transpose([2, 0, 1]) + mass_paired_indices = jnp.ones((n,)) / n + return ot_costs, paired_indices, mass_paired_indices + + return ot_costs, None, None, None, None + + +def quantile_distance( + x: jnp.ndarray, + y: jnp.ndarray, + cost_fn: costs.TICost, + a: jnp.ndarray, + b: jnp.ndarray, + return_transport: bool = True, + return_dual_vectors: bool = True, +) -> Distance_t: + """Computes distance between quantile functions of distributions (a,x)/(b,y). + + Args: + x: Vector ``[n,]`` of real values. + y: Vector ``[m,]`` of real values. + cost_fn: Translation invariant cost function, i.e. ``c(x, y) = h(x - y)``. + a: Vector ``[n,]`` of non-negative weights summing to 1. + b: Vector ``[m,]`` of non-negative weights summing to 1. + return_transport: whether to return mapped pairs. + return_dual_vectors: whether to return dual vectors. when set to ``True``, + will turn ``return_transport`` to ``True`` regardless of the user choice. + + Returns: + optimal transport cost. Optionally, a list of ``n + m`` paired indices, and + their corresponding transport mass. Note that said mass can be null in some + entries, but sums to 1.0. Optionally, two dual vectors corresponding to that + transport. + + Notes: + Inspired by :func:`~scipy.stats.wasserstein_distance`, + but can be used with other costs, not just :math:`c(x, y) = |x - y|`. + """ + x_, y_ = x, y + x, i_x = mu.sort_and_argsort(x, argsort=True) + y, i_y = mu.sort_and_argsort(y, argsort=True) + + all_values = jnp.concatenate([x, y]) + all_values_sorted, all_values_sorter = mu.sort_and_argsort( + all_values, argsort=True + ) + + x_pdf = jnp.concatenate([a[i_x], jnp.zeros_like(b)])[all_values_sorter] + y_pdf = jnp.concatenate([jnp.zeros_like(a), b[i_y]])[all_values_sorter] + + x_cdf = jnp.cumsum(x_pdf) + y_cdf = jnp.cumsum(y_pdf) + + x_y_cdfs = jnp.concatenate([x_cdf, y_cdf]) + quantile_levels, _ = mu.sort_and_argsort(x_y_cdfs, argsort=False) + + i_x_cdf_inv = jnp.searchsorted(x_cdf, quantile_levels) + x_cdf_inv = all_values_sorted[i_x_cdf_inv] + i_y_cdf_inv = jnp.searchsorted(y_cdf, quantile_levels) + y_cdf_inv = all_values_sorted[i_y_cdf_inv] + + diff_q = jnp.diff(quantile_levels) + successive_costs = jax.vmap(cost_fn.h)( + x_cdf_inv[1:, None] - y_cdf_inv[1:, None] + ) + cost = jnp.sum(successive_costs * diff_q) + paired_indices, mass_paired_indices, dual_a, dual_b = [None] * 4 + + if return_transport or return_dual_vectors: + n = x.shape[0] + + i_in_sorted_x_of_quantile = all_values_sorter[i_x_cdf_inv] % n + i_in_sorted_y_of_quantile = all_values_sorter[i_y_cdf_inv] - n + + orig_i = i_x[i_in_sorted_x_of_quantile][1:] + orig_j = i_y[i_in_sorted_y_of_quantile][1:] + paired_indices, mass_paired_indices = jnp.stack([orig_i, orig_j]), diff_q + + if return_dual_vectors: + m = y.shape[0] + + cliff = (jnp.abs(jnp.diff(orig_i)) > 0) & (jnp.abs(jnp.diff(orig_j)) > 0) + + #virtual_pairs = jnp.stack((orig_i[:-1], orig_j[1:])) * (2*cliff-1) + virtual_pairs = (1 + jnp.stack((orig_i[:-1], orig_j[1:]))) * (2 * cliff - 1) + + #actual_pairs = jnp.stack((orig_i[diff_q>0], orig_j[diff_q > 0])) + diff_q_sorted, diff_q_idx = mu.sort_and_argsort(diff_q, argsort=True) + actual_pairs = jnp.stack((orig_i, orig_j)) + actual_pairs = (1 + actual_pairs[:, diff_q_idx] + ) * (2 * (diff_q_sorted > 0) - 1) + # diff_q_idx = jnp.argsort(diff_q[diff_q>0]) + # print('actual_pairs', actual_pairs) + # if actual_pairs.shape[1] > n+m-1: + # actual_pairs = actual_pairs[:,diff_q_idx][:,-n-m+1:] + + # print('actual_pairs2', actual_pairs) + + # if virtual_pairs.shape[1] == 0: + # pairs = actual_pairs + # else: + pairs = jnp.concatenate((actual_pairs, virtual_pairs), axis=1) + + idx = jnp.argsort( + jnp.sum(pairs, axis=0) * + jnp.concatenate((diff_q_sorted, jnp.ones((virtual_pairs.shape[1],)))), + stable=True + )[-n - m + 1:] + pairs = pairs[:, idx] - 1 + + cost_vector = jax.vmap(cost_fn.h)(x_[pairs[0]] - y_[pairs[1]]) + + def kkt(dual_ab): + """Eq. 3.6 in :cite:`peyre:19`, with centering constraint on dual_a.""" + dual_a, dual_b = dual_ab[:n], dual_ab[n:] + return jnp.concatenate(( + jnp.array(jnp.sum(dual_a, keepdims=True)), + dual_a[pairs[0, :]] + dual_b[pairs[1, :]] + )) + + z = jnp.concatenate((jnp.zeros((1,)), cost_vector)) + operator = lx.FunctionLinearOperator(kkt, z) + solver = lx.NormalCG(rtol=1e-4, atol=1e-4, max_steps=1000) + sol = lx.linear_solve(operator, z, solver) + sol = sol.value + # split again solution into 2 dual variables + dual_a = sol[:n] + dual_b = sol[n:] + + return cost, paired_indices, mass_paired_indices, dual_a, dual_b + + +def _quant_dist( + x: jnp.ndarray, y: jnp.ndarray, cost_fn: costs.TICost, q: jnp.ndarray, + n_q: int +) -> Tuple[jnp.ndarray, None, None]: + x_q = jnp.quantile(x, q, axis=0) + y_q = jnp.quantile(y, q, axis=0) + ot_costs = jax.vmap(cost_fn.pairwise, in_axes=[1, 1])(x_q, y_q) + + return ot_costs / n_q, None, None, None, None From 888c8b35ad275d3bd42a22916b957fdb0a34962c Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Mon, 29 Apr 2024 17:12:14 +0200 Subject: [PATCH 2/7] Add `h_legendre` test --- tests/geometry/costs_test.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/geometry/costs_test.py b/tests/geometry/costs_test.py index 71b826de6..924f993ed 100644 --- a/tests/geometry/costs_test.py +++ b/tests/geometry/costs_test.py @@ -128,6 +128,12 @@ def test_bures(self, rng: jax.Array): np.testing.assert_equal(diffs.shape[0], max_iterations // inner_iterations) +class TestTIRegCost: + + def test_h_legendre(self): + pass + + @pytest.mark.fast() class TestRegTICost: @@ -216,6 +222,25 @@ def test_stronger_regularization_increases_sparsity( for fwd in [False, True]: np.testing.assert_array_equal(np.diff(sparsity[fwd]) > 0.0, True) + @pytest.mark.parametrize("d", [5, 10]) + def test_h_legendre_elastic_l2(self, rng: jax.Array, d: int): + n, d = 13, d + rngs = jax.random.split(rng, 2) + x = jax.random.normal(rngs[0], (n, d)) + u = jax.random.normal(rngs[1], (d,)) + + elastic_l2 = costs.ElasticL2(scaling_reg=0.0) + p_norm_p = costs.PNormP(p=2) + + concave_fn = lambda z: -elastic_l2.h(z) + jnp.dot(z, u) + + p_grad_h = jax.jit(jax.vmap(jax.grad(p_norm_p.h_transform(concave_fn)))) + elastic_grad_h = jax.vmap(jax.grad(elastic_l2.h_transform(concave_fn))) + + np.testing.assert_allclose( + elastic_grad_h(x), p_grad_h(x), rtol=1e-5, atol=1e-5 + ) + @pytest.mark.skipif(ts_metrics is None, reason="Not supported for Python 3.11") @pytest.mark.fast() From b0ad84ee9eb71b1959e81bfdcd066ae727962fc8 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Mon, 29 Apr 2024 17:39:27 +0200 Subject: [PATCH 3/7] Add more tests for `h_transform` --- tests/geometry/costs_test.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/tests/geometry/costs_test.py b/tests/geometry/costs_test.py index 924f993ed..e27262ab0 100644 --- a/tests/geometry/costs_test.py +++ b/tests/geometry/costs_test.py @@ -20,6 +20,7 @@ import numpy as np from ott.geometry import costs, pointcloud +from ott.math import utils as mu from ott.solvers import linear try: @@ -130,7 +131,23 @@ def test_bures(self, rng: jax.Array): class TestTIRegCost: - def test_h_legendre(self): + @pytest.mark.parametrize( + "cost_fn", [ + costs.SqPNorm(p=1.0), + costs.SqPNorm(2.3), + costs.PNormP(p=1.0), + costs.PNormP(1.3), + costs.SqEuclidean() + ] + ) + def test_h_legendre(self, rng: jax.Array, cost_fn: costs.TICost): + x = jax.random.normal(rng, (15, 3)) + h_transform = cost_fn.h_transform(mu.logsumexp) + h_transform = jax.jit(jax.vmap(jax.grad(h_transform))) + + np.testing.assert_array_equal(jnp.isfinite(h_transform(x)), True) + + def test_h_legendre_sqeucl(self): pass From 35c21139bacaf9d552829919cd9d9a8b7085ffe5 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Mon, 29 Apr 2024 17:59:50 +0200 Subject: [PATCH 4/7] Add more h_legendre tests --- tests/geometry/costs_test.py | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/tests/geometry/costs_test.py b/tests/geometry/costs_test.py index e27262ab0..581246938 100644 --- a/tests/geometry/costs_test.py +++ b/tests/geometry/costs_test.py @@ -134,8 +134,8 @@ class TestTIRegCost: @pytest.mark.parametrize( "cost_fn", [ costs.SqPNorm(p=1.0), - costs.SqPNorm(2.3), - costs.PNormP(p=1.0), + costs.SqPNorm(2.4), + costs.PNormP(p=1.1), costs.PNormP(1.3), costs.SqEuclidean() ] @@ -147,8 +147,25 @@ def test_h_legendre(self, rng: jax.Array, cost_fn: costs.TICost): np.testing.assert_array_equal(jnp.isfinite(h_transform(x)), True) - def test_h_legendre_sqeucl(self): - pass + @pytest.mark.parametrize("ridge", [1e-12, 1e-6]) + def test_h_legendre_sqeucl(self, rng: jax.Array, ridge: float): + n, d = 12, 4 + rngs = jax.random.split(rng, 2) + u = jnp.abs(jax.random.uniform(rngs[0], (d,))) + x = jax.random.normal(rngs[1], (n, d)) + + sqeucl = costs.SqEuclidean() + el_l2 = costs.ElasticL2(scaling_reg=0.0) + + h_concave = lambda z: 0.5 * (-sqeucl.h(z) + jnp.dot(z, u)) + h_concave_half = lambda z: -sqeucl.h(z) + jnp.dot(z, u) + + pred = jax.jit( + jax.vmap(jax.grad(sqeucl.h_transform(h_concave, ridge=ridge))) + ) + gt = jax.jit(jax.vmap(jax.grad(el_l2.h_transform(h_concave_half)))) + + np.testing.assert_allclose(pred(x), gt(x), rtol=1e-5, atol=1e-5) @pytest.mark.fast() From 263e3dfe5e79610027dc3027db84e3a9f4d008cc Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Mon, 29 Apr 2024 18:59:15 +0200 Subject: [PATCH 5/7] Increase tolerance for `h_transform` --- tests/geometry/costs_test.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/geometry/costs_test.py b/tests/geometry/costs_test.py index 581246938..176d01d93 100644 --- a/tests/geometry/costs_test.py +++ b/tests/geometry/costs_test.py @@ -268,8 +268,12 @@ def test_h_legendre_elastic_l2(self, rng: jax.Array, d: int): concave_fn = lambda z: -elastic_l2.h(z) + jnp.dot(z, u) - p_grad_h = jax.jit(jax.vmap(jax.grad(p_norm_p.h_transform(concave_fn)))) - elastic_grad_h = jax.vmap(jax.grad(elastic_l2.h_transform(concave_fn))) + p_grad_h = jax.jit( + jax.vmap(jax.grad(p_norm_p.h_transform(concave_fn, tol=1e-5))) + ) + elastic_grad_h = jax.vmap( + jax.grad(elastic_l2.h_transform(concave_fn, tol=1e-5)) + ) np.testing.assert_allclose( elastic_grad_h(x), p_grad_h(x), rtol=1e-5, atol=1e-5 From 02b68223dea04ba17521a4f8bfbb85710da97729 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Mon, 29 Apr 2024 19:17:59 +0200 Subject: [PATCH 6/7] Increase rtol/atol --- tests/geometry/costs_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/geometry/costs_test.py b/tests/geometry/costs_test.py index 176d01d93..7e82fecff 100644 --- a/tests/geometry/costs_test.py +++ b/tests/geometry/costs_test.py @@ -276,7 +276,7 @@ def test_h_legendre_elastic_l2(self, rng: jax.Array, d: int): ) np.testing.assert_allclose( - elastic_grad_h(x), p_grad_h(x), rtol=1e-5, atol=1e-5 + elastic_grad_h(x), p_grad_h(x), rtol=1e-4, atol=1e-4 ) From 3e4ead385c8401451bef431e3f993b28af99db62 Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Tue, 30 Apr 2024 09:31:24 +0200 Subject: [PATCH 7/7] Remove old file --- src/ott/univariate.py | 434 ------------------------------------------ 1 file changed, 434 deletions(-) delete mode 100644 src/ott/univariate.py diff --git a/src/ott/univariate.py b/src/ott/univariate.py deleted file mode 100644 index d3c20f78a..000000000 --- a/src/ott/univariate.py +++ /dev/null @@ -1,434 +0,0 @@ -# Copyright OTT-JAX -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import NamedTuple, Optional, Tuple, Union - -import lineax as lx - -import jax -import jax.numpy as jnp - -from ott import utils -from ott.geometry import costs, pointcloud -from ott.math import utils as mu -from ott.problems.linear import linear_problem - -__all__ = [ - "UnivariateOutput", "UnivariateSolver", "uniform_distance", - "quantile_distance" -] - -Distance_t = Tuple[float, Optional[jnp.ndarray], Optional[jnp.ndarray], - Optional[jnp.ndarray], Optional[jnp.ndarray]] - - -class UnivariateOutput(NamedTuple): # noqa: D101 - """Output of the :class:`~ott.solvers.linear.UnivariateSolver`. - - Objects of this class contain both solutions and problem definition of a - univariate OT problem. - - Args: - prob: OT problem between 2 weighted ``[n, d]`` and ``[m, d]`` point clouds. - ot_costs: ``[d,]`` optimal transport cost values, computed independently - along each of the ``d`` slices. - paired_indices: ``None`` if no transport was computed / recorded (e.g. when - using quantiles or subsampling approximations). Otherwise, output a tensor - of shape ``[d, 2, m+n]``, of ``m+n`` pairs of indices, for which the - optimal transport assigns mass, on each slice of the ``d`` slices - described in the dataset. Namely, for each index ``0<=k jnp.ndarray: - """Outputs a ``[d, n, m]`` tensor of all ``[n, m]`` transport matrices. - - This tensor will be extremely sparse, since it will have at most ``d(n+m)`` - non-zero values, out of ``dnm`` total entries. - """ - assert self.paired_indices is not None, \ - "[d, n, m] tensor of transports cannot be computed, likely because an" \ - " approximate method was used (using either subsampling or quantiles)." - - n, m = self.prob.geom.shape - if self.prob.is_equal_size and self.prob.is_uniform: - transport_matrices_from_indices = jax.vmap( - lambda idx, idy: jnp.eye(n)[idx, :][:, idy].T, in_axes=[0, 0] - ) - return transport_matrices_from_indices( - self.paired_indices[:, 0, :], self.paired_indices[:, 1, :] - ) - - # raveled indexing of entries. - indices = self.paired_indices[:, 0] * m + self.paired_indices[:, 1] - # segment sum is needed to collect several contributions - return jax.vmap( - lambda idx, mass: jax.ops.segment_sum( - mass, idx, indices_are_sorted=True, num_segments=n * m - ).reshape(n, m), - in_axes=[0, 0] - )(indices, self.mass_paired_indices) - - @property - def mean_transport_matrix(self) -> jnp.ndarray: - """Return the mean transport matrix, averaged over slices.""" - return jnp.mean(self.transport_matrices, axis=0) - - -@jax.tree_util.register_pytree_node_class -class UnivariateSolver: - r"""Univariate solver to compute 1D OT distance over slices of data. - - Computes 1-Dimensional optimal transport distance between two $d$-dimensional - point clouds. The total distance is the sum of univariate Wasserstein - distances on the $d$ slices of data: given two weighted point-clouds, stored - as ``[n, d]`` and ``[m, d]`` in a - :class:`~ott.problems.linear.linear_problem.LinearProblem` object, with - respective weights ``a`` and ``b``, the solver - computes ``d`` OT distances between each of these ``[n, 1]`` and ``[m, 1]`` - slices. The distance is computed using the analytical formula by default, - which involves sorting each of the slices independently. The optimal transport - matrices are also outputted when possible (described in sparse form, i.e. - pairs of indices and mass transferred between those indices). - - When weights ``a`` and ``b`` are uniform, and ``n=m``, the computation only - involves comparing sorted entries per slice, and ``d`` assignments are given. - - The user may also supply a ``num_subsamples`` parameter to extract as many - points from the original point cloud, sampled with probability masses ``a`` - and ``b``. This then simply applied the method above to the subsamples, to - output ``d`` costs, but assignments are not provided. - - When the problem is not uniform or not of equal size, the method defaults to - an inversion of the CDF, and outputs both costs and transport matrix in sparse - form. - - When a ``quantiles`` argument is passed, either specifying explicit quantiles - or a grid of quantiles, the distance is evaluated by comparing the quantiles - of the two point clouds on each slice. The OT costs are returned but - assignments are not provided. - - Args: - num_subsamples: Option to reduce the size of inputs by doing random - subsampling, taken into account marginal probabilities. - quantiles: When a vector or a number of quantiles is passed, the distance - is computed by evaluating the cost function on the sectional (one for each - dimension) quantiles of the two point cloud distributions described in the - problem. - """ - - def __init__( - self, - num_subsamples: Optional[int] = None, - quantiles: Optional[Union[int, jnp.ndarray]] = None, - ): - self._quantiles = quantiles - self.num_subsamples = num_subsamples - - @property - def quantiles(self) -> Optional[jnp.ndarray]: - """Quantiles' values used to evaluate OT cost.""" - if self._quantiles is None: - return None - if isinstance(self._quantiles, int): - return jnp.linspace(0.0, 1.0, self._quantiles) - return self._quantiles - - @property - def num_quantiles(self) -> int: - """Number of quantiles used to evaluate OT cost.""" - return 0 if self.quantiles is None else self.quantiles.shape[0] - - def __call__( - self, - prob: linear_problem.LinearProblem, - return_transport: bool = True, - return_dual_vectors: bool = True, - rng: Optional[jax.Array] = None, - ) -> UnivariateOutput: - """Computes Univariate Distance between the ``d`` dimensional slices. - - Args: - prob: Problem with a :attr:`~ott.problems.linear.LinearProblem.geom` - attribute, the two point clouds ``x`` and ``y`` - (of respective sizes ``[n, d]`` and ``[m, d]``) and a ground - `TI cost ` between two scalars. - The ``[n,]`` and ``[m,]`` size probability weights vectors are stored - in attributes `:attr:`~ott.problems.linear.LinearProblem.a` and - :attr:`~ott.problems.linear.LinearProblem.b`. - return_transport: Whether to return pairs of matched indices used to - compute optimal transport matrices. - return_dual_vectors: Whether to return pairs of dual vectors - rng: Used for random downsampling, if specified in the solver. - - Returns: - An output object, that computes ``d`` OT costs, in addition to, possibly, - paired lists of indices and their corresponding masses, on each of the - ``d`` dimensional slices of the input. - """ - geom = prob.geom - assert isinstance(geom, pointcloud.PointCloud), \ - "Geometry object in problem must be a PointCloud." - assert isinstance(geom.cost_fn, costs.TICost), \ - "Geometry's cost must be translation invariant." - - rng = utils.default_prng_key(rng) - return_transp_variables = return_transport or return_dual_vectors - - assert not(self.num_subsamples and return_transp_variables), \ - "Cannot return any transport output variables when subsampling." - - if self.num_subsamples: - x, y = self._subsample(prob, rng) - is_uniform_same_size = True - else: - # check if problem has the property uniform / same number of points - x, y = geom.x, geom.y - is_uniform_same_size = prob.is_uniform and prob.is_equal_size - - if self.quantiles is not None: - assert prob.is_uniform, \ - "The 'quantiles' method can only be used with uniform marginals." - out = _quant_dist(x, y, geom.cost_fn, self.quantiles, self.num_quantiles) - elif is_uniform_same_size and not return_dual_vectors: - return_transport = return_transport and not self.num_subsamples - out = uniform_distance(x, y, geom.cost_fn, return_transport) - else: - if False: - # perturb uniform inputs with small deterministic noise. - perturb_a = jax.random.uniform( - jax.random.PRNGKey(0), (geom.shape[0],) - ) * 1e-7 - perturb_b = jax.random.uniform( - jax.random.PRNGKey(1), (geom.shape[1],) - ) * 1e-7 - perturb_a -= jnp.mean(perturb_a) - perturb_b -= jnp.mean(perturb_b) - a = prob.a + perturb_a - b = prob.b + perturb_b - else: - a, b = prob.a, prob.b - fn = jax.vmap( - quantile_distance, in_axes=[1, 1, None, None, None, None, None] - ) - fn = lambda x, y, *args: quantile_distance(x[:, 0], y[:, 0], *args) - out = fn(x, y, geom.cost_fn, a, b, return_transport, return_dual_vectors) - - return UnivariateOutput(prob, *out) - - def _subsample(self, prob: linear_problem.LinearProblem, - rng: jax.Array) -> Tuple[jnp.ndarray, jnp.ndarray]: - n, m = prob.geom.shape - x, y = prob.geom.x, prob.geom.y - - if prob.is_uniform: - x = x[jnp.linspace(0, n, num=self.num_subsamples).astype(int), :] - y = y[jnp.linspace(0, m, num=self.num_subsamples).astype(int), :] - return x, y - - rng1, rng2 = jax.random.split(rng, 2) - x = jax.random.choice(rng1, x, (self.num_subsamples,), p=prob.a, axis=0) - y = jax.random.choice(rng2, y, (self.num_subsamples,), p=prob.b, axis=0) - return x, y - - def tree_flatten(self): # noqa: D102 - return None, (self.num_subsamples, self._quantiles) - - @classmethod - def tree_unflatten(cls, aux_data, children): # noqa: D102 - del children - return cls(*aux_data) - - -def uniform_distance( - x: jnp.ndarray, - y: jnp.ndarray, - cost_fn: costs.TICost, - return_transport: bool = True -) -> Distance_t: - """Distance between two equal-size families of uniformly weighted values x/y. - - Args: - x: Vector ``[n,]`` of real values. - y: Vector ``[n,]`` of real values. - cost_fn: Translation invariant cost function, i.e. ``c(x, y) = h(x - y)``. - return_transport: whether to return mapped pairs. - - Returns: - optimal transport cost, a list of ``n+m`` paired indices, and their - corresponding transport mass. Note that said mass can be null in some - entries, but sums to 1.0 - """ - n = x.shape[0] - i_x, i_y = jnp.argsort(x, axis=0), jnp.argsort(y, axis=0) - x = jnp.take_along_axis(x, i_x, axis=0) - y = jnp.take_along_axis(y, i_y, axis=0) - ot_costs = jax.vmap(cost_fn.h, in_axes=[0])(x.T - y.T) / n - - if return_transport: - paired_indices = jnp.stack([i_x, i_y]).transpose([2, 0, 1]) - mass_paired_indices = jnp.ones((n,)) / n - return ot_costs, paired_indices, mass_paired_indices - - return ot_costs, None, None, None, None - - -def quantile_distance( - x: jnp.ndarray, - y: jnp.ndarray, - cost_fn: costs.TICost, - a: jnp.ndarray, - b: jnp.ndarray, - return_transport: bool = True, - return_dual_vectors: bool = True, -) -> Distance_t: - """Computes distance between quantile functions of distributions (a,x)/(b,y). - - Args: - x: Vector ``[n,]`` of real values. - y: Vector ``[m,]`` of real values. - cost_fn: Translation invariant cost function, i.e. ``c(x, y) = h(x - y)``. - a: Vector ``[n,]`` of non-negative weights summing to 1. - b: Vector ``[m,]`` of non-negative weights summing to 1. - return_transport: whether to return mapped pairs. - return_dual_vectors: whether to return dual vectors. when set to ``True``, - will turn ``return_transport`` to ``True`` regardless of the user choice. - - Returns: - optimal transport cost. Optionally, a list of ``n + m`` paired indices, and - their corresponding transport mass. Note that said mass can be null in some - entries, but sums to 1.0. Optionally, two dual vectors corresponding to that - transport. - - Notes: - Inspired by :func:`~scipy.stats.wasserstein_distance`, - but can be used with other costs, not just :math:`c(x, y) = |x - y|`. - """ - x_, y_ = x, y - x, i_x = mu.sort_and_argsort(x, argsort=True) - y, i_y = mu.sort_and_argsort(y, argsort=True) - - all_values = jnp.concatenate([x, y]) - all_values_sorted, all_values_sorter = mu.sort_and_argsort( - all_values, argsort=True - ) - - x_pdf = jnp.concatenate([a[i_x], jnp.zeros_like(b)])[all_values_sorter] - y_pdf = jnp.concatenate([jnp.zeros_like(a), b[i_y]])[all_values_sorter] - - x_cdf = jnp.cumsum(x_pdf) - y_cdf = jnp.cumsum(y_pdf) - - x_y_cdfs = jnp.concatenate([x_cdf, y_cdf]) - quantile_levels, _ = mu.sort_and_argsort(x_y_cdfs, argsort=False) - - i_x_cdf_inv = jnp.searchsorted(x_cdf, quantile_levels) - x_cdf_inv = all_values_sorted[i_x_cdf_inv] - i_y_cdf_inv = jnp.searchsorted(y_cdf, quantile_levels) - y_cdf_inv = all_values_sorted[i_y_cdf_inv] - - diff_q = jnp.diff(quantile_levels) - successive_costs = jax.vmap(cost_fn.h)( - x_cdf_inv[1:, None] - y_cdf_inv[1:, None] - ) - cost = jnp.sum(successive_costs * diff_q) - paired_indices, mass_paired_indices, dual_a, dual_b = [None] * 4 - - if return_transport or return_dual_vectors: - n = x.shape[0] - - i_in_sorted_x_of_quantile = all_values_sorter[i_x_cdf_inv] % n - i_in_sorted_y_of_quantile = all_values_sorter[i_y_cdf_inv] - n - - orig_i = i_x[i_in_sorted_x_of_quantile][1:] - orig_j = i_y[i_in_sorted_y_of_quantile][1:] - paired_indices, mass_paired_indices = jnp.stack([orig_i, orig_j]), diff_q - - if return_dual_vectors: - m = y.shape[0] - - cliff = (jnp.abs(jnp.diff(orig_i)) > 0) & (jnp.abs(jnp.diff(orig_j)) > 0) - - #virtual_pairs = jnp.stack((orig_i[:-1], orig_j[1:])) * (2*cliff-1) - virtual_pairs = (1 + jnp.stack((orig_i[:-1], orig_j[1:]))) * (2 * cliff - 1) - - #actual_pairs = jnp.stack((orig_i[diff_q>0], orig_j[diff_q > 0])) - diff_q_sorted, diff_q_idx = mu.sort_and_argsort(diff_q, argsort=True) - actual_pairs = jnp.stack((orig_i, orig_j)) - actual_pairs = (1 + actual_pairs[:, diff_q_idx] - ) * (2 * (diff_q_sorted > 0) - 1) - # diff_q_idx = jnp.argsort(diff_q[diff_q>0]) - # print('actual_pairs', actual_pairs) - # if actual_pairs.shape[1] > n+m-1: - # actual_pairs = actual_pairs[:,diff_q_idx][:,-n-m+1:] - - # print('actual_pairs2', actual_pairs) - - # if virtual_pairs.shape[1] == 0: - # pairs = actual_pairs - # else: - pairs = jnp.concatenate((actual_pairs, virtual_pairs), axis=1) - - idx = jnp.argsort( - jnp.sum(pairs, axis=0) * - jnp.concatenate((diff_q_sorted, jnp.ones((virtual_pairs.shape[1],)))), - stable=True - )[-n - m + 1:] - pairs = pairs[:, idx] - 1 - - cost_vector = jax.vmap(cost_fn.h)(x_[pairs[0]] - y_[pairs[1]]) - - def kkt(dual_ab): - """Eq. 3.6 in :cite:`peyre:19`, with centering constraint on dual_a.""" - dual_a, dual_b = dual_ab[:n], dual_ab[n:] - return jnp.concatenate(( - jnp.array(jnp.sum(dual_a, keepdims=True)), - dual_a[pairs[0, :]] + dual_b[pairs[1, :]] - )) - - z = jnp.concatenate((jnp.zeros((1,)), cost_vector)) - operator = lx.FunctionLinearOperator(kkt, z) - solver = lx.NormalCG(rtol=1e-4, atol=1e-4, max_steps=1000) - sol = lx.linear_solve(operator, z, solver) - sol = sol.value - # split again solution into 2 dual variables - dual_a = sol[:n] - dual_b = sol[n:] - - return cost, paired_indices, mass_paired_indices, dual_a, dual_b - - -def _quant_dist( - x: jnp.ndarray, y: jnp.ndarray, cost_fn: costs.TICost, q: jnp.ndarray, - n_q: int -) -> Tuple[jnp.ndarray, None, None]: - x_q = jnp.quantile(x, q, axis=0) - y_q = jnp.quantile(y, q, axis=0) - ot_costs = jax.vmap(cost_fn.pairwise, in_axes=[1, 1])(x_q, y_q) - - return ot_costs / n_q, None, None, None, None