-
Notifications
You must be signed in to change notification settings - Fork 28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support AD engines #50
Comments
I'm not quite sure if this makes sense, but this is doing something haha using Zygote:@adjoint
Zygote.@adjoint function nfft(x, image)
return nfft(x, image), function (Δ)
return (nothing, nfft_adjoint(x, size(image), Δ))
end
end |
@kiranshila: Thanks for you feedback. Much appreciated if the software that we developed has some real use cases. Regarding the adjoint: Could you please have a look at this: https://github.com/MagneticResonanceImaging/MRIReco.jl/blob/master/src/Operators/NFFTOp.jl |
Will do! I'll keep you posted, and if it works, I'll submit a PR to add in Zygote support. |
Cool, |
Somewhat related to this, and our discussion in the other issue on the the existence of the inverse, is it a side effect of this algorithm that the magnitudes of the fourier components in "there and back" calculation are orders of magnitude bigger? As in, in the example in your documentation, I would expect The reason this is important to me is that I am writing an optimization to match the fourier components of data using a learned image with some regularization. However, the starting image for me is just the It seems as though handing the adjoint to Zygote does indeed work, now the problem is in this scaling. |
yes, you need to take scalings into account and it often makes sense to properly "normalize" an operator. Just to give you an example: Here we are doing exactly that in order to make the FFT unitary, which it is not in the standard definition: https://github.com/tknopp/SparsityOperators.jl/blob/master/src/FFTOp.jl#L41 But for the NFFT this is more complicated at it actually involves, what is usually called the density compensation. Let A be the NFFT matrix. Then A^H A is in general not the identity matrix. Instead what you want /need ist A^H W A where W is a diagonal matrix with the squared density weights in it. It holds that A^H W A \approx I if the weights are appropriately chosen. Then I recommend that you use A^H W^{1/2} and W^{1/2} A as your transformation pair. We actually also have a method in NFFT.jl to automatically calculate the density weights (called |
Oh perfect! Thank you! That makes sense. |
Please have a look at this article: https://downloads.hindawi.com/journals/ijbi/2007/024727.pdf It is on MRI reconstruction but actually the initial formulas touch exactly what I described. |
Yeah it seems that MRI reconstruction is very similar to radio telescope imaging |
Hi @roflmaostc, I would like to make NFFT.jl AD friendly but unfortunately don't yet have a deeper understanding on how to write custom chain rules. I have seen that you were involved in the chain rules for If I get it right, this probably could be done on the level of |
I don't know about ChainRulesCore, but it is about 1 line of code using Now that I RTFM I see that they recommend using ChainRulesCore: https://fluxml.ai/Zygote.jl/dev/adjoints/ I usually wrap my NFFT calls in a LinearMap so I think I will look into a chain rule for that. |
Yes I read that manual as well and with If I look at the LinearMaps example, it could look something like this: function rrule(::typeof(*), A:: AbstractNFFTPlan, x::AbstractArray)
y = A*x
function pullback(dy)
DY = unthunk(dy)
# Because A is an abstract map, the product is only differentiable w.r.t the input
return NoTangent(), NoTangent(), @thunk(A' * DY)
end
return y, pullback
end I am not sure if the same works for I like that the LinearMaps PR has a test for the Chain rules. Flux is a pretty heavy dependency, not sure if that is really needed, or if there are easier ways to test the |
This issue should be fixed by commit 8552b28, where we implemented |
First off, thank you for this excellent library! It has been invaluable in my work on radio astronomy imaging.
I am working on some imaging methods that are iterative, and I would like to use the rest of the Julia ecosystem to support gradient-based optimization methods. I have tried ForwardDiff, ReverseDiff, and Zygote (seemingly the three most popular AD engines), and none of them seem to work with this library. I think this might come from the requirements that the type of the variables in the plan must be the same as the data, and if I'm ADing w.r.t the data, the dual number wrapper forms a different type - but I'm not quite sure how to rectify that. I think zygote should get around that problem, but I'm not as familiar with the workings on the source-to-source methods.
The text was updated successfully, but these errors were encountered: