Skip to content

Commit

Permalink
Change == to ignore measure-zero branches (#481)
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott authored Aug 30, 2022
1 parent 61e4dd4 commit 6a6443b
Show file tree
Hide file tree
Showing 9 changed files with 196 additions and 64 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ForwardDiff"
uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.32"
version = "0.10.33"

[deps]
CommonSubexpressions = "bbf7d656-a473-5ed7-a52c-81e309532950"
Expand Down
27 changes: 25 additions & 2 deletions src/dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -384,17 +384,40 @@ for pred in UNARY_PREDICATES
@eval Base.$(pred)(d::Dual) = $(pred)(value(d))
end

for pred in BINARY_PREDICATES
# Before PR#481 this loop ran over this list:
# BINARY_PREDICATES = Symbol[:isequal, :isless, :<, :>, :(==), :(!=), :(<=), :(>=)]
# Not a minimal set, as Base defines some in terms of others.
for pred in [:isless, :<, :>, :(<=), :(>=)]
@eval begin
@define_binary_dual_op(
Base.$(pred),
$(pred)(value(x), value(y)),
$(pred)(value(x), y),
$(pred)(x, value(y))
$(pred)(x, value(y)),
)
end
end

Base.iszero(x::Dual) = iszero(value(x)) && iszero(partials(x)) # shortcut, equivalent to x == zero(x)

for pred in [:isequal, :(==)]
@eval begin
@define_binary_dual_op(
Base.$(pred),
$(pred)(value(x), value(y)) && $(pred)(partials(x), partials(y)),
$(pred)(value(x), y) && iszero(partials(x)),
$(pred)(x, value(y)) && iszero(partials(y)),
)
end
end

@define_binary_dual_op(
Base.:(!=),
(!=)(value(x), value(y)) || (!=)(partials(x), partials(y)),
(!=)(value(x), y) || !iszero(partials(x)),
(!=)(x, value(y)) || !iszero(partials(y)),
)

########################
# Promotion/Conversion #
########################
Expand Down
2 changes: 1 addition & 1 deletion src/prelude.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ const AMBIGUOUS_TYPES = (AbstractFloat, Irrational, Integer, Rational, Real, Rou

const UNARY_PREDICATES = Symbol[:isinf, :isnan, :isfinite, :iseven, :isodd, :isreal, :isinteger]

const BINARY_PREDICATES = Symbol[:isequal, :isless, :<, :>, :(==), :(!=), :(<=), :(>=)]
const DEFAULT_CHUNK_THRESHOLD = 12

struct Chunk{N} end

Expand Down
62 changes: 46 additions & 16 deletions test/DualTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ ForwardDiff.:≺(::Int, ::Type{TestTag()}) = false
ForwardDiff.:(::Type{TestTag}, ::Type{OuterTestTag}) = true
ForwardDiff.:(::Type{OuterTestTag}, ::Type{TestTag}) = false

for N in (0,3), M in (0,4), V in (Int, Float32)
@testset "Dual{Z,$V,$N} and Dual{Z,Dual{Z,$V,$M},$N}" for N in (0,3), M in (0,4), V in (Int, Float32)
println(" ...testing Dual{TestTag(),$V,$N} and Dual{TestTag(),Dual{TestTag(),$V,$M},$N}")

PARTIALS = Partials{N,V}(ntuple(n -> intrand(V), N))
Expand All @@ -44,6 +44,13 @@ for N in (0,3), M in (0,4), V in (Int, Float32)
PARTIALS3 = Partials{N,V}(ntuple(n -> intrand(V), N))
PRIMAL3 = intrand(V)
FDNUM3 = Dual{TestTag()}(PRIMAL3, PARTIALS3)

if !allunique([PRIMAL, PRIMAL2, PRIMAL3])
@info "testing with non-unique primals" PRIMAL PRIMAL2 PRIMAL3
end
if N > 0 && !allunique([PARTIALS, PARTIALS2, PARTIALS3])
@info "testing with non-unique partials" PARTIALS PARTIALS2 PARTIALS3
end

M_PARTIALS = Partials{M,V}(ntuple(m -> intrand(V), M))
NESTED_PARTIALS = convert(Partials{N,Dual{TestTag(),V,M}}, PARTIALS)
Expand Down Expand Up @@ -231,15 +238,27 @@ for N in (0,3), M in (0,4), V in (Int, Float32)
@test ForwardDiff.isconstant(one(NESTED_FDNUM))
@test ForwardDiff.isconstant(NESTED_FDNUM) == (N == 0)

@test isequal(FDNUM, Dual{TestTag()}(PRIMAL, PARTIALS2))
@test isequal(PRIMAL, PRIMAL2) == isequal(FDNUM, FDNUM2)

@test isequal(NESTED_FDNUM, Dual{TestTag()}(Dual{TestTag()}(PRIMAL, M_PARTIALS2), NESTED_PARTIALS2))
@test isequal(PRIMAL, PRIMAL2) == isequal(NESTED_FDNUM, NESTED_FDNUM2)

@test FDNUM == Dual{TestTag()}(PRIMAL, PARTIALS2)
@test (PRIMAL == PRIMAL2) == (FDNUM == FDNUM2)
@test (PRIMAL == PRIMAL2) == (NESTED_FDNUM == NESTED_FDNUM2)
# Recall that FDNUM = Dual{TestTag()}(PRIMAL, PARTIALS) has N partials,
# and FDNUM2 has everything with a 2, and all random numbers nonzero.
# M is the length of M_PARTIALS, which affects:
# NESTED_FDNUM = Dual{TestTag()}(Dual{TestTag()}(PRIMAL, M_PARTIALS), NESTED_PARTIALS)

@test (FDNUM == Dual{TestTag()}(PRIMAL, PARTIALS2)) == (PARTIALS == PARTIALS2)
@test isequal(FDNUM, Dual{TestTag()}(PRIMAL, PARTIALS2)) == (PARTIALS == PARTIALS2)
@test isequal(NESTED_FDNUM, Dual{TestTag()}(Dual{TestTag()}(PRIMAL, M_PARTIALS2), NESTED_PARTIALS2)) == ((M_PARTIALS == M_PARTIALS2) && (NESTED_PARTIALS == NESTED_PARTIALS2))

if PRIMAL == PRIMAL2
@test isequal(FDNUM, Dual{TestTag()}(PRIMAL, PARTIALS2)) == (PARTIALS == PARTIALS2)
@test isequal(FDNUM, FDNUM2) == (PARTIALS == PARTIALS2)

@test (FDNUM == FDNUM2) == (PARTIALS == PARTIALS2)
@test (NESTED_FDNUM == NESTED_FDNUM2) == ((M_PARTIALS == M_PARTIALS2) && (NESTED_PARTIALS == NESTED_PARTIALS2))
else
@test !isequal(FDNUM, FDNUM2)

@test FDNUM != FDNUM2
@test NESTED_FDNUM != NESTED_FDNUM2
end

@test isless(Dual{TestTag()}(1, PARTIALS), Dual{TestTag()}(2, PARTIALS2))
@test !(isless(Dual{TestTag()}(1, PARTIALS), Dual{TestTag()}(1, PARTIALS2)))
Expand Down Expand Up @@ -344,7 +363,7 @@ for N in (0,3), M in (0,4), V in (Int, Float32)
@test typeof(WIDE_NESTED_FDNUM) === Dual{TestTag(),Dual{TestTag(),WIDE_T,M},N}

@test value(WIDE_FDNUM) == PRIMAL
@test value(WIDE_NESTED_FDNUM) == PRIMAL
@test (value(WIDE_NESTED_FDNUM) == PRIMAL) == (M == 0)

@test convert(Dual, FDNUM) === FDNUM
@test convert(Dual, NESTED_FDNUM) === NESTED_FDNUM
Expand Down Expand Up @@ -395,6 +414,8 @@ for N in (0,3), M in (0,4), V in (Int, Float32)
#----------#

if M > 0 && N > 0
# Recall that FDNUM = Dual{TestTag()}(PRIMAL, PARTIALS) has N partials,
# all random numbers nonzero, and FDNUM2 another draw. M only affects NESTED_FDNUM.
@test Dual{1}(FDNUM) / Dual{1}(PRIMAL) === Dual{1}(FDNUM / PRIMAL)
@test Dual{1}(PRIMAL) / Dual{1}(FDNUM) === Dual{1}(PRIMAL / FDNUM)
@test_broken Dual{1}(FDNUM) / FDNUM2 === Dual{1}(FDNUM / FDNUM2)
Expand All @@ -413,6 +434,7 @@ for N in (0,3), M in (0,4), V in (Int, Float32)

# Exponentiation #
#----------------#

# If V == Int, the LHS terms are Int's. Large inputs cause integer overflow
# within the generic fallback of `isapprox`, resulting in a DomainError.
# Promote to Float64 to avoid issues.
Expand Down Expand Up @@ -442,7 +464,7 @@ for N in (0,3), M in (0,4), V in (Int, Float32)
@test abs(NESTED_FDNUM) === NESTED_FDNUM

if V != Int
for (M, f, arity) in DiffRules.diffrules(filter_modules = nothing)
@testset "$f" for (M, f, arity) in DiffRules.diffrules(filter_modules = nothing)
if f in (:/, :rem2pi)
continue # Skip these rules
elseif !(isdefined(@__MODULE__, M) && isdefined(getfield(@__MODULE__, M), f))
Expand Down Expand Up @@ -502,10 +524,14 @@ for N in (0,3), M in (0,4), V in (Int, Float32)
else
@test dx isa Complex{<:Dual{TestTag()}}
@test dy isa Complex{<:Dual{TestTag()}}
@test real(value(dx)) == real(actualval)
@test real(value(dy)) == real(actualval)
@test imag(value(dx)) == imag(actualval)
@test imag(value(dy)) == imag(actualval)
# @test real(value(dx)) == real(actualval)
# @test real(value(dy)) == real(actualval)
# @test imag(value(dx)) == imag(actualval)
# @test imag(value(dy)) == imag(actualval)
@test value(real(dx)) == real(actualval)
@test value(real(dy)) == real(actualval)
@test value(imag(dx)) == imag(actualval)
@test value(imag(dy)) == imag(actualval)
@test partials(real(dx), 1) real(actualdx) nans=true
@test partials(real(dy), 1) real(actualdy) nans=true
@test partials(imag(dx), 1) imag(actualdx) nans=true
Expand Down Expand Up @@ -568,6 +594,10 @@ for N in (0,3), M in (0,4), V in (Int, Float32)
end
end

#############
# bug fixes #
#############

@testset "Exponentiation of zero" begin
x0 = 0.0
x1 = Dual{:t1}(x0, 1.0)
Expand Down
57 changes: 53 additions & 4 deletions test/GradientTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module GradientTest
import Calculus

using Test
using LinearAlgebra
using ForwardDiff
using ForwardDiff: Dual, Tag
using StaticArrays
Expand All @@ -19,7 +20,7 @@ x = [0.1, 0.2, 0.3]
v = f(x)
g = [-9.4, 15.6, 52.0]

for c in (1, 2, 3), tag in (nothing, Tag(f, eltype(x)))
@testset "Rosenbrock, chunk size = $c and tag = $(repr(tag))" for c in (1, 2, 3), tag in (nothing, Tag(f, eltype(x)))
println(" ...running hardcoded test with chunk size = $c and tag = $(repr(tag))")
cfg = ForwardDiff.GradientConfig(f, x, ForwardDiff.Chunk{c}(), tag)

Expand Down Expand Up @@ -55,7 +56,7 @@ cfgx = ForwardDiff.GradientConfig(sin, x)
# test vs. Calculus.jl #
########################

for f in DiffTests.VECTOR_TO_NUMBER_FUNCS
@testset "$f" for f in DiffTests.VECTOR_TO_NUMBER_FUNCS
v = f(X)
g = ForwardDiff.gradient(f, X)
@test isapprox(g, Calculus.gradient(f, X), atol=FINITEDIFF_ERROR)
Expand Down Expand Up @@ -83,9 +84,9 @@ end

println(" ...testing specialized StaticArray codepaths")

x = rand(3, 3)
@testset "$T" for T in (StaticArrays.SArray, StaticArrays.MArray)
x = rand(3, 3)

for T in (StaticArrays.SArray, StaticArrays.MArray)
sx = T{Tuple{3,3}}(x)

cfg = ForwardDiff.GradientConfig(nothing, x)
Expand Down Expand Up @@ -148,6 +149,10 @@ end
@test isequal(ForwardDiff.gradient(t -> t[1]^t[2], [0.0, 1.5]), [0.0, 0.0])
end

#############
# bug fixes #
#############

# Issue 399
@testset "chunk size zero" begin
f_const(x) = 1.0
Expand All @@ -162,11 +167,55 @@ end
@test_throws DimensionMismatch ForwardDiff.gradient(identity, fill(2pi, 10^6)) # chunk_mode_gradient
end

# Issue 548
@testset "ArithmeticStyle" begin
function f(p)
sum(collect(0.0:p[1]:p[2]))
end
@test ForwardDiff.gradient(f, [0.2,25.0]) == [7875.0, 0.0]
end

@testset "det with branches" begin
# Issue 197
det2(A) = return (
A[1,1]*(A[2,2]*A[3,3]-A[2,3]*A[3,2]) -
A[1,2]*(A[2,1]*A[3,3]-A[2,3]*A[3,1]) +
A[1,3]*(A[2,1]*A[3,2]-A[2,2]*A[3,1])
)

A = [1 0 0; 0 2 0; 0 pi 3]
@test det2(A) == det(A) == 6
@test istril(A)

∇A = [6 0 0; 0 3 -pi; 0 0 2]
@test ForwardDiff.gradient(det2, A) ∇A
@test ForwardDiff.gradient(det, A) ∇A

# And issue 407
@test ForwardDiff.hessian(det, A) ForwardDiff.hessian(det2, A)

# https://discourse.julialang.org/t/forwarddiff-and-zygote-return-wrong-jacobian-for-log-det-l/77961
S = [1.0 0.8; 0.8 1.0]
L = cholesky(S).L
@test ForwardDiff.gradient(L -> log(det(L)), Matrix(L)) [1.0 -1.3333333333333337; 0.0 1.666666666666667]
@test ForwardDiff.gradient(L -> logdet(L), Matrix(L)) [1.0 -1.3333333333333337; 0.0 1.666666666666667]
end

@testset "branches in mul!" begin
a, b = rand(3,3), rand(3,3)

# Issue 536, version with 3-arg *, Julia 1.7:
@test ForwardDiff.derivative(x -> sum(x*a*b), 0.0) sum(a * b)

if VERSION >= v"1.3"
# version with just mul!
dx = ForwardDiff.derivative(0.0) do x
c = similar(a, typeof(x))
mul!(c, a, b, x, false)
sum(c)
end
@test dx sum(a * b)
end
end

end # module
8 changes: 8 additions & 0 deletions test/HessianTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module HessianTest
import Calculus

using Test
using LinearAlgebra
using ForwardDiff
using ForwardDiff: Dual, Tag
using StaticArrays
Expand Down Expand Up @@ -157,4 +158,11 @@ for T in (StaticArrays.SArray, StaticArrays.MArray)
@test DiffResults.hessian(sresult3) == DiffResults.hessian(result)
end

@testset "branches in dot" begin
# https://github.com/JuliaDiff/ForwardDiff.jl/issues/551
H = [1 2 3; 4 5 6; 7 8 9];
@test ForwardDiff.hessian(x->dot(x,H,x), fill(0.00001, 3)) [2 6 10; 6 10 14; 10 14 18]
@test ForwardDiff.hessian(x->dot(x,H,x), zeros(3)) [2 6 10; 6 10 14; 10 14 18]
end

end # module
4 changes: 4 additions & 0 deletions test/JacobianTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,10 @@ for T in (StaticArrays.SArray, StaticArrays.MArray)
@test DiffResults.jacobian(sresult3) == DiffResults.jacobian(result)
end

#########
# misc. #
#########

@testset "dimension errors for jacobian" begin
@test_throws DimensionMismatch ForwardDiff.jacobian(identity, 2pi) # input
@test_throws DimensionMismatch ForwardDiff.jacobian(sum, fill(2pi, 2)) # vector_mode_jacobian
Expand Down
2 changes: 1 addition & 1 deletion test/PartialsTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using ForwardDiff: Partials

samerng() = MersenneTwister(1)

for N in (0, 3), T in (Int, Float32, Float64)
@testset "Partials{$N,$T}" for N in (0, 3), T in (Int, Float32, Float64)
println(" ...testing Partials{$N,$T}")

VALUES = (rand(T,N)...,)
Expand Down
Loading

0 comments on commit 6a6443b

Please sign in to comment.