Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot compute cost_matrix or kernel_matrix from geometry.grid.Grid if both set to None #561

Closed
tlacombe opened this issue Jul 16, 2024 · 3 comments · Fixed by #565
Closed

Comments

@tlacombe
Copy link

Is your feature request related to a problem? Please describe.

Grid (as Geometry objects) have cost_matrix and kernel_matrix attributes, but it seems that

  • a natural use (e.g. the tutorial ) of Grid will initialize both these fields to None
  • When calling for cost_matrix, an error is raised because self._kernel_matrix is None, and vice-versa.

See this mwe:

from ott.geometry.grid import Grid

print(Grid(grid_size=(5,5))._cost_matrix)  # --> None
Grid(grid_size=(5,5)).cost_matrix  # --> ValueError: "unsupported operand type(s) for *: 'NoneType' and 'float' "

the error is raised because self._kernel_matrix is None.

This is in contrast with the PointCloud geometry, where calling for cost_matrix when both cost and kernel matrices are None calls self._compute_cost_matrix and returns something.

MWE:

from jax import random
from ott.geometry.pointcloud import PointCloud

rng = random.PRNGKey(0)
keys = random.split(rng, 2)

x = [
    random.uniform(keys[0], (5,2)),
    random.uniform(keys[1], (5,2)),
]

PointCloud(x[0], x[1]).cost_matrix  # outputs a 5x5 matrix. 

Describe the solution you'd like

I don't know what's the best option between

  • Documenting that cost_matrix or kernel_matrix are not available with Grid as is and/or raise an error with explicit exception (as, for instance, when calling out.matrix after solving a transportation problem with a Grid geometry (grid.py line 306))
  • Return an actual cost matrix, perhaps of the form $C[(i_1, ...,i_d), (j_1, ... , j_d)] =$ cost to move from cell $(i_1,\dots,i_d)$ to cell $(j_1,\dots,j_d)$. But it may be extremely large, and not handy to manipulate, so it may not be a good idea.

I'd be fine with the first solution; my problem it's just that from the current documentation one may think that you can use the cost_matrix and kernel_matrix attributes to access the matrices of interest, but the error it returns is not very clear so I had to search what was going wrong.

The second option seems unnecessary since we can have access to cost matrices along each axis through geometries(self), which may be adapted to most uses. So i believe it is just a matter of documenting this behavior.

Describe alternatives you've considered

I found a quick hack in this issue but, as is, it only handles 1D grid ; though it may be adapted for nD-grids.

@marcocuturi
Copy link
Contributor

marcocuturi commented Jul 18, 2024

Thanks @tlacombe for this question~

Indeed when we wrote the Grid geometry, our idea was to never instantiate the kernel or the cost matrix, but rather to only implement efficiently the apply operators for both kernel (else mode or not) and cost matrices. In fact this was one of the early motivations to start ott-jax as we wanted to make sure the apply logic was encapsulated, and not handled by the sinkhorn solvers.

In most relevant applications, creating the cost_matrix from a grid would generate matrices far too big to instantiate in memory.

If you do need these matrices, a possible option could be to instantiate them as PointCloud by passing explicitly all of the points in the grid in the x argument. This is what is done in this test to check both approaches coincide.

def test_grid_vs_euclidean(self, rng: jax.Array, lse_mode: bool):

Again this won't scale. Maybe we can think of a to_pointcloud method to do this conversion

@tlacombe
Copy link
Author

Hello @marcocuturi ,

Yes I think that I understand the "philosophy" of the Grid class, and that instantiating the kernel/cost should be counter-productive / won't scale, so it makes sense that the option is not available for this class.

I believe that the issue simply lies on the documentation, which (as far as I understand it) suggests that the attributes kernel_matrix and cost_matrix are available for the Grid class. Raising an error in the vein "xxx_matrix not available for Grid class, use PointCloud instead" (as done in other places) may be sufficient I guess (instead of a raw "unsupported (...) for NoneType and float" which is harder to understand without further investigation).

@michalk8
Copy link
Collaborator

In most relevant applications, creating the cost_matrix from a grid would generate matrices far too big to instantiate in memory.

Agree; let's prevent instantiating the cost/kernel for grids by having a better error message.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants