-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cannot compute cost_matrix
or kernel_matrix
from geometry.grid.Grid
if both set to None
#561
Comments
Thanks @tlacombe for this question~ Indeed when we wrote the In most relevant applications, creating the If you do need these matrices, a possible option could be to instantiate them as
Again this won't scale. Maybe we can think of a |
Hello @marcocuturi , Yes I think that I understand the "philosophy" of the I believe that the issue simply lies on the documentation, which (as far as I understand it) suggests that the attributes |
Agree; let's prevent instantiating the cost/kernel for grids by having a better error message. |
Is your feature request related to a problem? Please describe.
Grid
(asGeometry
objects) havecost_matrix
andkernel_matrix
attributes, but it seems thatGrid
will initialize both these fields toNone
cost_matrix
, an error is raised becauseself._kernel_matrix
isNone
, and vice-versa.See this mwe:
the error is raised because
self._kernel_matrix
isNone
.This is in contrast with the
PointCloud
geometry, where calling forcost_matrix
when both cost and kernel matrices areNone
callsself._compute_cost_matrix
and returns something.MWE:
Describe the solution you'd like
I don't know what's the best option between
cost_matrix
orkernel_matrix
are not available withGrid
as is and/or raise an error with explicit exception (as, for instance, when callingout.matrix
after solving a transportation problem with aGrid
geometry (grid.py
line 306))I'd be fine with the first solution; my problem it's just that from the current documentation one may think that you can use the
cost_matrix
andkernel_matrix
attributes to access the matrices of interest, but the error it returns is not very clear so I had to search what was going wrong.The second option seems unnecessary since we can have access to cost matrices along each axis through
geometries(self)
, which may be adapted to most uses. So i believe it is just a matter of documenting this behavior.Describe alternatives you've considered
I found a quick hack in this issue but, as is, it only handles 1D grid ; though it may be adapted for nD-grids.
The text was updated successfully, but these errors were encountered: