-
Notifications
You must be signed in to change notification settings - Fork 31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
always unthunk results #79
Conversation
This is intended to only affect the outermost level, the thing returned to the user who types Should the same be done to |
this should be done to gradient. diffractor would actually prefer if thunks didn't exist in the first place, so this is just a fallback to prevent the user from seeing them. our ideal scenario would be a way to get derivatives from chainrules that don't have thunks in the first place, but that is a lot harder than this pr. |
Some way of not computing things which will be discarded does seem desirable. With the example of JuliaDiff/ChainRulesCore.jl#558 (but picturing julia> Diffractor.PrimeDerivativeBack(x -> f(x, b) + 10f(2x, b))(a)
rrule is called
rrule is called
∇a is called
∇a is called
42.0f0 # with this PR, seems ideal? ∇b never run despite accumulation
julia> gradient(x -> f(x, b), a) # should unthunk answer, but not run ∇b
rrule is called
(Thunk(var"#28#31"{Float32, Float32, Float32}(1.0f0, 1.0f0, 2.0f0)),) Surely a sufficiently smart compiler could notice and eliminate the |
The goal for diffractor is to use escape analysis to remove the computation entirely which is made easier with simpler types. We aren't there yet, but once stage 2 is integrated, we'll be closeish. |
This would be great. If it works, it's possible all thunks could be stripped out of ChainRules, since (IIRC) nothing else uses them anyway? Although, besides delayed calculation, they are intended one day to save memory too... xref #69 I guess. |
as I understand it, ReverseDiff.jl likes thunks, but I'm not sure. |
Yes I suppose, although you have to opt-in. More directly I forgot that Yota doesn't un-thunk internally, only the final result: julia> using Yota
julia> grad(x -> f(x, b), a) # unthunks final result
rrule is called
∇a is called
(2.0f0, (ZeroTangent(), 2.0f0))
julia> using Zygote
julia> Zygote.gradient(x -> f(x, b), a) # unthunks all
rrule is called
∇a is called
∇b is called
(2.0f0,)
julia> using ReverseDiff
julia> ReverseDiff.@grad_from_chainrules f(x::TrackedReal, y::Real)
julia> ReverseDiff.gradient(x -> f(x[1], b), [a]) # maybe this relies on thunks inside?
rrule is called
∇a is called
1-element Vector{Float32}:
2.0 |
0d199b6
to
f6d2f35
Compare
Codecov Report
@@ Coverage Diff @@
## main #79 +/- ##
==========================================
- Coverage 52.62% 51.41% -1.21%
==========================================
Files 21 21
Lines 2172 2118 -54
==========================================
- Hits 1143 1089 -54
Misses 1029 1029
Continue to review full report at Codecov.
|
I think this is good to merge. (jacobians have been extricated). |
Can this branch unthunk apply to gradient too, and perhaps within Tangent? Examples from above:
|
I think this should be merged as is. I haven't found a great way to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems ok by me. Would be nice if tests passed, though...
For disabling thunks, xref JuliaDiff/ChainRulesCore.jl#568 (inspired by difficulties with Zygote over Zygote).
Master has the same errors now, on Julia nightly: |
in theory this shouldn't be necessary, but it is a good fallback to make sure we don't return thunks to the user. Also, it's 0 overhead as long as the preceding calculation is correctly inferred.