diff --git a/src/lib/number.jl b/src/lib/number.jl index ced2d0ca9..b05100b40 100644 --- a/src/lib/number.jl +++ b/src/lib/number.jl @@ -20,8 +20,13 @@ end @adjoint (T::Type{<:Complex})(re, im) = T(re, im), c̄ -> (nothing, real(c̄), imag(c̄)) +# we define these here because ChainRules.jl only defines them for x::Union{Real,Complex} + +@adjoint abs2(x::Number) = abs2(x), Δ -> (real(Δ)*(x + x),) @adjoint real(x::Number) = real(x), r̄ -> (real(r̄),) @adjoint conj(x::Number) = conj(x), r̄ -> (conj(r̄),) @adjoint imag(x::Number) = imag(x), ī -> (real(ī)*im,) -@adjoint abs2(x::Number) = abs2(x), Δ -> (real(Δ)*(x + x),) +# for real x, ChainRules pulls back a zero real adjoint, whereas we treat x +# as embedded in the complex numbers and pull back a pure imaginary adjoint +@adjoint imag(x::Real) = zero(x), ī -> (real(ī)*im,)