You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Here I'm going to track the performance issues with Yota
Starter code
using Yota
using Yota.Umlaut
using Metalhead
using Profile
using ProfileView
loss(model, image) =sum(model(image))
functionmain()
model = Metalhead.ResNet(18)
image =rand(Float32, 224, 224, 3, 1)
@timemodel(image)
@timetrace(model, image; ctx=GradCtx())
@profilegrad(loss, model, image)
Profile.print(mincount=100)
end
Currently, the gradient of ResNet takes forever (or at least > 30 minutes). Function execution takes ~10 seconds, tracing it - 60 seconds, so most of the time is spent in `grad()`.
During initial tracing, Yota simply writes these higher-order functions to the tape, which is relatively fast.
These functions are rewritten into corresponding rrules.
The rrules invoke rrule_via_ad, which, in their turn, trigger tracing of the argument function. So the bottleneck is indeed tracing, though not the initial tracer pass.
Tracing is slow because Julia turns on type inference and specializes each call (a lot of slowness comes from mkcall()).
Simply wrapping the tape and the tracing into @nospecialize helped a bit in tests, but is definitely not a game changer.
In the original design, Yota wasn't supposed to trigger compilation during backpropagation. In fact, the design was very similar to JAX, with the only exception that we used IR-level tracing instead of operator overloading due to issues in multiple dispatch. Just to recap, the first versions of Yota worked like this:
Trace a function to turn it into a list of primitives for which we know the differentiation rules (forward pass).
Record these differentiation rules to the same tape (backward pass).
Compile the tape.
So exactly one tracing and one compilation. However, prevalence of ChainRules changed the game. The current design looks like this:
Trace a function.
Replace all primitive calls y = f(xs...) with y, pb = rrule(f, xs...). Save pullbacks for the next step.
Record pullback invocations.
Compile the tape.
Now Yota has no control over what happens in an rrule and, in case of higher-order functions, cannot avoid additional tracing and compilation. Since Flux uses higher-order functions extensively, we get what we get.
As far as I can see, the only way forward is to speed up tracing. However, it requires a really, really good understanding of the Julia compiler, which I don't have. To my knowledge, the only autodiff package that managed to do it is Diffractor, and not too many people understand how it works.
So I'm pretty much puzzled.
The text was updated successfully, but these errors were encountered:
Thanks to all the commenters in this thread, tracing now works ~2x times faster. I also fixed a performance bug in todo_list(), and now the whole grad(loss, model, image) compiles and runs in 61 second, which is reasonably good.
Here I'm going to track the performance issues with Yota
Starter code
Profiler output after 2 minutes of execution
My current interpretation is as follows:
rrule
s.rrule
s invokerrule_via_ad
, which, in their turn, trigger tracing of the argument function. So the bottleneck is indeed tracing, though not the initial tracer pass.mkcall()
).@nospecialize
helped a bit in tests, but is definitely not a game changer.In the original design, Yota wasn't supposed to trigger compilation during backpropagation. In fact, the design was very similar to JAX, with the only exception that we used IR-level tracing instead of operator overloading due to issues in multiple dispatch. Just to recap, the first versions of Yota worked like this:
So exactly one tracing and one compilation. However, prevalence of ChainRules changed the game. The current design looks like this:
y = f(xs...)
withy, pb = rrule(f, xs...)
. Save pullbacks for the next step.Now Yota has no control over what happens in an
rrule
and, in case of higher-order functions, cannot avoid additional tracing and compilation. Since Flux uses higher-order functions extensively, we get what we get.As far as I can see, the only way forward is to speed up tracing. However, it requires a really, really good understanding of the Julia compiler, which I don't have. To my knowledge, the only autodiff package that managed to do it is Diffractor, and not too many people understand how it works.
So I'm pretty much puzzled.
The text was updated successfully, but these errors were encountered: