Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove source/target masks #603

Merged
merged 5 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
237 changes: 43 additions & 194 deletions src/ott/geometry/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import jax.numpy as jnp
import jax.scipy as jsp
import jax.tree_util as jtu
import numpy as np

from ott import utils
from ott.geometry import epsilon_scheduler as eps_scheduler
Expand Down Expand Up @@ -65,10 +64,6 @@ class Geometry:
scale_cost: option to rescale the cost matrix. Implemented scalings are
'median', 'mean', 'std' and 'max_cost'. Alternatively, a float factor can
be given to rescale the cost such that ``cost_matrix /= scale_cost``.
src_mask: Mask specifying valid rows when computing some statistics of
:attr:`cost_matrix`, see :attr:`src_mask`.
tgt_mask: Mask specifying valid columns when computing some statistics of
:attr:`cost_matrix`, see :attr:`tgt_mask`.

Note:
When defining a :class:`~ott.geometry.geometry.Geometry` through a
Expand All @@ -86,18 +81,13 @@ def __init__(
relative_epsilon: Optional[Literal["mean", "std"]] = None,
scale_cost: Union[float, Literal["mean", "max_cost", "median",
"std"]] = 1.0,
src_mask: Optional[jnp.ndarray] = None,
tgt_mask: Optional[jnp.ndarray] = None,
):
self._cost_matrix = cost_matrix
self._kernel_matrix = kernel_matrix
self._epsilon_init = epsilon
self._relative_epsilon = relative_epsilon
self._scale_cost = scale_cost

self._src_mask = src_mask
self._tgt_mask = tgt_mask

@property
def cost_rank(self) -> Optional[int]:
"""Output rank of cost matrix, if any was provided."""
Expand All @@ -117,14 +107,14 @@ def cost_matrix(self) -> jnp.ndarray:
@property
def median_cost_matrix(self) -> float:
"""Median of the :attr:`cost_matrix`."""
geom = self._masked_geom(mask_value=jnp.nan)
return jnp.nanmedian(geom.cost_matrix) # will fail for online PC
return jnp.median(self.cost_matrix)

@property
def mean_cost_matrix(self) -> float:
"""Mean of the :attr:`cost_matrix`."""
tmp = self._masked_geom().apply_cost(self._n_normed_ones).squeeze()
return jnp.sum(tmp * self._m_normed_ones)
n, m = self.shape
tmp = self.apply_cost(jnp.full((n,), fill_value=1.0 / n))
return jnp.sum((1.0 / m) * tmp)

@property
def std_cost_matrix(self) -> float:
Expand All @@ -139,8 +129,9 @@ def std_cost_matrix(self) -> float:

to output :math:`\sigma`.
"""
tmp = self._masked_geom().apply_square_cost(self._n_normed_ones).squeeze()
tmp = jnp.sum(tmp * self._m_normed_ones) - (self.mean_cost_matrix ** 2)
n, m = self.shape
tmp = self.apply_square_cost(jnp.full((n,), fill_value=1.0 / n))
tmp = jnp.sum((1.0 / m) * tmp) - (self.mean_cost_matrix ** 2)
return jnp.sqrt(jax.nn.relu(tmp))

@property
Expand All @@ -158,7 +149,6 @@ def epsilon_scheduler(self) -> eps_scheduler.Epsilon:
"""Epsilon scheduler."""
if isinstance(self._epsilon_init, eps_scheduler.Epsilon):
return self._epsilon_init

# no relative epsilon
if self._relative_epsilon is None:
if self._epsilon_init is not None:
Expand Down Expand Up @@ -217,23 +207,21 @@ def is_online(self) -> bool:
@property
def is_symmetric(self) -> bool:
"""Whether geometry cost/kernel is a symmetric matrix."""
n, m = self.shape
mat = self.kernel_matrix if self.cost_matrix is None else self.cost_matrix
return (
mat.shape[0] == mat.shape[1] and jnp.all(mat == mat.T)
) if mat is not None else False
return (n == m) and jnp.all(mat == mat.T)

@property
def inv_scale_cost(self) -> float:
def inv_scale_cost(self) -> jnp.ndarray:
"""Compute and return inverse of scaling factor for cost matrix."""
if isinstance(self._scale_cost, (int, float, np.number, jax.Array)):
return 1.0 / self._scale_cost
self = self._masked_geom(mask_value=jnp.nan)
if self._scale_cost == "max_cost":
return 1.0 / jnp.nanmax(self._cost_matrix)
return 1.0 / jnp.max(self._cost_matrix)
if self._scale_cost == "mean":
return 1.0 / jnp.nanmean(self._cost_matrix)
return 1.0 / jnp.mean(self._cost_matrix)
if self._scale_cost == "median":
return 1.0 / jnp.nanmedian(self._cost_matrix)
return 1.0 / jnp.median(self._cost_matrix)
if jnp.isscalar(self._scale_cost):
return 1.0 / self._scale_cost
raise ValueError(f"Scaling {self._scale_cost} not implemented.")

def set_scale_cost(self, scale_cost: Union[float, str]) -> "Geometry":
Expand Down Expand Up @@ -692,14 +680,14 @@ def to_LRCGeometry(
i_star = jax.random.randint(rng1, shape=(), minval=0, maxval=n)
j_star = jax.random.randint(rng2, shape=(), minval=0, maxval=m)

ci_star = self.subset([i_star], None).cost_matrix.ravel() ** 2 # (m,)
cj_star = self.subset(None, [j_star]).cost_matrix.ravel() ** 2 # (n,)
ci_star = self.subset(row_ixs=i_star).cost_matrix.ravel() ** 2 # (m,)
cj_star = self.subset(col_ixs=j_star).cost_matrix.ravel() ** 2 # (n,)

p_row = cj_star + ci_star[j_star] + jnp.mean(ci_star) # (n,)
p_row /= jnp.sum(p_row)
row_ixs = jax.random.choice(rng3, n, shape=(n_subset,), p=p_row)
# (n_subset, m)
s = self.subset(row_ixs, None).cost_matrix
s = self.subset(row_ixs=row_ixs).cost_matrix
s /= jnp.sqrt(n_subset * p_row[row_ixs][:, None])

p_col = jnp.sum(s ** 2, axis=0) # (m,)
Expand All @@ -720,7 +708,7 @@ def to_LRCGeometry(
col_ixs = jax.random.choice(rng5, m, shape=(n_subset,)) # (n_subset,)

# (n, n_subset)
A_trans = self.subset(None, col_ixs).cost_matrix * inv_scale
A_trans = self.subset(col_ixs=col_ixs).cost_matrix * inv_scale
B = (U[col_ixs, :] @ v * inv_scale) # (n_subset, k)
M = jnp.linalg.inv(B.T @ B) # (k, k)
V = jnp.linalg.multi_dot([A_trans, B, M.T, v.T]) # (n, k)
Expand All @@ -737,187 +725,48 @@ def to_LRCGeometry(
)

def subset(
self, src_ixs: Optional[jnp.ndarray], tgt_ixs: Optional[jnp.ndarray],
**kwargs: Any
) -> "Geometry":
"""Subset rows or columns of a geometry.

Args:
src_ixs: Row indices. If ``None``, use all rows.
tgt_ixs: Column indices. If ``None``, use all columns.
kwargs: Keyword arguments to override the initialization.

Returns:
The modified geometry.
"""

def subset_fn(
arr: Optional[jnp.ndarray],
src_ixs: Optional[jnp.ndarray],
tgt_ixs: Optional[jnp.ndarray],
) -> Optional[jnp.ndarray]:
if arr is None:
return None
if src_ixs is not None:
arr = arr[src_ixs, ...]
if tgt_ixs is not None:
arr = arr[:, tgt_ixs]
return arr # noqa: RET504

return self._mask_subset_helper(
src_ixs,
tgt_ixs,
fn=subset_fn,
propagate_mask=True,
**kwargs,
)

def mask(
self,
src_mask: Optional[jnp.ndarray],
tgt_mask: Optional[jnp.ndarray],
mask_value: float = 0.0,
row_ixs: Optional[jnp.ndarray] = None,
col_ixs: Optional[jnp.ndarray] = None
) -> "Geometry":
"""Mask rows or columns of a geometry.

The mask is used only when computing some statistics of the
:attr:`cost_matrix`.

- :attr:`mean_cost_matrix`
- :attr:`median_cost_matrix`
- :attr:`inv_scale_cost`
"""Subset rows or columns of a geometry.

Args:
src_mask: Row mask. Can be specified either as a boolean array of shape
``[num_a,]`` or as an array of indices. If ``None``, no mask is applied.
tgt_mask: Column mask. Can be specified either as a boolean array of shape
``[num_b,]`` or as an array of indices. If ``None``, no mask is applied.
mask_value: Value to use for masking.
row_ixs: Row indices. If :obj:`None`, use all rows.
col_ixs: Column indices. If :obj:`None`, use all columns.

Returns:
The masked geometry.
The subsetted geometry.
"""

def mask_fn(
arr: Optional[jnp.ndarray],
src_mask: Optional[jnp.ndarray],
tgt_mask: Optional[jnp.ndarray],
) -> Optional[jnp.ndarray]:
if arr is None:
return arr
assert arr.ndim == 2, arr.ndim
if src_mask is not None:
arr = jnp.where(src_mask[:, None], arr, mask_value)
if tgt_mask is not None:
arr = jnp.where(tgt_mask[None, :], arr, mask_value)
return arr # noqa: RET504

src_mask = self._normalize_mask(src_mask, self.shape[0])
tgt_mask = self._normalize_mask(tgt_mask, self.shape[1])
return self._mask_subset_helper(
src_mask, tgt_mask, fn=mask_fn, propagate_mask=False
)

def _mask_subset_helper(
self,
src_ixs: Optional[jnp.ndarray],
tgt_ixs: Optional[jnp.ndarray],
*,
fn: Callable[
[Optional[jnp.ndarray], Optional[jnp.ndarray], Optional[jnp.ndarray]],
Optional[jnp.ndarray]],
propagate_mask: bool,
**kwargs: Any,
) -> "Geometry":
(cost, kernel, eps, src_mask, tgt_mask), aux_data = self.tree_flatten()
cost = fn(cost, src_ixs, tgt_ixs)
kernel = fn(kernel, src_ixs, tgt_ixs)
if propagate_mask:
src_mask = self._normalize_mask(src_mask, self.shape[0])
tgt_mask = self._normalize_mask(tgt_mask, self.shape[1])
src_mask = fn(src_mask, src_ixs, None)
tgt_mask = fn(tgt_mask, tgt_ixs, None)

aux_data = {**aux_data, **kwargs}
return type(self).tree_unflatten(
aux_data, [cost, kernel, eps, src_mask, tgt_mask]
)

@property
def src_mask(self) -> Optional[jnp.ndarray]:
"""Mask of shape ``[num_a,]`` to compute :attr:`cost_matrix` statistics.

Specifically, it is used when computing:

- :attr:`mean_cost_matrix`
- :attr:`median_cost_matrix`
- :attr:`inv_scale_cost`
"""
return self._normalize_mask(self._src_mask, self.shape[0])

@property
def tgt_mask(self) -> Optional[jnp.ndarray]:
"""Mask of shape ``[num_b,]`` to compute :attr:`cost_matrix` statistics.

Specifically, it is used when computing:

- :attr:`mean_cost_matrix`
- :attr:`median_cost_matrix`
- :attr:`inv_scale_cost`
"""
return self._normalize_mask(self._tgt_mask, self.shape[1])
(cost, kernel, *rest), aux_data = self.tree_flatten()
row_ixs = row_ixs if row_ixs is None else jnp.atleast_1d(row_ixs)
col_ixs = col_ixs if col_ixs is None else jnp.atleast_1d(col_ixs)
if cost is not None:
cost = cost if row_ixs is None else cost[row_ixs]
cost = cost if col_ixs is None else cost[:, col_ixs]
if kernel is not None:
kernel = kernel if row_ixs is None else kernel[row_ixs]
kernel = kernel if col_ixs is None else kernel[:, col_ixs]
return type(self).tree_unflatten(aux_data, (cost, kernel, *rest))

@property
def dtype(self) -> jnp.dtype:
"""The data type."""
return (
self._kernel_matrix if self._cost_matrix is None else self._cost_matrix
).dtype

def _masked_geom(self, mask_value: float = 0.0) -> "Geometry":
"""Mask geometry based on :attr:`src_mask` and :attr:`tgt_mask`."""
src_mask, tgt_mask = self.src_mask, self.tgt_mask
if src_mask is None and tgt_mask is None:
return self
return self.mask(src_mask, tgt_mask, mask_value=mask_value)

@property
def _n_normed_ones(self) -> jnp.ndarray:
"""Normalized array of shape ``[num_a,]``."""
mask = self.src_mask
arr = jnp.ones(self.shape[0]) if mask is None else mask
return arr / jnp.sum(arr)

@property
def _m_normed_ones(self) -> jnp.ndarray:
"""Normalized array of shape ``[num_b,]``."""
mask = self.tgt_mask
arr = jnp.ones(self.shape[1]) if mask is None else mask
return arr / jnp.sum(arr)

@staticmethod
def _normalize_mask(mask: Optional[jnp.ndarray],
size: int) -> Optional[jnp.ndarray]:
"""Convert array of indices to a boolean mask."""
if mask is None:
return None
if not jnp.issubdtype(mask, (bool, jnp.bool_)):
mask = jnp.isin(jnp.arange(size), mask)
assert mask.shape == (size,)
return mask
if self._cost_matrix is not None:
return self._cost_matrix.dtype
return self._kernel_matrix.dtype

def tree_flatten(self): # noqa: D102
return (
self._cost_matrix, self._kernel_matrix, self._epsilon_init,
self._src_mask, self._tgt_mask
self._cost_matrix,
self._kernel_matrix,
self._epsilon_init,
), {
"scale_cost": self._scale_cost,
"relative_epsilon": self._relative_epsilon,
}

@classmethod
def tree_unflatten(cls, aux_data, children): # noqa: D102
cost, kernel, eps, src_mask, tgt_mask = children
return cls(
cost, kernel, eps, src_mask=src_mask, tgt_mask=tgt_mask, **aux_data
)
cost, kernel, epsilon = children
return cls(cost, kernel_matrix=kernel, epsilon=epsilon, **aux_data)
17 changes: 0 additions & 17 deletions src/ott/geometry/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,21 +315,6 @@ def transport_from_scalings(
"cloud geometry instead."
)

def subset(
self, src_ixs: Optional[jnp.ndarray], tgt_ixs: Optional[jnp.ndarray]
) -> NoReturn:
"""Not implemented."""
raise NotImplementedError("Subsetting is not implemented for grids.")

def mask(
self,
src_mask: Optional[jnp.ndarray],
tgt_mask: Optional[jnp.ndarray],
mask_value: float = 0.0,
) -> NoReturn:
"""Not implemented."""
raise NotImplementedError("Masking is not implemented for grids.")

@property
def cost_matrix(self) -> jnp.ndarray:
"""Not implemented."""
Expand Down Expand Up @@ -425,6 +410,4 @@ def to_LRCGeometry(
epsilon=self._epsilon_init,
relative_epsilon=self._relative_epsilon,
scale_cost=self._scale_cost,
src_mask=self.src_mask,
tgt_mask=self.tgt_mask,
)
Loading
Loading