Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The right way to implement rrule(broadcasted, f, args...) #531

Closed
dfdx opened this issue Sep 14, 2021 · 6 comments · Fixed by #644
Closed

The right way to implement rrule(broadcasted, f, args...) #531

dfdx opened this issue Sep 14, 2021 · 6 comments · Fixed by #644

Comments

@dfdx
Copy link
Contributor

dfdx commented Sep 14, 2021

In some other discussion I proposed a generic implementation of rrule for broadcasted, a slightly modified version of which looks like this (using rrule instead of rrule_via_ad for simplicity here):

# unzip taken from Zygote:
# https://github.com/FluxML/Zygote.jl/blob/d5be4d5ca80e79278d714eaac15ca71904a262e3/src/lib/array.jl#L177-L185
struct StaticGetter{i} end
(::StaticGetter{i})(v) where {i} = v[i]

@generated function _unzip(tuples, ::Val{N}) where {N}
  Expr(:tuple, (:(map($(StaticGetter{i}()), tuples)) for i  1:N)...)
end

function unzip(tuples)
  N = length(first(tuples))
  _unzip(tuples, Val(N))
end


function rrule(::typeof(Broadcast.broadcasted), f::F, args...) where F
    ys, pbs = unzip(rrule_via_ad.(f, args...))
    function pullback(Δ)
        dxs = map((pb, Δ) -> pb(Δ), pbs, Δ) |> unzip
        dxs = [all(dx .== NoTangent()) ? NoTangent() : dx for dx in dxs]
        return NoTangent(), dxs...
    end
    return ys, pullback
end

Empirically, I can see that it works correctly at least in simple cases, e.g.:

f = sin
xs = rand(2)

# manually get pullbacks for each element and apply them to seed 1.0
pbs = [rrule(f, x)[2] for x in xs]
dxs = [pbs[1](1.0)[2], pbs[2](1.0)[2]]

# use rrule for broadcasted
_, bcast_pb = rrule(Broadcast.broadcasted, f, xs)
dxs_bcast = bcast_pb(ones(2))[end]

@assert all(dxs .== dxs_bcast)

But when I run test_rrule(Broadcast.broadcasted, f, xs; check_inferred=false) I get a strange error:

test_rrule: broadcasted on typeof(sin),Vector{Float64}: Test Failed at /home/az/.julia/packages/ChainRulesTestUtils/Rzheq/src/check_result.jl:24
  Expression: isapprox(actual, expected; kwargs...)
  Problem:  Vector{ChainRulesCore.AbstractTangent}[1]
   Evaluated: isapprox(0.0, -8.379999999999999; rtol = 1.0e-9, atol = 1.0e-9)
Stacktrace:
 [1] test_approx(actual::Union{Number, AbstractArray{var"#s79", N} where {var"#s79"<:Number, N}}, expected::Union{Number, AbstractArray{var"#s87", N} where {var"#s87"<:Number, N}}, msg::Any; kwargs::Any)
   @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/Rzheq/src/check_result.jl:24
 [2] test_approx(::ChainRulesCore.AbstractZero, x::Any, msg::Any; kwargs::Any)
   @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/Rzheq/src/check_result.jl:33
 [3] test_approx(actual::AbstractArray, expected::AbstractArray, msg::Any; kwargs::Any)
   @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/Rzheq/src/check_result.jl:80
 [4] macro expansion
   @ ~/.julia/packages/ChainRulesTestUtils/Rzheq/src/testers.jl:238 [inlined]
 [5] macro expansion
   @ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/Test/src/Test.jl:1151 [inlined]
 [6] test_rrule(::ChainRulesCore.RuleConfig, ::Any, ::Any, ::Vararg{Any, N} where N; output_tangent::Any, check_thunked_output_tangent::Any, fdm::Any, rrule_f::Any, check_inferred::Bool, fkwargs::NamedTuple, rtol::Real, atol::Real, kwargs::Any)
   @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/Rzheq/src/testers.jl:194
test_rrule: broadcasted on typeof(sin),Vector{Float64}: Error During Test at /home/az/.julia/packages/ChainRulesTestUtils/Rzheq/src/testers.jl:191
  Got exception outside of a @test
  AssertionError: T <: NamedTuple
  Stacktrace:
    [1] test_approx(actual::ChainRulesCore.Tangent{Tuple{Vector{Float64}}, Tuple{Vector{Float64}}}, expected::Any, msg::Any; kwargs::Any)
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/Rzheq/src/check_result.jl:112
    [2] test_approx(actual::AbstractArray, expected::AbstractArray, msg::Any; kwargs::Any)
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/Rzheq/src/check_result.jl:80
    [3] macro expansion
      @ ~/.julia/packages/ChainRulesTestUtils/Rzheq/src/testers.jl:238 [inlined]
    [4] macro expansion
      @ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/Test/src/Test.jl:1151 [inlined]
    [5] test_rrule(::ChainRulesCore.RuleConfig, ::Any, ::Any, ::Vararg{Any, N} where N; output_tangent::Any, check_thunked_output_tangent::Any, fdm::Any, rrule_f::Any, check_inferred::Bool, fkwargs::NamedTuple, rtol::Real, atol::Real, kwargs::Any)
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/Rzheq/src/testers.jl:194
    [6] test_rrule(::Any, ::Vararg{Any, N} where N; kwargs::Any)
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/Rzheq/src/testers.jl:168
    [7] top-level scope
      @ REPL[20]:1
    [8] eval
      @ ./boot.jl:360 [inlined]
    [9] eval
      @ ./Base.jl:39 [inlined]
   [10] repleval(m::Module, code::Expr, #unused#::String)
      @ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.4.0/scripts/packages/VSCodeServer/src/repl.jl:157
   [11] (::VSCodeServer.var"#69#71"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.Prompt})()
      @ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.4.0/scripts/packages/VSCodeServer/src/repl.jl:123
   [12] with_logstate(f::Function, logstate::Any)
      @ Base.CoreLogging ./logging.jl:491
   [13] with_logger
      @ ./logging.jl:603 [inlined]
   [14] (::VSCodeServer.var"#68#70"{Module, Expr, REPL.LineEditREPL, REPL.LineEdit.Prompt})()
      @ VSCodeServer ~/.vscode/extensions/julialang.language-julia-1.4.0/scripts/packages/VSCodeServer/src/repl.jl:124
   [15] #invokelatest#2
      @ ./essentials.jl:708 [inlined]
   [16] invokelatest(::Any)
      @ Base ./essentials.jl:706
   [17] macro expansion
      @ ~/.vscode/extensions/julialang.language-julia-1.4.0/scripts/packages/VSCodeServer/src/eval.jl:34 [inlined]
   [18] (::VSCodeServer.var"#53#54")()
      @ VSCodeServer ./task.jl:411

Note 2 error messages:

Evaluated: isapprox(0.0, -8.379999999999999; rtol = 1.0e-9, atol = 1.0e-9)

and

AssertionError: T <: NamedTuple

Can you see a mistake in this implementation or is it just too complicated for test_rrule() to verify?

@mcabbott
Copy link
Member

This seems to give sensible answers. But the tester might be confused if it compares to Broadcast.broadcasted(log, [1,2,3]), which isn't collected? (I get different errors, but wasn't careful.)

Note that you will need something like Zygote's unbroadcast to deal with different shapes e.g. [1,2,3] .+ [4 5]. Note also that this can (often) be made more efficient with the new derivatives_given_output. And of course this kills broadcast fusion, but whether we can do better isn't clear.

@dfdx
Copy link
Contributor Author

dfdx commented Sep 15, 2021

So I guess the optimal approach would be to avoid rrule altogether and only use derivatives_given_output?

ys = f.(xs...)
...
dxs = derivatives_given_output.(ys, f, xs...)   # this already works for log()

I already have a very similar mechanism in Yota (example), so it would be pretty natural to extend it with derivatives_given_output(), but how stable is it? Like 50% chance it will disappear in the next release or like minor changes to the signature may land in the next 6 months? Also is there an API to check if derivatives_given_output / scalar rules exists for a particular function?

Regarding unbroadcast(), we can always have special rules to handle these special cases. For example, I do this.

@mcabbott
Copy link
Member

If the link works, this is my attempt at an unfused broadcasting rule.

This uses isconcretetype(Core.Compiler._return_type(derivatives_given_output, ... as a way to check if there is a method, and falls back to calling AD for functions without one. More discussion here about possible mechanisms (a bit long!): JuliaDiff/ChainRulesCore.jl#456 (comment) .

No idea how stable you should regard derivatives_given_output as being, it's brand new. But c.c. @piever whose idea this was.

I think the question of how much better a fused broadcast can be boils down to whether it can be made fast, and whether it can save memory. Pushing dual numbers through (like Zygote does) equates to quite a few copies, as does storing an array of closures. There might be ways to fuse some simple operations like f.(x.-y) while breaking arbitrary f.(g.(h.(x)))?

@oxinabox
Copy link
Member

No idea how stable you should regard derivatives_given_output as being

Not at all stable.

@dfdx
Copy link
Contributor Author

dfdx commented Sep 16, 2021

If the link works

The link opens a new PR page and lets me pretend to be you 😄 But yes, I can see the code and all the different paths you've considered 👍

Given my simple use case of handling functions like sin, tanh, etc. I infer the following from the discussion:

  • the best way to handle such functions is to use derivatives_given_output() or similar API when it becomes stable
  • until then it's easier to have rrule(broadcasted, f, args...) for each needed f separately (5-6 rules should cover 90% of cases for me)

I believe there's nothing more to do in this issue, so closing it. Thanks for sharing all this great stuff with me!

@dfdx dfdx closed this as completed Sep 16, 2021
@mcabbott
Copy link
Member

If you open the PR then you get to fix the bugs!

One more question is whether rules for broadcasting should live here in ChainRules. If they are defined elsewhere through rrule then they will tend to interfere with each other. But it's not so clear we know the optimal thing to do yet, so perhaps it's better to have different experiments. My link above is trying to be careful about memory but hasn't thought at all about higher order derivatives.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants