Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AD support via chainrules #177

Merged
merged 9 commits into from
Sep 16, 2022
Merged

Conversation

mloubout
Copy link
Contributor

@mloubout mloubout commented Apr 6, 2022

Supports AD via chainrules for the multiplication, closes #176.

Currently, mul! does not support AD because I'm not sure how to define a perturbation on the LinearMap itself for the forward rule.

Added a test as well.

Project.toml Show resolved Hide resolved
@mloubout mloubout force-pushed the master branch 2 times, most recently from 3867248 to 4b90548 Compare April 6, 2022 15:36
@codecov
Copy link

codecov bot commented Apr 6, 2022

Codecov Report

Base: 99.58% // Head: 99.58% // Increases project coverage by +0.00% 🎉

Coverage data is based on head (4128bbf) compared to base (e19a7cc).
Patch coverage: 100.00% of modified lines in pull request are covered.

Additional details and impacted files
@@           Coverage Diff           @@
##           master     #177   +/-   ##
=======================================
  Coverage   99.58%   99.58%           
=======================================
  Files          17       18    +1     
  Lines        1448     1460   +12     
=======================================
+ Hits         1442     1454   +12     
  Misses          6        6           
Impacted Files Coverage Δ
src/LinearMaps.jl 100.00% <ø> (ø)
src/chainrules.jl 100.00% <100.00%> (ø)

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report at Codecov.
📢 Do you have feedback about the report comment? Let us know in this issue.

@mloubout mloubout force-pushed the master branch 2 times, most recently from f301a22 to 3e249e4 Compare April 6, 2022 16:54
Copy link
Member

@dkarrasch dkarrasch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clear contribution. I just have a few style remarks. As this is unfamiliar ground for me, is there any chance to have special differentiation rules for FunctionMaps that would "expose" the code to Flux? I have no idea what people need or use, but I imagine that would be awesome. Our usual toy example is cumsum, which in terms of a matrix is just LowerTriangular(ones(m,m)). That doesn't have parameters in it so the derivative should be zero. But one could imagine a "weighted cumulative sum" with some weight vector input. Since for any weight we obtain a linear map, taking derivatives shouldn't lead us out of the linear map realms, naively speaking? Or the other way around, forgetting about LinearMaps.jl and just considering "weighted cumsum" as a piece of code, that should be accessible to AD, shouldn't it? Anyway, I'm just thinking out loud without any concrete idea.

src/LinearMaps.jl Outdated Show resolved Hide resolved
src/chainrules.jl Outdated Show resolved Hide resolved
@mloubout
Copy link
Contributor Author

mloubout commented Apr 7, 2022

Thanks for the quick review.

Concerning FunctionMaps this is not really a trivial thing. For some special case you can write a specific rule, for example there is rules defined for Diagonal in ChainRules.jl but that's on a case-by-case basis and cannot really be generalized to an anonymous function.
So conceptually yeah that would be great, but it's a lot more work than this PR. I

@Jutho
Copy link
Collaborator

Jutho commented Apr 8, 2022

What are the problems with AD not handling this (not sure what "this" is) automatically?

I am currently experimenting with AD in the context of KrylovKit.jl, for the solution of linear problems or eigenvalue problems. In KrylovKit, the linear operator is generically specified as a function (not a LinearMap). It seems possible to define the rules even for the case where these are functions with internal parameters, so assuming A to be non-differentiable is not even necessary. But there the hard part is the rule for the linear problem/eigenvalue problem itself; the rule for the matrix vector product should be easy enough that AD can handle it by itself, no?

@dkarrasch
Copy link
Member

Thanks for chiming in, @Jutho. And indeed, the provided tests almost work, they fail with

Can't differentiate gc_preserve_end expression
  Stacktrace:
    [1] error(s::String)
      @ Base ./error.jl:35
    [2] Pullback
      @ ~/Documents/julia/usr/share/julia/stdlib/v1.9/LinearAlgebra/src/blas.jl:644 [inlined]

from within BLAS. But with A = LinearMap(cumsum, reverse∘cumsum∘reverse, 10) instead, all tests just pass even without this PR. 🎉 OTOH, I quickly tested with a BlockMap, and then it fails again with

Mutating arrays is not supported -- called copyto!(::SubArray{Float64, 1, Vector{Float64}, Tuple{UnitRange{Int64}}, true}, _...)

So maybe this PR is necessary to cover up internal "issues"?

@mloubout
Copy link
Contributor Author

There is a lot of subtleties to it yes.

, the linear operator is generically specified as a function

In this case, most of the rules for standard functions are already defined in ChainRule.jl so it will automatically know what to do. The issue here (IMO) and can be discussed is that the function is represented by a matrix. If you look at the rull definitions most rules will look like

function rrule(func, args...)
  y = func(args...)
  pullback(dy) = NoTangent(), args_derivs

And that first NoTangent is because the function itself is not differentiable (i.e diff(func)). WIth LinearMap, you are replacing this func by an abstract matrix. In some cases yes that matrix can be explicit and a derivative can be defined but conceptually, it is not differentiable as it is a representation of a function.

For the example A = LinearMap(cumsum, reverse∘cumsum∘reverse, 10) the AD will be smart enough to only read the function call and ignore the MAp, but in a generic case, some internal of the Map itself may render things a bit complicated (such as mutating arrays that are usually a problem for AD)

@mloubout
Copy link
Contributor Author

(sorry for the second message)

Finally, defining the rule internally that d (Fx)/ dx = F' (since we know it's linear operator with defined adjoint) may lead to more straightforward AD as for example everse∘cumsum∘reverse, will be treated as three different function that may degrade performance (compute and memory)

Copy link
Member

@dkarrasch dkarrasch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is a new feature, please bump the minor version, and please write a short "news" entry in the documentation (history.md), then we're ready to go.

Project.toml Outdated Show resolved Hide resolved
@dkarrasch
Copy link
Member

For reasons of performance and functionality (handling of inplace operations etc.), this seems like a necessary step. Shall we merge this? @Jutho Were you afraid that this PR may prevent getting derivatives of the linear operator working later? Perhaps, if we have concrete use cases, we may add more specific rules, like directing dA/dp where A is the linear map and p is some set of parameters to the wrapped map?

@mloubout
Copy link
Contributor Author

Sorry for the delay, I am on parental leave will get back to it asap.

IMO there is two main cases

  • A standard linear function (i.e cumsum). In that case the function is fixed so it's linear representation is as well so dA/dx doesn't exist
  • A parametric function. In that case it may not be linear anymore depending on the parameteization so it would probably be on a case by case with a specific rule for a given parametric map.

Happy to iterate over the PR to get it to a best usable point

@Jutho
Copy link
Collaborator

Jutho commented May 18, 2022

As I am no longer actively using this package, I don't think my opinion should be valued highly.

That being said, I would indeed like to see use cases where this rule really helps, where there is no custom rule higher up that states that you need to do A' * x, and where you are not interested in varying parameters in A. For example, the case I am most familiar with where LinearMaps are used is in Krylov methods. However, I think it would be a bad idea to AD through Krylov methods, i.e. you will want a custom rrule for the Krylov method itself, which will be calling the action of A' . But there are likely many use cases that I am not at all familiar with.

Assuming that just because a function has additional parameters would make it a nonlinear function of x is not really a good argument. It would be rather evil to insert a non-linear function into LinearMap, and the consequences of that would be the user's own responsibility.

@mloubout
Copy link
Contributor Author

Sorry for the delay will get to it next week.

src/chainrules.jl Outdated Show resolved Hide resolved
src/chainrules.jl Outdated Show resolved Hide resolved
@JeffFessler
Copy link
Member

I'd like to add my general voice of support for building in support for autodiff here.
I believe ChainRulesCore is indeed the appropriate mechanism, but am not completely sure about that because I've never tried it myself.
My group has a LinearMapAO for SPECT that is pretty efficient by itself, but autodiff of A*x was very slow until we wrote a custom @adjoint using Flux
https://github.com/JeffFessler/SPECTrecon.jl/blob/7ae2b750fafdd8baa4a7ad0d5399cbe9f62e31ee/docs/lit/examples/6-dl.jl#L167
It would be convenient to have capability this built-in so that we don't need to repeat that for every new LinearMap.

src/chainrules.jl Outdated Show resolved Hide resolved
@JeffFessler
Copy link
Member

Bump? I think there will be increasing interest in automating AD for linear maps:
JuliaMath/NFFT.jl#50

@tknopp
Copy link
Contributor

tknopp commented Sep 6, 2022

Out of curiosity: Is there a reason to use Flux in the tests instead of Zygote? If I see it right, only the gradient function from Zygote is used. (There is also https://juliadiff.org/ChainRulesTestUtils.jl/stable/)

@mloubout
Copy link
Contributor Author

mloubout commented Sep 6, 2022

I am quite late on my TODOs, I'm really sorry this is taking a while.

Is there a reason to use Flux in the tests instead of Zygote?

Just the habit, no particular reason and I'm fine switching to Zygote.

This will work only for maps that have an adjoint.

This is the assumption for a Linear Operator yes when used as part of a "backpropagated" code. If the adjoint is not defined then an error should occur but I am not sure what the "cleanest" way to do so is to avoid having it mized in the middle of Zygote error messages

@dkarrasch
Copy link
Member

IIUC, then we technically could add differentiation w.r.t. the operator later, right? We would need to replace the NoTangent() output by something. So, we wouldn't break anything, just add a new feature. I guess one would need to have something like the update_coefficients function from the SciML ecosystem, but somebody is already re-implementing LinearMaps.jl with caches and update functions over there, so I'm not sure anybody will ever work on it here.

@mloubout I think all is set up right now, so maybe you could write a little announcement in the documentation, perhaps with a little "two-liner" as an example. You could, for instance, use one of your earlier test examples that I have replaced with a plain test_rrule to avoid depending on heavy packages, and also only minimally test what we provide here and not depend on possible breakage downstream.

@mloubout
Copy link
Contributor Author

mloubout commented Sep 6, 2022

could add differentiation w.r.t. the operator later, right?

Yes. I think it would go the Tangent{LinearMap}(;.....) way that would follow the design of ChainRules (and the update rule accordingly). It would require some more careful case-by-case design I think but could maybe be generalized. But as you said, the current state allow to easy integration of "static" operators within ML code which is a step in a good direction IMO.

I'll work on the tiny Doc update hopefully by end of the week max. Thanks for the patience

src/LinearMaps.jl Outdated Show resolved Hide resolved
@JeffFessler
Copy link
Member

somebody is already re-implementing LinearMaps.jl
sigh ☹️
Otherwise, LGTM other than one small code suggestion to aid future readers.

docs/src/history.md Outdated Show resolved Hide resolved
@mloubout
Copy link
Contributor Author

Thank you for the patience, I think all comments are answered now.

@dkarrasch dkarrasch merged commit bd10c6c into JuliaLinearAlgebra:master Sep 16, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

AD support
5 participants