Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add missing complex tests and rules #216

Merged
merged 31 commits into from
Jun 30, 2020
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
90a369d
Fix indentation
sethaxen Jun 28, 2020
a49187b
Test \ on complex inputs
sethaxen Jun 28, 2020
aa26f04
Test ^ on complex inputs
sethaxen Jun 28, 2020
e47afd2
Test identity on complex inputs
sethaxen Jun 28, 2020
beaae0c
Test muladd on complex inputs
sethaxen Jun 28, 2020
b89b588
Test binary functions on complex inputs
sethaxen Jun 28, 2020
4738c48
Test functions on complex inputs
sethaxen Jun 28, 2020
0f34193
Release type constraint on exp
sethaxen Jun 28, 2020
0dd4023
Add _realconjtimes
sethaxen Jun 28, 2020
d737cdf
Use _realconjtimes in abs/abs2 rules
sethaxen Jun 28, 2020
e5276e6
Add complex rule for hypot
sethaxen Jun 28, 2020
ac58495
Add generic rule for adjoint
sethaxen Jun 28, 2020
7f6c709
Add generic rule for real
sethaxen Jun 28, 2020
b5fef9e
Add generic rule for imag
sethaxen Jun 28, 2020
45ba9b7
Add complex rule for hypot
sethaxen Jun 28, 2020
5971f4f
Add rules/tests for Complex
sethaxen Jun 28, 2020
45b2edc
Test frule for identity
sethaxen Jun 28, 2020
c678197
Add missing angle test
sethaxen Jun 28, 2020
3b3d11d
Make inline just in case
sethaxen Jun 29, 2020
f6f7c8a
Unify abs rules
sethaxen Jun 29, 2020
16f8307
Introduce _imagconjtimes utility function
sethaxen Jun 29, 2020
86dac82
Unify angle rules
sethaxen Jun 29, 2020
071b658
Unify sign rules
sethaxen Jun 29, 2020
3ff0457
Merge branch 'master' into complextests
sethaxen Jun 29, 2020
e0dd0d3
Multiply by correct variable
sethaxen Jun 29, 2020
2ce61d4
Fix argument order
sethaxen Jun 29, 2020
70f6d70
Merge branch 'complextests' of https://github.com/sethaxen/ChainRules…
sethaxen Jun 29, 2020
1c97506
Bump ChainRulesTestUtils version number
sethaxen Jun 29, 2020
c286f41
Restrict to Complex
sethaxen Jun 30, 2020
980c413
Use muladd
sethaxen Jun 30, 2020
34419ce
Update src/rulesets/Base/fastmath_able.jl
sethaxen Jun 30, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/ChainRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ if VERSION < v"1.3.0-DEV.142"
import LinearAlgebra: dot
end

include("rulesets/Base/utils.jl")
include("rulesets/Base/base.jl")
include("rulesets/Base/fastmath_able.jl")
include("rulesets/Base/array.jl")
Expand Down
69 changes: 67 additions & 2 deletions src/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,80 @@

@scalar_rule one(x) zero(x)
@scalar_rule zero(x) zero(x)
@scalar_rule adjoint(x::Real) One()
@scalar_rule transpose(x) One()

# `adjoint`

frule((_, Δz), ::typeof(adjoint), z::Number) = (z', Δz')

