Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scatter_logsumexp: NaNs on untouched indices #368

Closed
aabbas90 opened this issue Apr 13, 2023 · 5 comments · Fixed by #369
Closed

scatter_logsumexp: NaNs on untouched indices #368

aabbas90 opened this issue Apr 13, 2023 · 5 comments · Fixed by #369

Comments

@aabbas90
Copy link

aabbas90 commented Apr 13, 2023

Hi,

I am trying to perform scatter_logsumexp on a strict subset of indices of the out tensor. I am getting NaNs at the indices where out is supposed to be untouched. Example:

import torch
from torch_scatter import scatter_logsumexp

src = torch.Tensor([0.0, 1.0, 4.0])
index = torch.tensor([1, 1, 4])
out = torch.zeros((6, ), dtype = torch.float32)

scatter_logsumexp(src, index, out = out)
print(out)
tensor([   nan, 1.5514,    nan,    nan, 4.0181,    nan]) # Only indices 1, 4 should be changed
print(torch_scatter.__version__)
'2.1.1+pt20cu118'

Another issue even if the NaN issue is resolved is about efficiency. We would ideally like to only operate those locations of out which are referred to in index. Otherwise for a very large sized out we are doing redundant calculations.

Thanks,
Ahmed

@rusty1s
Copy link
Owner

rusty1s commented Apr 14, 2023

Thanks for reporting. I fixed this in #369.

@rusty1s rusty1s linked a pull request Apr 14, 2023 that will close this issue
@aabbas90
Copy link
Author

Thanks for the quick fix. But there is an issue with backpropagation now:

import torch
from torch_scatter import scatter_logsumexp

src = torch.Tensor([0.0, 1.0, 4.0])
src.requires_grad = True
src.retain_grad()
index = torch.tensor([1, 1, 4])
out = torch.zeros((6, ), dtype = torch.float32)

scatter_logsumexp(src, index, out = out)

loss = torch.square(out - torch.ones_like(out)).sum()
loss.backward()
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [6]], which is output 0 of ExpBackward0, is at version 7; expected version 2 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
print(src.grad)

@rusty1s
Copy link
Owner

rusty1s commented Apr 14, 2023

Yes, that's because we need to write in-place to out.

out = scatter_logsumexp(src, index)

should fix this.

@yzhangcs
Copy link

yzhangcs commented Jun 1, 2023

@rusty1s Hi, is there any pregress on how to cure backpropagation problems?

and I also wonder if there any plans to optimize scatter_logsumexp with cud? Directly log then sum then exp may cause many overflow/underflow issues.

@rusty1s
Copy link
Owner

rusty1s commented Jun 2, 2023

I think this issue is only present if you pass in out, and there is not much I can do about this. It should work without specifying that. AFAIK, there are no numerical issues in our implementation, as we correctly compute exp by first subtracting by the max element.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants