-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Broadcasting #644
Broadcasting #644
Conversation
src/rulesets/Base/broadcast.jl
Outdated
# and we don't know whether re-computing `y` is cheap. | ||
# (We could check `f` first like `sum(f, x)` does, but checking whether `g` needs `y` is tricky.) | ||
|
||
function rrule(cfg::RCR, ::typeof(broadcasted), f::F, args::Vararg{Any,N}) where {F,N} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This rule is applied before Zygote's own rule. To avoid that, it could be changed to this:
function rrule(cfg::RCR, ::typeof(broadcasted), f::F, args::Vararg{Any,N}) where {F,N} | |
function rrule(cfg::RCR, ::typeof(broadcasted), ::BroadcastStyle, f::F, args...) where {F} |
This has two problems.
The first is that Yota no longer sees the rule at all. Trying to make it a specific rule which inserts the style is tricky, as this creates method ambiguities:
function ChainRulesCore.rrule(cfg::YotaRuleConfig, ::typeof(Broadcast.broadcasted), f::F, args::Vararg{Any,N}) where {F,N}
bcs = Broadcast.BroadcastStyle(typeof(Broadcast.broadcasted(f, args...)))
y, back = ChainRulesCore.rrule(cfg, Broadcast.broadcasted, bcs, f, args...) # this one is ambiguous
...
@dfdx is there a better way?
The second problem is that this (slightly later) method has its arguments wrapped by broadcastable
, which means many things need Ref
:
julia> using Base.Broadcast: broadcasted, DefaultArrayStyle, broadcastable
julia> copy(broadcasted(|>, [1,2,3], inv))
3-element Vector{Float64}:
1.0
0.5
0.3333333333333333
julia> copy(broadcasted(DefaultArrayStyle{1}(), |>, [1,2,3], Ref(inv))) # error without Ref
3-element Vector{Float64}:
1.0
0.5
0.3333333333333333
julia> broadcastable(inv)
Base.RefValue{typeof(inv)}(inv)
Which at the moment gives test errors like Tangent{Base.RefValue{typeof(sin)}}(x = NoTangent(),) isa NoTangent
, lots of messy Tangents.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps better is for Zygote to be told not to use these rules, via @opt_out
. Old versions of Zygote won't have that, and would perhaps need to get an upper version bound on ChainRules.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should be able to do a configured opt-out for ZygoteConfig
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Which at the moment gives test errors like Tangent{Base.RefValue{typeof(sin)}}(x = NoTangent(),) isa NoTangent, lots of messy Tangents.
Can we add a simplify
call at the end which converts all(isa(NoTangent), x::Tangent)
into NoTangent()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have not yet tracked these down. It's possible some projection is missing, or does not simplify things which it should.
Edit: this is now JuliaDiff/ChainRulesCore.jl#565 , but not really needed here if we don't use broadcasted(::BroadcastStyle, ...
methods.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW, why does @opt_out
define no_rrule
to return nothing
, which it already returns? That means you have to check the method table, but it seems you could equally well return true
or NoTangent
or something, which the caller can easily test. Can we fix that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am going to need to refresh my memory on how and why @opt_out
works the way it does. 😅
I think that possibly rrule_via_ad takes a shortcut which does not check no_rrule.
There is indeed a shortcut that doesn't check no_rule
When it goes through the path of Zygote using Zygote because these is no rule then it hits the normal zygote check.
But there is a fast path that avoids calling back in to zygote if there is a rule right there.
https://github.com/FluxML/Zygote.jl/blob/5c80f55b24d4060f5550b25201b4b288703d34b9/src/compiler/chainrules.jl#L243-L245
But that path does check if rrule
returns nothing
.
Which it should as, the @opt_out
defines a more specific rrule
that returns nothing
for this case.
So there should be no need to check no_rrule
.
But that doesn't hold if you bypass @opt_out
, which is why the docs are specific about not being able to bypass @opt_out
as you can't know if user is using the method table way, or the directly checking rrule
way.
BTW, why does
@opt_out
defineno_rrule
to return nothing, which it already returns?
It doesn't return anything because it is just using the method table for storage.
But the user shouldn't need to check no_rrule
because checking that rrule
returns nothing
rather than a result should be enough.
Using the macro in the obvious way to opt out of the generic rule (by matching its signature) produces ambiguities with every more specific rule
Aren't these real ambiguities?
Can you show some examples?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR defines one generic, and many specific rules:
rrule(::RuleConfig{>:HasReverseMode}, ::typeof(broadcasted), f::F, args::Vararg{Any,N}) # generic
rrule(::RuleConfig{>:HasReverseMode}, ::typeof(broadcasted), ::typeof(+), xs::Union{Number, Broadcasted, AbstractArray{<:Number}, Tuple{Vararg{Number, N}}}...) # specific
Attempting to opt out of the generic rule for Zygote, the macro produces this:
rrule(::Zygote.ZygoteRuleConfig, ::typeof(broadcasted), f::F, args::Vararg{Any, N}) where {F, N}
which is more specific on the 1st but less specific on the 3rd argument, sometimes.
The hope was to opt out of only the generic rule, since the rules for +
etc. are similar enough not to matter. But the generic one is quite different, and e.g. skips Zygote's present CuArray handling... I'd rather not touch that yet.
Such an opt-out could be done if you are allowed to define only no_rrule
(and the caller expected to check that). Note that the only caller here is Zygote, but still.
Maybe the Zygote PR does have to opt out of every rule here, I think that would remove ambiguity. The tricky question is what happens to 3rd party packages defining broadcasting rules. NNlib has
rrule(::typeof(broadcasted), ::typeof(relu), x::Numeric)
which won't in fact cause an ambiguity. But it also won't ever be called, with this PR, since the ::RuleConfig
method will match first. It could be given a ::RuleConfig
argument (as the specific rules are in this PR) but then the ambiguity will return.
The alternative suggested above is to move the generic rule one of this PR one step later in the pipeline, something like:
rrule(::RuleConfig{>:HasReverseMode}, ::typeof(broadcasted), ::BroadcastStyle, f::F, args::Vararg{Any,N}) # generic
rrule(::typeof(broadcasted), ::typeof(+), xs::Union{Number, Broadcasted, AbstractArray{<:Number}, Tuple{Vararg{Number, N}}}...) # specific
The ugly feature of this is that the arguments here have already been acted on with broadcastable
, so there is a lot more dealing with Ref
required.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
7e7d105 changes over to use BroadcastStyle.
I now think this is the right thing to do. Besides Zygote, the generic rule needs RuleConfig{>:HasReverseMode}
while the specific rules do not. Naiively that would mean the specific rules never get called, so previously I gave them a RuleConfig argument only for this reason. But that's a bit weird, and e.g. NNlib's rules don't, https://github.com/FluxML/NNlib.jl/blob/master/src/activations.jl#L875-L880 .
Now, instead, when no specific rule matches, AD will continue & hit the rule with both RuleConfig and Broadcasted. (For some reason this didn't work with Yota, earlier, but surely fixable.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Zygote does appear to occasionally call some of these rules, in its tests. Perhaps because of rrule_via_ad
? But they aren't so different, and I think tests pass.
f6e27aa
to
f62a756
Compare
@@ -13,6 +13,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | |||
RealDot = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" | |||
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" | |||
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" | |||
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't forget to add compat for this. and up the version number
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
src/unzipped.jl
Outdated
""" | ||
unzip_map(f, args...) | ||
|
||
For a function `f` which returns a tuple, this is `== unzip(map(f, args...))`, | ||
but performed using `StructArrays` for efficiency. | ||
|
||
Not in use at present, but see `unzip_broadcast`. | ||
""" | ||
function unzip_map(f::F, args...) where {F} | ||
T = Broadcast.combine_eltypes(f, args) | ||
if isconcretetype(T) | ||
T <: Tuple || throw(ArgumentError("""unzip_map(f, args) only works on functions returning a tuple, | ||
but f = $(sprint(show, f)) returns type T = $T""")) | ||
end | ||
# if any(a -> a isa CuArray, args) | ||
# return unzip(map(f, args...)) | ||
# end | ||
return StructArrays.components(StructArray(Iterators.map(f, args...))) | ||
end | ||
|
||
function ChainRulesCore.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(unzip_map), f::F, xs...) where {F} | ||
y, back = rrule_via_ad(cfg, map, f, xs...) | ||
z = unzip(y) | ||
function ununzip_map(dz) | ||
# dy = StructArray(map(unthunk, dz)) # fails for e.g. StructArray(([1,2,3], ZeroTangent())) | ||
dy = broadcast(tuple, map(unthunk, dz)...) | ||
return back(dy) | ||
end | ||
return z, ununzip_map | ||
end | ||
|
||
=# |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete this? and move it to an issue maybe?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Will want something like it for map
etc. at some point. Somewhere I thought I had a better version anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yay, merge when happy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, will merge when tests pass.
Am sure there will be things to fix, but nothing relies on this just yet. Will revisit when I return to JuliaDiff/Diffractor.jl#73
@@ -13,6 +13,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | |||
RealDot = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" | |||
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" | |||
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" | |||
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
src/rulesets/Base/broadcast.jl
Outdated
# Fallback | ||
|
||
function unbroadcast(x, dx) | ||
@info "last unbroadcast method!" x dx | ||
dx isa AbstractZero && return dx | ||
p = ProjectTo(x) | ||
if p isa ProjectTo{<:AbstractZero} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No tests currently hit this fallback method. I'm removing it as I can't picture how to hit it: Now that the generic rule uses ::BroadcastStyle
, weird arguments are going to be wrapped in Ref
, hence hit another methog.
src/unzipped.jl
Outdated
""" | ||
unzip_map(f, args...) | ||
|
||
For a function `f` which returns a tuple, this is `== unzip(map(f, args...))`, | ||
but performed using `StructArrays` for efficiency. | ||
|
||
Not in use at present, but see `unzip_broadcast`. | ||
""" | ||
function unzip_map(f::F, args...) where {F} | ||
T = Broadcast.combine_eltypes(f, args) | ||
if isconcretetype(T) | ||
T <: Tuple || throw(ArgumentError("""unzip_map(f, args) only works on functions returning a tuple, | ||
but f = $(sprint(show, f)) returns type T = $T""")) | ||
end | ||
# if any(a -> a isa CuArray, args) | ||
# return unzip(map(f, args...)) | ||
# end | ||
return StructArrays.components(StructArray(Iterators.map(f, args...))) | ||
end | ||
|
||
function ChainRulesCore.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(unzip_map), f::F, xs...) where {F} | ||
y, back = rrule_via_ad(cfg, map, f, xs...) | ||
z = unzip(y) | ||
function ununzip_map(dz) | ||
# dy = StructArray(map(unthunk, dz)) # fails for e.g. StructArray(([1,2,3], ZeroTangent())) | ||
dy = broadcast(tuple, map(unthunk, dz)...) | ||
return back(dy) | ||
end | ||
return z, ununzip_map | ||
end | ||
|
||
=# |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Will want something like it for map
etc. at some point. Somewhere I thought I had a better version anyway.
This largely adapts JuliaDiff/Diffractor.jl#68 to ChainRules. Comments there still apply; in particular the split path uses
derivatives_given_output
when available, and StructArrays to avoid making temporary arrays of tuples. And there are fused rules for some cheap functions like+
,-
, etc.What's new is that there's a path calling
frule_via_ad
. At the moment this is only forf.(x)
, but it could be extended, see comments. See also FluxML/Zygote.jl#1222 about this.Zygote
hashad some failures. Its generic rule is@adjoint function broadcasted(::AbstractArrayStyle, f::F, args...) where {F}
which is not called whenrrule(config, ::typeof(broadcasted), f::F, args...)
is defined, as the style is added later.Which means that this PR breaks Zygote right now. Opt out of CR broadcasting FluxML/Zygote.jl#1263 opts out of these new rules.Resolved by changing the signature.Yota's tests pass with its own broadcast rules removed.Or they did, on one version, before changing the signature.GPU arrays:
With RFC: broadcasting Diffractor.jl#68 I believe simple ones worked.All paths except the most generic reverse mode save-the-pullbacks one seem to work, or will iftuplecast
can be disabled when it seesAbstractGPUArrayStyle
, which needs MoveAbstractGPUArrayStyle
toGPUArraysCore
? JuliaGPU/GPUArrays.jl#417.(Zygote dispatches earlier on a style to always take a CuArray path, which fails silently FluxML/Zygote.jl#1215.)
Closes JuliaDiff/Diffractor.jl#68 (alternative)
closes #531.
Closes FluxML/Zygote.jl#1263 (not required).