Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Broadcasting #644

Merged
merged 20 commits into from
Aug 9, 2022
Merged

Broadcasting #644

merged 20 commits into from
Aug 9, 2022

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Jul 12, 2022

This largely adapts JuliaDiff/Diffractor.jl#68 to ChainRules. Comments there still apply; in particular the split path uses derivatives_given_output when available, and StructArrays to avoid making temporary arrays of tuples. And there are fused rules for some cheap functions like +, -, etc.

What's new is that there's a path calling frule_via_ad. At the moment this is only for f.(x), but it could be extended, see comments. See also FluxML/Zygote.jl#1222 about this.

  • Zygote has had some failures. Its generic rule is @adjoint function broadcasted(::AbstractArrayStyle, f::F, args...) where {F} which is not called when rrule(config, ::typeof(broadcasted), f::F, args...) is defined, as the style is added later. Which means that this PR breaks Zygote right now. Opt out of CR broadcasting FluxML/Zygote.jl#1263 opts out of these new rules. Resolved by changing the signature.

  • Yota's tests pass with its own broadcast rules removed. Or they did, on one version, before changing the signature.

  • GPU arrays: With RFC: broadcasting Diffractor.jl#68 I believe simple ones worked. All paths except the most generic reverse mode save-the-pullbacks one seem to work, or will if tuplecast can be disabled when it sees AbstractGPUArrayStyle, which needs Move AbstractGPUArrayStyle to GPUArraysCore? JuliaGPU/GPUArrays.jl#417.

(Zygote dispatches earlier on a style to always take a CuArray path, which fails silently FluxML/Zygote.jl#1215.)

Closes JuliaDiff/Diffractor.jl#68 (alternative)
closes #531.
Closes FluxML/Zygote.jl#1263 (not required).

# and we don't know whether re-computing `y` is cheap.
# (We could check `f` first like `sum(f, x)` does, but checking whether `g` needs `y` is tricky.)

function rrule(cfg::RCR, ::typeof(broadcasted), f::F, args::Vararg{Any,N}) where {F,N}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This rule is applied before Zygote's own rule. To avoid that, it could be changed to this:

Suggested change
function rrule(cfg::RCR, ::typeof(broadcasted), f::F, args::Vararg{Any,N}) where {F,N}
function rrule(cfg::RCR, ::typeof(broadcasted), ::BroadcastStyle, f::F, args...) where {F}

This has two problems.

The first is that Yota no longer sees the rule at all. Trying to make it a specific rule which inserts the style is tricky, as this creates method ambiguities:

function ChainRulesCore.rrule(cfg::YotaRuleConfig, ::typeof(Broadcast.broadcasted), f::F, args::Vararg{Any,N}) where {F,N}
    bcs = Broadcast.BroadcastStyle(typeof(Broadcast.broadcasted(f, args...)))
    y, back = ChainRulesCore.rrule(cfg, Broadcast.broadcasted, bcs, f, args...)  # this one is ambiguous
...

@dfdx is there a better way?

The second problem is that this (slightly later) method has its arguments wrapped by broadcastable, which means many things need Ref:

julia> using Base.Broadcast: broadcasted, DefaultArrayStyle, broadcastable

julia> copy(broadcasted(|>, [1,2,3], inv))
3-element Vector{Float64}:
 1.0
 0.5
 0.3333333333333333

julia> copy(broadcasted(DefaultArrayStyle{1}(), |>, [1,2,3], Ref(inv)))  # error without Ref
3-element Vector{Float64}:
 1.0
 0.5
 0.3333333333333333

julia> broadcastable(inv)
Base.RefValue{typeof(inv)}(inv)

Which at the moment gives test errors like Tangent{Base.RefValue{typeof(sin)}}(x = NoTangent(),) isa NoTangent, lots of messy Tangents.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps better is for Zygote to be told not to use these rules, via @opt_out. Old versions of Zygote won't have that, and would perhaps need to get an upper version bound on ChainRules.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should be able to do a configured opt-out for ZygoteConfig

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which at the moment gives test errors like Tangent{Base.RefValue{typeof(sin)}}(x = NoTangent(),) isa NoTangent, lots of messy Tangents.

Can we add a simplify call at the end which converts all(isa(NoTangent), x::Tangent) into NoTangent() ?

Copy link
Member Author

@mcabbott mcabbott Jul 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have not yet tracked these down. It's possible some projection is missing, or does not simplify things which it should.

Edit: this is now JuliaDiff/ChainRulesCore.jl#565 , but not really needed here if we don't use broadcasted(::BroadcastStyle, ... methods.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, why does @opt_out define no_rrule to return nothing, which it already returns? That means you have to check the method table, but it seems you could equally well return true or NoTangent or something, which the caller can easily test. Can we fix that?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am going to need to refresh my memory on how and why @opt_out works the way it does. 😅

I think that possibly rrule_via_ad takes a shortcut which does not check no_rrule.

There is indeed a shortcut that doesn't check no_rule
When it goes through the path of Zygote using Zygote because these is no rule then it hits the normal zygote check.
But there is a fast path that avoids calling back in to zygote if there is a rule right there.
https://github.com/FluxML/Zygote.jl/blob/5c80f55b24d4060f5550b25201b4b288703d34b9/src/compiler/chainrules.jl#L243-L245

But that path does check if rrule returns nothing.
Which it should as, the @opt_out defines a more specific rrule that returns nothing for this case.
So there should be no need to check no_rrule.
But that doesn't hold if you bypass @opt_out, which is why the docs are specific about not being able to bypass @opt_out as you can't know if user is using the method table way, or the directly checking rrule way.

BTW, why does @opt_out define no_rrule to return nothing, which it already returns?

It doesn't return anything because it is just using the method table for storage.
But the user shouldn't need to check no_rrule because checking that rrule returns nothing rather than a result should be enough.

Using the macro in the obvious way to opt out of the generic rule (by matching its signature) produces ambiguities with every more specific rule

Aren't these real ambiguities?
Can you show some examples?

Copy link
Member Author

@mcabbott mcabbott Aug 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR defines one generic, and many specific rules:

rrule(::RuleConfig{>:HasReverseMode}, ::typeof(broadcasted), f::F, args::Vararg{Any,N})  # generic
rrule(::RuleConfig{>:HasReverseMode}, ::typeof(broadcasted), ::typeof(+), xs::Union{Number, Broadcasted, AbstractArray{<:Number}, Tuple{Vararg{Number, N}}}...) # specific

Attempting to opt out of the generic rule for Zygote, the macro produces this:

rrule(::Zygote.ZygoteRuleConfig, ::typeof(broadcasted), f::F, args::Vararg{Any, N}) where {F, N}

which is more specific on the 1st but less specific on the 3rd argument, sometimes.

The hope was to opt out of only the generic rule, since the rules for + etc. are similar enough not to matter. But the generic one is quite different, and e.g. skips Zygote's present CuArray handling... I'd rather not touch that yet.

Such an opt-out could be done if you are allowed to define only no_rrule (and the caller expected to check that). Note that the only caller here is Zygote, but still.

Maybe the Zygote PR does have to opt out of every rule here, I think that would remove ambiguity. The tricky question is what happens to 3rd party packages defining broadcasting rules. NNlib has

rrule(::typeof(broadcasted), ::typeof(relu), x::Numeric)

which won't in fact cause an ambiguity. But it also won't ever be called, with this PR, since the ::RuleConfig method will match first. It could be given a ::RuleConfig argument (as the specific rules are in this PR) but then the ambiguity will return.

The alternative suggested above is to move the generic rule one of this PR one step later in the pipeline, something like:

rrule(::RuleConfig{>:HasReverseMode}, ::typeof(broadcasted), ::BroadcastStyle, f::F, args::Vararg{Any,N})  # generic
rrule(::typeof(broadcasted), ::typeof(+), xs::Union{Number, Broadcasted, AbstractArray{<:Number}, Tuple{Vararg{Number, N}}}...) # specific

The ugly feature of this is that the arguments here have already been acted on with broadcastable, so there is a lot more dealing with Ref required.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

7e7d105 changes over to use BroadcastStyle.

I now think this is the right thing to do. Besides Zygote, the generic rule needs RuleConfig{>:HasReverseMode} while the specific rules do not. Naiively that would mean the specific rules never get called, so previously I gave them a RuleConfig argument only for this reason. But that's a bit weird, and e.g. NNlib's rules don't, https://github.com/FluxML/NNlib.jl/blob/master/src/activations.jl#L875-L880 .

Now, instead, when no specific rule matches, AD will continue & hit the rule with both RuleConfig and Broadcasted. (For some reason this didn't work with Yota, earlier, but surely fixable.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Zygote does appear to occasionally call some of these rules, in its tests. Perhaps because of rrule_via_ad? But they aren't so different, and I think tests pass.

src/tuplecast.jl Outdated Show resolved Hide resolved
@mcabbott mcabbott force-pushed the broadcast branch 2 times, most recently from f6e27aa to f62a756 Compare August 7, 2022 00:30
@@ -13,6 +13,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RealDot = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't forget to add compat for this. and up the version number

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

src/unzipped.jl Outdated
Comment on lines 63 to 94
"""
unzip_map(f, args...)

For a function `f` which returns a tuple, this is `== unzip(map(f, args...))`,
but performed using `StructArrays` for efficiency.

Not in use at present, but see `unzip_broadcast`.
"""
function unzip_map(f::F, args...) where {F}
T = Broadcast.combine_eltypes(f, args)
if isconcretetype(T)
T <: Tuple || throw(ArgumentError("""unzip_map(f, args) only works on functions returning a tuple,
but f = $(sprint(show, f)) returns type T = $T"""))
end
# if any(a -> a isa CuArray, args)
# return unzip(map(f, args...))
# end
return StructArrays.components(StructArray(Iterators.map(f, args...)))
end

function ChainRulesCore.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(unzip_map), f::F, xs...) where {F}
y, back = rrule_via_ad(cfg, map, f, xs...)
z = unzip(y)
function ununzip_map(dz)
# dy = StructArray(map(unthunk, dz)) # fails for e.g. StructArray(([1,2,3], ZeroTangent()))
dy = broadcast(tuple, map(unthunk, dz)...)
return back(dy)
end
return z, ununzip_map
end

=#
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete this? and move it to an issue maybe?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Will want something like it for map etc. at some point. Somewhere I thought I had a better version anyway.

Copy link
Member

@oxinabox oxinabox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yay, merge when happy

Copy link
Member Author

@mcabbott mcabbott left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, will merge when tests pass.

Am sure there will be things to fix, but nothing relies on this just yet. Will revisit when I return to JuliaDiff/Diffractor.jl#73

@@ -13,6 +13,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RealDot = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines 375 to 381
# Fallback

function unbroadcast(x, dx)
@info "last unbroadcast method!" x dx
dx isa AbstractZero && return dx
p = ProjectTo(x)
if p isa ProjectTo{<:AbstractZero}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No tests currently hit this fallback method. I'm removing it as I can't picture how to hit it: Now that the generic rule uses ::BroadcastStyle, weird arguments are going to be wrapped in Ref, hence hit another methog.

src/unzipped.jl Outdated
Comment on lines 63 to 94
"""
unzip_map(f, args...)

For a function `f` which returns a tuple, this is `== unzip(map(f, args...))`,
but performed using `StructArrays` for efficiency.

Not in use at present, but see `unzip_broadcast`.
"""
function unzip_map(f::F, args...) where {F}
T = Broadcast.combine_eltypes(f, args)
if isconcretetype(T)
T <: Tuple || throw(ArgumentError("""unzip_map(f, args) only works on functions returning a tuple,
but f = $(sprint(show, f)) returns type T = $T"""))
end
# if any(a -> a isa CuArray, args)
# return unzip(map(f, args...))
# end
return StructArrays.components(StructArray(Iterators.map(f, args...)))
end

function ChainRulesCore.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(unzip_map), f::F, xs...) where {F}
y, back = rrule_via_ad(cfg, map, f, xs...)
z = unzip(y)
function ununzip_map(dz)
# dy = StructArray(map(unthunk, dz)) # fails for e.g. StructArray(([1,2,3], ZeroTangent()))
dy = broadcast(tuple, map(unthunk, dz)...)
return back(dy)
end
return z, ununzip_map
end

=#
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Will want something like it for map etc. at some point. Somewhere I thought I had a better version anyway.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

The right way to implement rrule(broadcasted, f, args...)
3 participants