Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support AD engines #50

Closed
kiranshila opened this issue May 19, 2021 · 13 comments
Closed

Support AD engines #50

kiranshila opened this issue May 19, 2021 · 13 comments

Comments

@kiranshila
Copy link

First off, thank you for this excellent library! It has been invaluable in my work on radio astronomy imaging.

I am working on some imaging methods that are iterative, and I would like to use the rest of the Julia ecosystem to support gradient-based optimization methods. I have tried ForwardDiff, ReverseDiff, and Zygote (seemingly the three most popular AD engines), and none of them seem to work with this library. I think this might come from the requirements that the type of the variables in the plan must be the same as the data, and if I'm ADing w.r.t the data, the dual number wrapper forms a different type - but I'm not quite sure how to rectify that. I think zygote should get around that problem, but I'm not as familiar with the workings on the source-to-source methods.

@kiranshila
Copy link
Author

I'm not quite sure if this makes sense, but this is doing something haha

using Zygote:@adjoint

Zygote.@adjoint function nfft(x, image)
    return nfft(x, image), function (Δ)
        return (nothing, nfft_adjoint(x, size(image), Δ))
    end
end

@tknopp
Copy link
Member

tknopp commented May 19, 2021

@kiranshila: Thanks for you feedback. Much appreciated if the software that we developed has some real use cases.

Regarding the adjoint: Could you please have a look at this: https://github.com/MagneticResonanceImaging/MRIReco.jl/blob/master/src/Operators/NFFTOp.jl
It is basically an operator version of the nfft and it already has a adjoint. I would say Zygote should be happy with that.

@kiranshila
Copy link
Author

Will do! I'll keep you posted, and if it works, I'll submit a PR to add in Zygote support.

@tknopp
Copy link
Member

tknopp commented May 19, 2021

Cool, NFFTOp is currently in a different package but it actually would fit into NFFT.jl itself. We pull in a dependency on LinearOperators but that should actually not be a big deal.

@kiranshila
Copy link
Author

Somewhat related to this, and our discussion in the other issue on the the existence of the inverse, is it a side effect of this algorithm that the magnitudes of the fourier components in "there and back" calculation are orders of magnitude bigger? As in, in the example in your documentation, I would expect fHat and g to be identical, but the magnitudes of g are much higher. If I use the adjoint to make an image, the relative magnitudes are identical, but it is scaled.

The reason this is important to me is that I am writing an optimization to match the fourier components of data using a learned image with some regularization. However, the starting image for me is just the nfft_adjoint of the data. I would expect the difference of the fourier components of this image and the data to be zero, but because of this scaling, it is not.

It seems as though handing the adjoint to Zygote does indeed work, now the problem is in this scaling.

@tknopp
Copy link
Member

tknopp commented May 20, 2021

yes, you need to take scalings into account and it often makes sense to properly "normalize" an operator. Just to give you an example: Here we are doing exactly that in order to make the FFT unitary, which it is not in the standard definition: https://github.com/tknopp/SparsityOperators.jl/blob/master/src/FFTOp.jl#L41

But for the NFFT this is more complicated at it actually involves, what is usually called the density compensation. Let A be the NFFT matrix. Then A^H A is in general not the identity matrix. Instead what you want /need ist A^H W A where W is a diagonal matrix with the squared density weights in it. It holds that A^H W A \approx I if the weights are appropriately chosen. Then I recommend that you use A^H W^{1/2} and W^{1/2} A as your transformation pair.

We actually also have a method in NFFT.jl to automatically calculate the density weights (called sdc I think).

@kiranshila
Copy link
Author

Oh perfect! Thank you! That makes sense.

@tknopp
Copy link
Member

tknopp commented May 20, 2021

Please have a look at this article: https://downloads.hindawi.com/journals/ijbi/2007/024727.pdf It is on MRI reconstruction but actually the initial formulas touch exactly what I described.

@kiranshila
Copy link
Author

Yeah it seems that MRI reconstruction is very similar to radio telescope imaging

@tknopp
Copy link
Member

tknopp commented Sep 4, 2022

Hi @roflmaostc,

I would like to make NFFT.jl AD friendly but unfortunately don't yet have a deeper understanding on how to write custom chain rules. I have seen that you were involved in the chain rules for AbstractFFTs: JuliaDiff/ChainRules.jl#127 Could you help me with NFFT.jl?

If I get it right, this probably could be done on the level of AbstractNFFTs allowing all implementations to benefit from that.

@JeffFessler
Copy link
Collaborator

I don't know about ChainRulesCore, but it is about 1 line of code using @adjoint in Flux / Zygote;
see this example: https://jefffessler.github.io/SPECTrecon.jl/stable/generated/examples/6-dl/#Custom-backpropagation

Now that I RTFM I see that they recommend using ChainRulesCore: https://fluxml.ai/Zygote.jl/dev/adjoints/

I usually wrap my NFFT calls in a LinearMap so I think I will look into a chain rule for that.
(I looked at LinearOperators.jl and didn't see a dependency on ChainRules there BTW.)

@tknopp
Copy link
Member

tknopp commented Sep 5, 2022

Yes I read that manual as well and with AbstractFFTs they did basically the same: JuliaMath/AbstractFFTs.jl#58

If I look at the LinearMaps example, it could look something like this:

function rrule(::typeof(*), A:: AbstractNFFTPlan, x::AbstractArray)
    y = A*x
    function pullback(dy)
      DY = unthunk(dy)
      # Because A is an abstract map, the product is only differentiable w.r.t the input
      return NoTangent(), NoTangent(), @thunk(A' * DY)
    end
    return y, pullback
end

I am not sure if the same works for mul!.

I like that the LinearMaps PR has a test for the Chain rules. Flux is a pretty heavy dependency, not sure if that is really needed, or if there are easier ways to test the rrule.

@migrosser
Copy link
Collaborator

This issue should be fixed by commit 8552b28, where we implemented ChainRulesCore.frule and ChainRulesCore.rrule. Therefore, I will close it for now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants