diff --git a/Project.toml b/Project.toml index 494b4b779..9e5605a3c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.5.0" +version = "0.5.1" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -11,7 +11,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] ChainRulesCore = "0.7" -ChainRulesTestUtils = "0.2.1" +ChainRulesTestUtils = "0.2.2" Compat = "3" FiniteDifferences = "0.9" Reexport = "0.2" diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index db7672c32..562498342 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -4,7 +4,7 @@ @scalar_rule(abs(x::Real), sign(x)) @scalar_rule(abs2(x), 2x) -@scalar_rule(exp(x), Ω) +@scalar_rule(exp(x::Real), Ω) @scalar_rule(exp10(x), Ω * log(oftype(x, 10))) @scalar_rule(exp2(x), Ω * log(oftype(x, 2))) @scalar_rule(expm1(x), exp(x)) @@ -45,7 +45,8 @@ @scalar_rule(sinh(x), cosh(x)) @scalar_rule(tanh(x), 1-Ω^2) -@scalar_rule(acosh(x), inv(sqrt(x^2 - 1))) +# Can't multiply though sqrt in acosh because of negative complex case for x +@scalar_rule(acosh(x), inv(sqrt(x - 1) * sqrt(x + 1))) @scalar_rule(acoth(x), inv(1 - x^2)) @scalar_rule(acsch(x), -inv(x^2 * sqrt(1 + x^-2))) @scalar_rule(acsch(x::Real), -inv(abs(x) * sqrt(1 + x^2))) @@ -66,7 +67,9 @@ @scalar_rule(-(x, y), (One(), -1)) @scalar_rule(/(x, y), (inv(y), -(x / y / y))) @scalar_rule(\(x, y), (-(y / x / x), inv(x))) -@scalar_rule(^(x, y), (ifelse(iszero(y), zero(Ω), y * x^(y - 1)), Ω * log(x))) + +#log(complex(x)) is require so it give correct complex answer for x<0 +@scalar_rule(^(x, y), (ifelse(iszero(y), zero(Ω), y * x^(y - 1)), Ω * log(complex(x)))) @scalar_rule(cbrt(x), inv(3 * Ω^2)) @scalar_rule(inv(x), -Ω^2) @@ -117,7 +120,7 @@ end function rrule(::typeof(*), x::Number, y::Number) function times_pullback(ΔΩ) - return (NO_FIELDS, @thunk(ΔΩ * y'), @thunk(x' * ΔΩ)) + return (NO_FIELDS, @thunk(ΔΩ * y), @thunk(x * ΔΩ)) end return x * y, times_pullback end @@ -132,3 +135,14 @@ function rrule(::typeof(identity), x) end return x, identity_pullback end + +function rrule(::typeof(identity), x::Tuple) + # `identity(::Tuple)` returns multiple outputs;because that is how we think of + # returning a tuple, so its pullback needs to accept multiple inputs. + # `identity(::Tuple)` has one input, so its pullback should return 1 matching output + # see https://github.com/JuliaDiff/ChainRulesCore.jl/issues/152 + function identity_pullback(ȳs...) + return (NO_FIELDS, Composite{typeof(x)}(ȳs...)) + end + return x, identity_pullback +end diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index 322de69e6..cd4a37be8 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -40,17 +40,17 @@ end ##### `det` ##### -function frule((_, ẋ), ::typeof(det), x) +function frule((_, ẋ), ::typeof(det), x::Union{Number, AbstractMatrix}) Ω = det(x) # TODO Performance optimization: probably there is an efficent # way to compute this trace without during the full compution within return Ω, Ω * tr(inv(x) * ẋ) end -function rrule(::typeof(det), x) +function rrule(::typeof(det), x::Union{Number, AbstractMatrix}) Ω = det(x) function det_pullback(ΔΩ) - return NO_FIELDS, @thunk(Ω * ΔΩ * inv(x)') + return NO_FIELDS, Ω * ΔΩ * transpose(inv(x)) end return Ω, det_pullback end @@ -59,15 +59,15 @@ end ##### `logdet` ##### -function frule((_, Δx), ::typeof(logdet), x) +function frule((_, Δx), ::typeof(logdet), x::Union{Number, AbstractMatrix}) Ω = logdet(x) return Ω, tr(inv(x) * Δx) end -function rrule(::typeof(logdet), x) +function rrule(::typeof(logdet), x::Union{Number, AbstractMatrix}) Ω = logdet(x) function logdet_pullback(ΔΩ) - return (NO_FIELDS, @thunk(ΔΩ * inv(x)')) + return (NO_FIELDS, ΔΩ * transpose(inv(x))) end return Ω, logdet_pullback end @@ -81,6 +81,8 @@ function frule((_, Δx), ::typeof(tr), x) end function rrule(::typeof(tr), x) + # This should really be a FillArray + # see https://github.com/JuliaDiff/ChainRules.jl/issues/46 function tr_pullback(ΔΩ) return (NO_FIELDS, @thunk Diagonal(fill(ΔΩ, size(x, 1)))) end @@ -121,14 +123,11 @@ function rrule(::typeof(/), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R C, dC_pb = rrule(adjoint, Cᵀ) function slash_pullback(Ȳ) # Optimization note: dAᵀ, dBᵀ, dC are calculated no matter which partial you want - # this is not a problem if you want the 2nd or 3rd, but if you want the first, it - # is fairly wasteful _, dC = dC_pb(Ȳ) - _, dBᵀ, dAᵀ = dS_pb(extern(dC)) + _, dBᵀ, dAᵀ = dS_pb(unthunk(dC)) - # need to extern as dAᵀ, dBᵀ are generally `Thunk`s, which don't support adjoint - ∂A = @thunk last(dA_pb(extern(dAᵀ))) - ∂B = @thunk last(dA_pb(extern(dBᵀ))) + ∂A = last(dA_pb(unthunk(dAᵀ))) + ∂B = last(dA_pb(unthunk(dBᵀ))) (NO_FIELDS, ∂A, ∂B) end diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl index 45f026edb..666ffe7dc 100644 --- a/src/rulesets/LinearAlgebra/structured.jl +++ b/src/rulesets/LinearAlgebra/structured.jl @@ -6,8 +6,15 @@ ##### function rrule(::Type{<:Diagonal}, d::AbstractVector) - function Diagonal_pullback(ȳ) - return (NO_FIELDS, @thunk(diag(ȳ))) + function Diagonal_pullback(ȳ::AbstractMatrix) + return (NO_FIELDS, diag(ȳ)) + end + function Diagonal_pullback(ȳ::Composite) + # TODO: Assert about the primal type in the Composite, It should be Diagonal + # infact it should be exactly the type of `Diagonal(d)` + # but right now Zygote loses primal type information so we can't use it. + # See https://github.com/FluxML/Zygote.jl/issues/603 + return (NO_FIELDS, ȳ.diag) end return Diagonal(d), Diagonal_pullback end diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index a6d56debb..9a3142ef4 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -35,14 +35,15 @@ test_scalar(acsc, 1/x) test_scalar(acot, 1/x) end - @testset "Inverse hyperbolic" for x = (0.5, Complex(0.5, 0.25)) + @testset "Inverse hyperbolic" for x = (0.5, Complex(0.5, 0.25), Complex(-2.1 -3.1im)) test_scalar(asinh, x) - test_scalar(acosh, x + 1) # +1 accounts for domain + test_scalar(acosh, x + 1) # +1 accounts for domain for real test_scalar(atanh, x) test_scalar(asech, x) test_scalar(acsch, x) test_scalar(acoth, x + 1) end + @testset "Inverse degrees" for x = (0.5, Complex(0.5, 0.25)) test_scalar(asind, x) test_scalar(acosd, x) @@ -100,7 +101,23 @@ end end - @testset "*(x, y)" begin + @testset "*(x, y) (scalar)" begin + # This is pretty important so testing it fairly heavily + test_points = (0.0, -2.1, 3.2, 3.7+2.12im, 14.2-7.1im) + @testset "$x * $y; (perturbed by: $perturb)" for + x in test_points, y in test_points, perturb in test_points + + # give small off-set so as can't slip in symmetry + x̄ = ẋ = 0.5 + perturb + ȳ = ẏ = 0.6 + perturb + Δz = perturb + + frule_test(*, (x, ẋ), (y, ẏ)) + rrule_test(*, Δz, (x, x̄), (y, ȳ)) + end + end + + @testset "matmul *(x, y)" begin x, y = rand(3, 2), rand(2, 5) z, pullback = rrule(*, x, y) @@ -125,10 +142,26 @@ rrule_test(f, Δz, (x, x̄), (y, ȳ)) end + @testset "x^n for x<0" begin + rng = MersenneTwister(123456) + x = -15*rand(rng) + Δx, x̄ = 10rand(rng, 2) + y, Δy, ȳ = rand(rng, 3) + Δz = rand(rng) + + frule_test(^, (-x, Δx), (y, Δy)) + rrule_test(^, Δz, (-x, x̄), (y, ȳ)) + end + @testset "identity" begin rng = MersenneTwister(1) rrule_test(identity, randn(rng), (randn(rng), randn(rng))) rrule_test(identity, randn(rng, 4), (randn(rng, 4), randn(rng, 4))) + + rrule_test( + identity, Tuple(randn(rng, 3)), + (Composite{Tuple}(randn(rng, 3)...), Composite{Tuple}(randn(rng, 3)...)) + ) end @testset "Constants" for x in (-0.1, 6.4, 1.0+0.5im, -10.0+0im, 0+200im) diff --git a/test/rulesets/LinearAlgebra/structured.jl b/test/rulesets/LinearAlgebra/structured.jl index 1731aad32..07b8ef986 100644 --- a/test/rulesets/LinearAlgebra/structured.jl +++ b/test/rulesets/LinearAlgebra/structured.jl @@ -6,7 +6,15 @@ rrule_test(Diagonal, D, (randn(rng, N), randn(rng, N))) # Concrete type instead of UnionAll rrule_test(typeof(D), D, (randn(rng, N), randn(rng, N))) + + # TODO: replace this with a `rrule_test` once we have that working + # see https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/24 + res, pb = rrule(Diagonal, [1, 4]) + @test pb(10*res) == (NO_FIELDS, [10, 40]) + comp = Composite{typeof(res)}(; diag=10*res.diag) # this is the structure of Diagonal + @test pb(comp) == (NO_FIELDS, [10, 40]) end + @testset "::Diagonal * ::AbstractVector" begin rng, N = MersenneTwister(123456), 3 rrule_test(