-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hybrid overloading + SCT; not pure SCT? #24
Comments
Feel free to close this if the hybrid approach is indeed intentional, and the long term plan. |
@YingboMa the Cassette PR doesn't address Lyndon's question. It would be helpful to understand whether the long-term goal is to take the kind of SCT approach that Lyndon discusses above, or to continue using various dual types. |
That PR makes ForwardDiff2 on |
Sure. So is the plan to have a single |
|
Pure source to source is also prone to missing rule errors, so I don't think that is a good idea for the forward mode. |
Well, yes, kind of, but it only makes sense with subtypes of
Is it? Could you provide an example please? |
With methods, code_typed One can solve this problem with a dependency analysis in source to source AD, but that is just doing more work. Won't the Cassette PR make FD2 as powerful as what pure source to source could be? |
They could potentially become the same type, yes. What we do is a hybrid, but we'd love to move to use Tagging in Cassette which would make it source-to-source. But tagging is pretty unusable at the moment, so it will need more work in the Cassette and compiler side. Cassette's tagging does the same thing as DualArrays for arrays.
Almost. We still won't support structs that have fields restricted to subtypes of
this problem will show up from time to time when Dual tries to fit into the type hierarchy. |
Is see. There's only a finite number of such functions though, and you can just define Moreover, you should generally be able to use Also, assuming that we do wind up a world where struct Foo
a::Float64
end
(foo::Foo)(x) = 5x * foo.a
foo = Dual(Foo(5.0), Composite{Foo}(;a=1.0.))
methods(foo) An instance of
Understood. Yeah, the current state of the Cassette's argument-local metadata propagation makes me sad.
Agreed. But I assume that this problem will go away once you've got a generic |
It's super nice when you don't need performance though :) |
Note the source transform I propose in the to post doesn't need Tagging. I think hybrid SCT + overloading has some advantages in particular it's less likely to hit compiler edge cases, and one has to write less really annoying IR transforms. |
One thing about source to source is it supports mutation basically as a side effect of everything you have to do anyway *. But you do run into issues relating to overloading (in either direction) Pure source code tranformation does not use overloading so don't run into these. @AntoineLevitt posted this example which currently fails
In this case making it work for overloading (and arguably making it more "correct" julia code) is just a matter of changing line 2 to be But lets think what its source transform looks like:We probably do want a
and we want a frule for
so now lets look at the whole program: repeating it again:
The compiler will optimize out the checks to
Which is pretty slick. Things needed from the description language (ChainRulesCore)
|
That's just the easy case though, and FD2 can fix that. But if the buffers come from the user and the user cannot define the internal weird type that FD2 wants, then it won't be fixed. The real fix is #8 . It's the same reason why constant global mutable buffers work in ForwardDiff, and don't work in Tracker/Zygote. Your solution is cute for the easy case that is still actually allocating, but doesn't work for non-allocating optimized code. |
I can imagine that it doesn't work on that, sure.
Yes it does. AFAICT. Also that case is great becauuse if not allocating don't need to worry about constructors, so many of the issues relating to them are solved for me. Lets look at:
I am going to write the transformed code in the abridged version assuming that
So just like before. So it all just works the same. We get only as much mutation again as we had for the primal problem
Not sure what that means, if the user wants to provide a preallocated array to contain the derivatives they better knoow what type of derivatives they are getting back. Just like for nornal preallocation, but I guess a little harder as they have to know a bit about differentials. There is a weirder case where you have preallocated partials but not preallocated primal values. |
But that's incorrect because then you'd have to use a Union for ForwardDiff, but you don't. You can just make a DualCache that just needs to know the chunk size and can reinterpret. With a type based version this is fairly trivial to expose: you just say preallocations should be done with this type. If things like chunksize and such are never exposed, how do you make this helper function. What type does it output? |
@oxinabox let's make it concrete. https://tutorials.juliadiffeq.org/html/introduction/03-optimizing_diffeq_code.html Case 1: Ayu = zeros(N,N)
uAx = zeros(N,N)
Du = zeros(N,N)
Ayv = zeros(N,N)
vAx = zeros(N,N)
Dv = zeros(N,N)
function gm3!(dr,r,p,t)
a,α,ubar,β,D1,D2 = p
u = @view r[:,:,1]
v = @view r[:,:,2]
du = @view dr[:,:,1]
dv = @view dr[:,:,2]
mul!(Ayu,Ay,u)
mul!(uAx,u,Ax)
mul!(Ayv,Ay,v)
mul!(vAx,v,Ax)
@. Du = D1*(Ayu + uAx)
@. Dv = D2*(Ayv + vAx)
@. du = Du + a*u*u./v + ubar - α*u
@. dv = Dv + a*u*u - β*v
end
prob = ODEProblem(gm3!,r0,(0.0,0.1),p)
@benchmark solve(prob,Tsit5()) Case 2: p = (1.0,1.0,1.0,10.0,0.001,100.0,Ayu,uAx,Du,Ayv,vAx,Dv) # a,α,ubar,β,D1,D2
function gm4!(dr,r,p,t)
a,α,ubar,β,D1,D2,Ayu,uAx,Du,Ayv,vAx,Dv = p
u = @view r[:,:,1]
v = @view r[:,:,2]
du = @view dr[:,:,1]
dv = @view dr[:,:,2]
mul!(Ayu,Ay,u)
mul!(uAx,u,Ax)
mul!(Ayv,Ay,v)
mul!(vAx,v,Ax)
@. Du = D1*(Ayu + uAx)
@. Dv = D2*(Ayv + vAx)
@. du = Du + a*u*u./v + ubar - α*u
@. dv = Dv + a*u*u - β*v
end
prob = ODEProblem(gm4!,r0,(0.0,0.1),p)
@benchmark solve(prob,Tsit5()) The standard way to handle this with ForwardDiff would be to just do Ayu = dualcache(u, N=ForwardDiff.pickchunksize(length(u0)) for all of the buffers, and then https://github.com/JuliaDiffEq/DiffEqBase.jl/blob/master/src/init.jl#L49-L61 How does your setup handle that? |
Oh Duel Cache is cute, took me a while to get it. I think we can get this behavour by defining a
Could even get rid of the From there everything just works I believe. (* Technically speaking one with
Those |
Yup exactly.
Yeah, that's cool, I guess you can do it through |
Concrete examples FTW. I am sure there is a lot of fiddlyness in writing the pure SCT forward mode.
The harder bit is it is important to handle chunking. I think it has all the power we need and is more flexible to user code. I don't think its super high priority, ForwardDiff2 is pretty great as is for many use cases. And we would need a tone of |
In general, it seems extremely unlikely to me that there's something hybrid/OO can do that pure SCT can't. Mutations, including global ones, caches etc. are all fine; you can also do things like adding rules to zero-argument functions, functions that take structs, etc. If there's scepticism on that still, we can always work through more specific cases. Until then you're going to have problems with duals leaking out (#31), things that won't compile (or worse, run incorrectly) due to perturbation confusion, errors on structs with type requirements or code that doesn't use |
I think it's best to have something that works first, then think about pure SCT. Right now we haven't had a pure SCT AD perform well or have a good feature set, which is a good reason to get something good in the next month and then consider something possibly better in the time a multi-year time frame. |
If you can get the current design working fast you can certainly get an SCT version to work fast too. I suspect it's significantly easier than putting effort to working around all those issues with Cassette tricks, if effort is going into that. |
As i understand it Cassette Tagging, with global handling turned off, Also its worth remembering: |
I think you're technically right that you could twist Cassette's metadata system to behaving like (forward mode) SCT AD with some effort, but this would look very different from using it as a tagging system (and would make a lot of the features that make it slow unnecessary). One way to look at the difference, which might be helpful, is that SCT AD works only with variables whereas OO AD (dual types or compiler-level tagging) works with values (much like static/dynamic types). A big difference is whether you can dynamically decide whether something is tagged or untagged. For example: y = x > 0 ? x : 0 In OO AD At first, having a variable for SCT AD is not strictly better than OO AD, especially in forward mode; it's a tradeoff (I leave the benefits of OO as an exercise to the reader). From my perspective it's mostly just important to be clear that these are not actually the same thing, which is a confusion that's come up on this thread and elsewhere a couple of times. |
Its actually difficult for both. Then you need to know what operations are going to be done in the future. But for reverse mode its theoretically possible, just very hard. |
That's true, though only if you choose to represent Zygote might end up getting a similar stricter mapping from primal to adjoint types for somewhat similar reasons. But it's a bigger sacrifice for Zygote because we might ideally want to promote based on future operations, whereas in forward mode only the primal type itself matters. |
@ChrisRackauckas assured me that ForwardDiff2
was pure source code transformation based approach.
Rather than an overloading based approach.
My thoughts were that is was primarily an overloading based approach that uses limited source code transformation to make it possible to extend the overloads with
frule
s.Now its a blurry line between the two.
But the place I would cut it is that a pure source code transform approach never calls the original function passing in a special overloaded type. Not even as a fallback.
It always calls a function that it created.
This means it never has method errors,
since it created the function its going to call.
To describe the rewritten function:
I am going to write it without
Dual
types but you can do it withDual
types, it just means allocating an extra slot in the transformed code and putting stuff into that slot.We started with:
sq2(x::Float64) = x^2
The new code is:
Doing do a more complicated one:
So one can see that this kind of transform is complete.
We never call the original function.
We never perform any overloaded operations.
We only call
frule
and out transform generating function:do_forward_mode
.I believe ForwardDiff2 is meant to be using this approach,
and thus bugs have slipped in and caused it to not.
I suggest to make such bugs harder to slip through tests,
that
Dual
stop subtypingNumber
and stops defining operations like+
and*
.That way we know that it is correctly doing the rewrites.
Alternatively, this may be intentional.
The hybrid overloading + source code transform ForwardDiff2 is using right now
works pretty well if everything is ether a
Number
or anArray
.but we might need another forward mode AD package that is pure source to source to easily handle concretely typed structs, and code that is otherwise hostile to the overloading approach.
The text was updated successfully, but these errors were encountered: