Skip to content

Commit

Permalink
Merge #721
Browse files Browse the repository at this point in the history
721: Use ChainRules v0.7.0 r=oxinabox a=sethaxen

ChainRules v0.7.0 introduced a new convention for complex numbers. It is no longer necessary for Zygote to conjugate sensitivities sent to and received from ChainRules. This PR is a non-breaking change that uses ChainRules v0.7.0.

Fixes JuliaDiff/ChainRules.jl#210 and supersedes #720.

Co-authored-by: Seth Axen <[email protected]>
  • Loading branch information
bors[bot] and sethaxen authored Jun 30, 2020
2 parents 7890b75 + 6a2bf00 commit 27da443
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 12 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
[compat]
AbstractFFTs = "0.5"
ArrayLayouts = "0.1, 0.2, 0.3"
ChainRules = "0.6.0"
ChainRules = "0.7.0"
DiffRules = "1.0"
FillArrays = "0.8"
ForwardDiff = "0"
Expand Down
10 changes: 4 additions & 6 deletions src/compiler/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,9 @@ is_kwfunc(k, ::Type{<:NamedTuple}, f, args...) = k===Core.kwftype(f)
"""
wrap_chainrules_output(x)
Convert `x` from the differentials types ChainRules uses to the format Zygote uses internally
(including conjugating complex gradients).
Convert `x` from the differentials types ChainRules uses to the format Zygote uses internally.
"""
@inline wrap_chainrules_output(x) = conj(unthunk(x)) # For now we are just not going to deal with thunks
@inline wrap_chainrules_output(x) = unthunk(x) # For now we are just not going to deal with thunks
@inline wrap_chainrules_output(x::Tuple) = map(wrap_chainrules_output, x)
@inline wrap_chainrules_output(x::ChainRules.AbstractZero) = nothing
for T_outer in (:Tuple, :NamedTuple)
Expand All @@ -55,10 +54,9 @@ end
"""
wrap_chainrules_input(x)
Convert `x` from the format Zygote uses internally (including conjugated complex gradients)
to differentials types ChainRules uses.
Convert `x` from the format Zygote uses internally to differentials types ChainRules uses.
"""
@inline wrap_chainrules_input(x) = conj(x)
@inline wrap_chainrules_input(x) = x
@inline wrap_chainrules_input(::Nothing) = ChainRules.Zero()
@inline function wrap_chainrules_input(xs::Union{Tuple, NamedTuple})
xp = map(wrap_chainrules_input, xs)
Expand Down
11 changes: 6 additions & 5 deletions src/lib/number.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ end

@adjoint (T::Type{<:Complex})(re, im) = T(re, im), c̄ -> (nothing, real(c̄), imag(c̄))

# we define these here because ChainRules.jl only defines them for x::Union{Real,Complex}

@adjoint abs2(x::Number) = abs2(x), Δ -> (real(Δ)*(x + x),)
@adjoint real(x::Number) = real(x), r̄ -> (real(r̄),)
@adjoint conj(x::Number) = conj(x), r̄ -> (conj(r̄),)
@adjoint imag(x::Number) = imag(x), ī -> (real(ī)*im,)

# we intentionally define these here rather than falling back on ChainRules.jl
# because ChainRules doesn't really handle nonanalytic complex functions
@adjoint abs(x::Real) = abs(x), Δ -> (real(Δ)*sign(x),)
@adjoint abs(x::Complex) = abs(x), Δ -> (real(Δ)*x/abs(x),)
@adjoint abs2(x::Number) = abs2(x), Δ -> (real(Δ)*(x + x),)
# for real x, ChainRules pulls back a zero real adjoint, whereas we treat x
# as embedded in the complex numbers and pull back a pure imaginary adjoint
@adjoint imag(x::Real) = zero(x), ī -> (real(ī)*im,)

0 comments on commit 27da443

Please sign in to comment.