function rrule(::typeof(adjoint), z::Number)
adjoint_pullback(ΔΩ) = (NO_FIELDS, ΔΩ')
return (z', adjoint_pullback)
end

# `real`

@scalar_rule real(x::Real) One()

frule((_, Δz), ::typeof(real), z::Number) = (real(z), real(Δz))

function rrule(::typeof(real), z::Number)
# add zero(z) to embed the real number in the same number type as z
real_pullback(ΔΩ) = (NO_FIELDS, real(ΔΩ) + zero(z))
return (real(z), real_pullback)
end

# `imag`

@scalar_rule imag(x::Real) Zero()

frule((_, Δz), ::typeof(imag), z::Number) = (imag(z), imag(Δz))

function rrule(::typeof(imag), z::Complex)
imag_pullback(ΔΩ) = (NO_FIELDS, real(ΔΩ) * im)
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
return (imag(z), imag_pullback)
end

# `Complex`

frule((_, Δz), ::Type{T}, z::Number) where {T<:Complex} = (T(z), Complex(Δz))
function frule((_, Δx, Δy), ::Type{T}, x::Number, y::Number) where {T<:Complex}
return (T(x, y), Complex(Δx, Δy))
end

function rrule(::Type{T}, z::Complex) where {T<:Complex}
Complex_pullback(ΔΩ) = (NO_FIELDS, Complex(ΔΩ))
return (T(z), Complex_pullback)
end
function rrule(::Type{T}, x::Real) where {T<:Complex}
Complex_pullback(ΔΩ) = (NO_FIELDS, real(ΔΩ))
return (T(x), Complex_pullback)
end
function rrule(::Type{T}, x::Number, y::Number) where {T<:Complex}
Complex_pullback(ΔΩ) = (NO_FIELDS, real(ΔΩ), imag(ΔΩ))
return (T(x, y), Complex_pullback)
end

# `hypot`

@scalar_rule hypot(x::Real) sign(x)

function frule((_, Δz), ::typeof(hypot), z::Complex)
Ω = hypot(z)
∂Ω = _realconjtimes(z, Δz) / ifelse(iszero(Ω), one(Ω), Ω)
return Ω, ∂Ω
end

function rrule(::typeof(hypot), z::Complex)
Ω = hypot(z)
function hypot_pullback(ΔΩ)
return (NO_FIELDS, (real(ΔΩ) / ifelse(iszero(Ω), one(Ω), Ω)) * z)
end
return (Ω, hypot_pullback)
end

@scalar_rule fma(x, y, z) (y, x, One())
@scalar_rule muladd(x, y, z) (y, x, One())
@scalar_rule real(x::Real) One()
@scalar_rule rem2pi(x, r::RoundingMode) (One(), DoesNotExist())
@scalar_rule(
mod(x, y),
Expand Down
36 changes: 28 additions & 8 deletions src/rulesets/Base/fastmath_able.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ let
@scalar_rule cbrt(x) inv(3 * Ω ^ 2)
@scalar_rule inv(x) -(Ω ^ 2)
@scalar_rule sqrt(x) inv(2Ω)
@scalar_rule exp(x::Real) Ω
@scalar_rule exp(x) Ω
@scalar_rule exp10(x) Ω * log(oftype(x, 10))
@scalar_rule exp2(x) Ω * log(oftype(x, 2))
@scalar_rule expm1(x) exp(x)
Expand All @@ -42,7 +42,7 @@ let
end
function frule((_, Δz), ::typeof(abs), z::Complex)
Ω = abs(z)
return Ω, (real(z) * real(Δz) + imag(z) * imag(Δz)) / ifelse(iszero(z), one(Ω), Ω)
return Ω, _realconjtimes(z, Δz) / ifelse(iszero(z), one(Ω), Ω)
# `ifelse` is applied only to denominator to ensure type-stability.
end

Expand All @@ -63,11 +63,8 @@ let
end

## abs2
function frule((_, Δx), ::typeof(abs2), x::Real)
return abs2(x), 2x * real(Δx)
end
function frule((_, Δz), ::typeof(abs2), z::Complex)
return abs2(z), 2 * (real(z) * real(Δz) + imag(z) * imag(Δz))
function frule((_, Δz), ::typeof(abs2), z::Union{Real, Complex})
return abs2(z), 2 * _realconjtimes(z, Δz)
end

function rrule(::typeof(abs2), z::Union{Real, Complex})
Expand Down Expand Up @@ -126,7 +123,30 @@ let
end

# Binary functions
@scalar_rule hypot(x::Real, y::Real) (x / Ω, y / Ω)

# `hypot`

function frule(
(_, Δx, Δy),
::typeof(hypot),
x::T,
y::T,
) where {T<:Union{Real,Complex}}
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
Ω = hypot(x, y)
n = ifelse(iszero(Ω), one(Ω), Ω)
∂Ω = (_realconjtimes(x, Δx) + _realconjtimes(y, Δy)) / n
return Ω, ∂Ω
end

function rrule(::typeof(hypot), x::T, y::T) where {T<:Union{Real,Complex}}
Ω = hypot(x, y)
function hypot_pullback(ΔΩ)
c = real(ΔΩ) / ifelse(iszero(Ω), one(Ω), Ω)
return (NO_FIELDS, @thunk(c * x), @thunk(c * y))
end
return (Ω, hypot_pullback)
end

@scalar_rule x + y (One(), One())
@scalar_rule x - y (One(), -1)
@scalar_rule x / y (inv(y), -((x / y) / y))
Expand Down
5 changes: 5 additions & 0 deletions src/rulesets/Base/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# real(x * conj(y)) avoiding computing the imaginary part
_realconjtimes(x, y) = real(x) * real(y) + imag(x) * imag(y)
_realconjtimes(x::Real, y) = x * real(y)
_realconjtimes(x, y::Real) = real(x) * y
_realconjtimes(x::Real, y::Real) = x * y
88 changes: 58 additions & 30 deletions test/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,21 +46,31 @@
end # Trig

@testset "Angles" begin
for x in (-0.1, 6.4)
for x in (-0.1, 6.4, 0.5 + 0.25im)
test_scalar(deg2rad, x)
test_scalar(rad2deg, x)
end
end

@testset "Unary complex functions" begin
for x in (-4.1, 6.4)
for x in (-4.1, 6.4, 0.0, 0.0 + 0.0im, 0.5 + 0.25im)
test_scalar(real, x)
test_scalar(imag, x)
test_scalar(hypot, x)
test_scalar(adjoint, x)
end
end

@testset "Complex" begin
test_scalar(Complex, randn())
test_scalar(Complex, randn(ComplexF64))
x, ẋ, x̄ = randn(3)
y, ẏ, ȳ = randn(3)
Δz = randn(ComplexF64)
frule_test(Complex, (x, ẋ), (y, ẏ))
rrule_test(Complex, Δz, (x, x̄), (y, ȳ))
end

@testset "*(x, y) (scalar)" begin
# This is pretty important so testing it fairly heavily
test_points = (0.0, -2.1, 3.2, 3.7+2.12im, 14.2-7.1im)
Expand Down Expand Up @@ -95,44 +105,51 @@
@test extern(dy) == extern(zeros(2, 5) .+ dy)
end

@testset "ldexp" begin
x, Δx, x̄ = 10rand(3)
Δz = rand()
@testset "ldexp" begin
x, Δx, x̄ = 10rand(3)
Δz = rand()

for n in (0,1,20)
# TODO: Forward test does not work when parameter is Integer
# See: https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/22
#frule_test(ldexp, (x, Δx), (n, nothing))
rrule_test(ldexp, Δz, (x, x̄), (n, nothing))
end
end

for n in (0,1,20)
# TODO: Forward test does not work when parameter is Integer
# See: https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/22
#frule_test(ldexp, (x, Δx), (n, nothing))
rrule_test(ldexp, Δz, (x, x̄), (n, nothing))
end
end
@testset "\\(x::$T, y::$T) (scalar)" for T in (Float64, ComplexF64)
x, ẋ, x̄, y, ẏ, ȳ, Δz = randn(T, 7)
frule_test(*, (x, ẋ), (y, ẏ))
rrule_test(*, Δz, (x, x̄), (y, ȳ))
end

@testset "binary function ($f)" for f in (mod, \)
@testset "mod" begin
x, Δx, x̄ = 10rand(3)
y, Δy, ȳ = rand(3)
Δz = rand()

frule_test(f, (x, Δx), (y, Δy))
rrule_test(f, Δz, (x, x̄), (y, ȳ))
frule_test(mod, (x, Δx), (y, Δy))
rrule_test(mod, Δz, (x, x̄), (y, ȳ))
end

@testset "x^n for x<0" begin
x = -15*rand()
Δx, x̄ = 10rand(2)
y, Δy, ȳ = rand(3)
Δz = rand()
@testset "^(x::$T, n::$T)" for T in (Float64, ComplexF64)
# for real x and n, x must be >0
x = T <: Real ? 15rand() : 15randn(ComplexF64)
Δx, x̄ = 10rand(T, 2)
y, Δy, ȳ = rand(T, 3)
Δz = rand(T)

frule_test(^, (-x, Δx), (y, Δy))
rrule_test(^, Δz, (-x, x̄), (y, ȳ))
frule_test(^, (x, Δx), (y, Δy))
rrule_test(^, Δz, (x, x̄), (y, ȳ))
end

@testset "identity" begin
rrule_test(identity, randn(), (randn(), randn()))
rrule_test(identity, randn(4), (randn(4), randn(4)))
@testset "identity" for T in (Float64, ComplexF64)
rrule_test(identity, randn(T), (randn(T), randn(T)))
rrule_test(identity, randn(T, 4), (randn(T, 4), randn(T, 4)))

rrule_test(
identity, Tuple(randn(3)),
(Composite{Tuple}(randn(3)...), Composite{Tuple}(randn(3)...))
identity, Tuple(randn(T, 3)),
(Composite{Tuple}(randn(T, 3)...), Composite{Tuple}(randn(T, 3)...))
)
end

Expand All @@ -141,15 +158,26 @@
test_scalar(zero, x)
end

@testset "trinary ($f)" for f in (muladd, fma)
@testset "muladd(x::$T, y::$T, z::$T)" for T in (Float64, ComplexF64)
x, Δx, x̄ = 10randn(T, 3)
y, Δy, ȳ = randn(T, 3)
z, Δz, z̄ = randn(T, 3)
Δk = randn(T)

frule_test(muladd, (x, Δx), (y, Δy), (z, Δz))
rrule_test(muladd, Δk, (x, x̄), (y, ȳ), (z, z̄))
end

@testset "fma" begin
x, Δx, x̄ = 10randn(3)
y, Δy, ȳ = randn(3)
z, Δz, z̄ = randn(3)
Δk = randn()

frule_test(f, (x, Δx), (y, Δy), (z, Δz))
rrule_test(f, Δk, (x, x̄), (y, ȳ), (z, z̄))
frule_test(fma, (x, Δx), (y, Δy), (z, Δz))
rrule_test(fma, Δk, (x, x̄), (y, ȳ), (z, z̄))
end

@testset "clamp" begin
x̄, ȳ, z̄ = randn(3)
Δx, Δy, Δz = randn(3)
Expand Down
40 changes: 27 additions & 13 deletions test/rulesets/Base/fastmath_able.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ const FASTABLE_AST = quote
test_scalar(atan, x)
end
@testset "Multivariate" begin
@testset "sincos" begin
x, Δx, x̄ = randn(3)
Δz = (randn(), randn())
@testset "sincos(x::$T)" for T in (Float64, ComplexF64)
x, Δx, x̄ = randn(T, 3)
Δz = (randn(T), randn(T))

frule_test(sincos, (x, Δx))
rrule_test(sincos, Δz, (x, x̄))
Expand All @@ -66,17 +66,19 @@ const FASTABLE_AST = quote
end

@testset "exponents" begin
for x in (-0.1, 6.4)
for x in (-0.1, 6.4, 0.5 + 0.25im)
test_scalar(inv, x)

test_scalar(exp, x)
test_scalar(exp2, x)
test_scalar(exp10, x)
test_scalar(expm1, x)

test_scalar(cbrt, x)
if x isa Real
test_scalar(cbrt, x)
end

if x >= 0
if x isa Complex || x >= 0
test_scalar(sqrt, x)
test_scalar(log, x)
test_scalar(log2, x)
Expand All @@ -103,19 +105,31 @@ const FASTABLE_AST = quote
end

@testset "Unary functions" begin
for x in (-4.1, 6.4)
for x in (-4.1, 6.4, 0.0, 0.0 + 0.0im, 0.5 + 0.25im)
test_scalar(+, x)
test_scalar(-, x)
test_scalar(atan, x)
end
end

@testset "binary function ($f)" for f in (/, +, -, hypot, atan, rem, ^, max, min)
x, Δx, x̄ = 10rand(3)
y, Δy, ȳ = rand(3)
Δz = rand()
@testset "binary functions" begin
@testset "$f(x, y)" for f in (atan, rem, max, min)
x, Δx, x̄ = 10rand(3)
y, Δy, ȳ = rand(3)
Δz = rand()

frule_test(f, (x, Δx), (y, Δy))
rrule_test(f, Δz, (x, x̄), (y, ȳ))
end

@testset "$f(x::$T, y::$T)" for f in (/, +, -, hypot), T in (Float64, ComplexF64)
x, Δx, x̄ = 10rand(T, 3)
y, Δy, ȳ = rand(T, 3)
Δz = randn(typeof(f(x, y)))

frule_test(f, (x, Δx), (y, Δy))
rrule_test(f, Δz, (x, x̄), (y, ȳ))
frule_test(f, (x, Δx), (y, Δy))
rrule_test(f, Δz, (x, x̄), (y, ȳ))
end
end

@testset "sign" begin
Expand Down