diff --git a/Project.toml b/Project.toml index 5af68933..0c28b962 100644 --- a/Project.toml +++ b/Project.toml @@ -13,8 +13,8 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" [compat] -ChainRules = "1.5" -ChainRulesCore = "1.2" +ChainRules = "1.44.6" +ChainRulesCore = "1.15.3" Combinatorics = "1" StaticArrays = "1" StatsBase = "0.33" diff --git a/src/extra_rules.jl b/src/extra_rules.jl index cd70a6df..cb6a103d 100644 --- a/src/extra_rules.jl +++ b/src/extra_rules.jl @@ -16,7 +16,7 @@ function ChainRulesCore.rrule(::DiffractorRuleConfig, g::∇getindex, Δ) g(Δ), Δ′′->(nothing, Δ′′[1][g.i...]) end -function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(getindex), xs::Array, i...) +function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(getindex), xs::Array{<:Number}, i...) xs[i...], ∇getindex(xs, i) end @@ -150,12 +150,6 @@ end ChainRulesCore.canonicalize(::ChainRulesCore.ZeroTangent) = ChainRulesCore.ZeroTangent() -# Skip AD'ing through the axis computation -function ChainRules.rrule(::DiffractorRuleConfig, ::typeof(Base.Broadcast.instantiate), bc::Base.Broadcast.Broadcasted) - return Base.Broadcast.instantiate(bc), Δ->begin - Core.tuple(NoTangent(), Δ) - end -end using StaticArrays @@ -187,10 +181,6 @@ end @ChainRulesCore.non_differentiable StaticArrays.promote_tuple_eltype(T) -function ChainRules.frule((_, ∂A), ::typeof(getindex), A::AbstractArray, args...) - getindex(A, args...), getindex(∂A, args...) -end - function ChainRules.rrule(::DiffractorRuleConfig, ::typeof(map), ::typeof(+), A::AbstractArray, B::AbstractArray) map(+, A, B), Δ->(NoTangent(), NoTangent(), Δ, Δ) end @@ -225,27 +215,28 @@ struct BackMap{T} f::T end (f::BackMap{N})(args...) where {N} = ∂⃖¹(getfield(f, :f), args...) -back_apply(x, y) = x(y) -back_apply_zero(x) = x(Zero()) +back_apply(x, y) = x(y) # this is just |> with arguments reversed +back_apply_zero(x) = x(Zero()) # Zero is not defined function ChainRules.rrule(::DiffractorRuleConfig, ::typeof(map), f, args::Tuple) a, b = unzip_tuple(map(BackMap(f), args)) - function back(Δ) + function map_back(Δ) (fs, xs) = unzip_tuple(map(back_apply, b, Δ)) (NoTangent(), sum(fs), xs) end - function back(Δ::ZeroTangent) - (fs, xs) = unzip_tuple(map(back_apply_zero, b)) - (NoTangent(), sum(fs), xs) - end - a, back + map_back(Δ::AbstractZero) = (NoTangent(), NoTangent(), NoTangent()) + a, map_back end +ChainRules.rrule(::DiffractorRuleConfig, ::typeof(map), f, args::Tuple{}) = (), _ -> (NoTangent(), NoTangent(), NoTangent()) + function ChainRules.rrule(::DiffractorRuleConfig, ::typeof(Base.ntuple), f, n) a, b = unzip_tuple(ntuple(BackMap(f), n)) - a, function (Δ) + function ntuple_back(Δ) (NoTangent(), sum(map(back_apply, b, Δ)), NoTangent()) end + ntuple_back(::AbstractZero) = (NoTangent(), NoTangent(), NoTangent()) + a, ntuple_back end function ChainRules.frule(::DiffractorRuleConfig, _, ::Type{Vector{T}}, undef::UndefInitializer, dims::Int...) where {T} @@ -267,5 +258,4 @@ function ChainRulesCore.rrule(::DiffractorRuleConfig, ::Type{InplaceableThunk}, val, Δ->(NoTangent(), NoTangent(), Δ) end -Base.real(z::ZeroTangent) = z # TODO should be in CRC -Base.real(z::NoTangent) = z +Base.real(z::NoTangent) = z # TODO should be in CRC, https://github.com/JuliaDiff/ChainRulesCore.jl/pull/581 diff --git a/src/runtime.jl b/src/runtime.jl index e3c1edec..48776239 100644 --- a/src/runtime.jl +++ b/src/runtime.jl @@ -27,3 +27,4 @@ accum(x::Tangent{T}, y::Tangent) where T = _tangent(T, accum(backing(x), backing _tangent(::Type{T}, z) where T = Tangent{T,typeof(z)}(z) _tangent(::Type, ::NamedTuple{()}) = NoTangent() +_tangent(::Type, ::NamedTuple{<:Any, <:Tuple{Vararg{AbstractZero}}}) = NoTangent() diff --git a/src/stage1/broadcast.jl b/src/stage1/broadcast.jl index 30cf98ef..e04bd60a 100644 --- a/src/stage1/broadcast.jl +++ b/src/stage1/broadcast.jl @@ -28,46 +28,3 @@ function (∂ₙ::∂☆{N})(zc::ZeroBundle{N, typeof(copy)}, end return r end - -# Broadcast over one element is just map -function (∂⃖ₙ::∂⃖{N})(::typeof(broadcasted), f, a::Array) where {N} - ∂⃖ₙ(map, f, a) -end - -# The below is from Zygote: TODO: DO we want to do something better here? - -accum_sum(xs::Nothing; dims = :) = NoTangent() -accum_sum(xs::AbstractArray{Nothing}; dims = :) = NoTangent() -accum_sum(xs::AbstractArray{<:Number}; dims = :) = sum(xs, dims = dims) -accum_sum(xs::AbstractArray{<:AbstractArray{<:Number}}; dims = :) = sum(xs, dims = dims) -accum_sum(xs::Number; dims = :) = xs - -# https://github.com/FluxML/Zygote.jl/issues/594 -function Base.reducedim_init(::typeof(identity), ::typeof(accum), A::AbstractArray, region) - Base.reducedim_initarray(A, region, NoTangent(), Union{Nothing,eltype(A)}) -end - -trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x)))) - -unbroadcast(x::AbstractArray, x̄) = - size(x) == size(x̄) ? x̄ : - length(x) == length(x̄) ? trim(x, x̄) : - trim(x, accum_sum(x̄, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(x̄)+1, Val(ndims(x̄))))) - -unbroadcast(x::Number, x̄) = accum_sum(x̄) -unbroadcast(x::Tuple{<:Any}, x̄) = (accum_sum(x̄),) -unbroadcast(x::Base.RefValue, x̄) = (x=accum_sum(x̄),) - -unbroadcast(x::AbstractArray, x̄::Nothing) = NoTangent() - -const Numeric = Union{Number, AbstractArray{<:Number, N} where N} - -function ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(broadcasted), ::typeof(+), xs::Numeric...) - broadcast(+, xs...), ȳ -> (NoTangent(), NoTangent(), map(x -> unbroadcast(x, unthunk(ȳ)), xs)...) -end - -ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(broadcasted), ::typeof(-), x::Numeric, y::Numeric) = x .- y, - Δ -> let Δ=unthunk(Δ); (NoTangent(), NoTangent(), unbroadcast(x, Δ), -unbroadcast(y, Δ)); end - -ChainRulesCore.rrule(::DiffractorRuleConfig, ::typeof(broadcasted), ::typeof(*), x::Numeric, y::Numeric) = x.*y, - z̄ -> let z̄=unthunk(z̄); (NoTangent(), NoTangent(), unbroadcast(x, z̄ .* conj.(y)), unbroadcast(y, z̄ .* conj.(x))); end \ No newline at end of file diff --git a/src/stage1/generated.jl b/src/stage1/generated.jl index 5ba3588d..0e723135 100644 --- a/src/stage1/generated.jl +++ b/src/stage1/generated.jl @@ -315,13 +315,13 @@ function (::∂⃖{N})(::typeof(Core.getfield), s, field::Symbol) where {N} end # TODO: Temporary - make better -function (::∂⃖{N})(::typeof(Base.getindex), a::Array, inds...) where {N} +function (::∂⃖{N})(::typeof(Base.getindex), a::Array{<:Number}, inds...) where {N} getindex(a, inds...), let EvenOddOdd{1, c_order(N)}( (@Base.constprop :aggressive Δ->begin Δ isa AbstractZero && return (NoTangent(), Δ, map(Returns(Δ), inds)...) BB = zero(a) - BB[inds...] = Δ + BB[inds...] = unthunk(Δ) (NoTangent(), BB, map(x->NoTangent(), inds)...) end), (@Base.constprop :aggressive (_, Δ, _)->begin @@ -334,6 +334,7 @@ struct tuple_back{M}; end (::tuple_back)(Δ::Tuple) = Core.tuple(NoTangent(), Δ...) (::tuple_back{N})(Δ::AbstractZero) where {N} = Core.tuple(NoTangent(), ntuple(i->Δ, N)...) (::tuple_back{N})(Δ::Tangent) where {N} = Core.tuple(NoTangent(), ntuple(i->lifted_getfield(Δ, i), N)...) +(t::tuple_back)(Δ::AbstractThunk) = t(unthunk(Δ)) function (::∂⃖{N})(::typeof(Core.tuple), args::Vararg{Any, M}) where {N, M} Core.tuple(args...), diff --git a/test/runtests.jl b/test/runtests.jl index b0014bf5..3ff1732e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -92,9 +92,9 @@ let var"'" = Diffractor.PrimeDerivativeBack @test @inferred(sin'(1.0)) == cos(1.0) @test @inferred(sin''(1.0)) == -sin(1.0) @test sin'''(1.0) == -cos(1.0) - @test sin''''(1.0) == sin(1.0) broken = VERSION >= v"1.8" - @test sin'''''(1.0) == cos(1.0) broken = VERSION >= v"1.8" - @test sin''''''(1.0) == -sin(1.0) broken = VERSION >= v"1.8" + @test sin''''(1.0) == sin(1.0) + @test sin'''''(1.0) == cos(1.0) # broken = VERSION >= v"1.8" + @test sin''''''(1.0) == -sin(1.0) # broken = VERSION >= v"1.8" f_getfield(x) = getfield((x,), 1) @test f_getfield'(1) == 1 @@ -219,6 +219,68 @@ z45, delta45 = frule_via_ad(DiffractorRuleConfig(), (0,1), x -> log(exp(x)), 2) @test z45 ≈ 2.0 @test delta45 ≈ 1.0 +# PR #82 - getindex on non-numeric arrays +@test gradient(ls -> ls[1](1.), [Base.Fix1(*, 1.)])[1][1] isa Tangent{<:Base.Fix1} + +@testset "broadcast" begin + @test gradient(x -> sum(x ./ x), [1,2,3]) == ([0,0,0],) # derivatives_given_output + @test gradient(x -> sum(sqrt.(atan.(x, transpose(x)))), [1,2,3])[1] ≈ [0.2338, -0.0177, -0.0661] atol=1e-3 + @test gradient(x -> sum(exp.(log.(x))), [1,2,3]) == ([1,1,1],) + + @test gradient(x -> sum((exp∘log).(x)), [1,2,3]) == ([1,1,1],) # frule_via_ad + exp_log(x) = exp(log(x)) + @test gradient(x -> sum(exp_log.(x)), [1,2,3]) == ([1,1,1],) + @test gradient((x,y) -> sum(x ./ y), [1 2; 3 4], [1,2]) == ([1 1; 0.5 0.5], [-3, -1.75]) + @test gradient((x,y) -> sum(x ./ y), [1 2; 3 4], 5) == ([0.2 0.2; 0.2 0.2], -0.4) + @test gradient(x -> sum((y -> y/x).([1,2,3])), 4) == (-0.375,) # closure + + @test gradient(x -> sum(sum, (x,) ./ x), [1,2,3])[1] ≈ [-4.1666, 0.3333, 1.1666] atol=1e-3 # array of arrays + @test gradient(x -> sum(sum, Ref(x) ./ x), [1,2,3])[1] ≈ [-4.1666, 0.3333, 1.1666] atol=1e-3 + @test gradient(x -> sum(sum, (x,) ./ x), [1,2,3])[1] ≈ [-4.1666, 0.3333, 1.1666] atol=1e-3 + @test gradient(x -> sum(sum, (x,) .* transpose(x)), [1,2,3])[1] ≈ [12, 12, 12] # must not take the * fast path + + @test gradient(x -> sum(x ./ 4), [1,2,3]) == ([0.25, 0.25, 0.25],) + @test gradient(x -> sum([1,2,3] ./ x), 4) == (-0.375,) # x/y rule + @test gradient(x -> sum(x.^2), [1,2,3]) == ([2.0, 4.0, 6.0],) # x.^2 rule + @test gradient(x -> sum([1,2,3] ./ x.^2), 4) == (-0.1875,) # scalar^2 rule + + @test gradient(x -> sum((1,2,3) .- x), (1,2,3)) == (Tangent{Tuple{Int,Int,Int}}(-1.0, -1.0, -1.0),) + @test gradient(x -> sum(transpose([1,2,3]) .- x), (1,2,3)) == (Tangent{Tuple{Int,Int,Int}}(-3.0, -3.0, -3.0),) + @test gradient(x -> sum([1 2 3] .+ x .^ 2), (1,2,3)) == (Tangent{Tuple{Int,Int,Int}}(6.0, 12.0, 18.0),) + + @test gradient(x -> sum(x .> 2), [1,2,3]) |> only |> iszero # Bool output + @test gradient(x -> sum(1 .+ iseven.(x)), [1,2,3]) |> only |> iszero + @test gradient((x,y) -> sum(x .== y), [1,2,3], [1 2 3]) == (NoTangent(), NoTangent()) + @test gradient(x -> sum(x .+ [1,2,3]), true) |> only |> iszero # Bool input + @test gradient(x -> sum(x ./ [1,2,3]), [true false]) |> only |> iszero + @test gradient(x -> sum(x .* transpose([1,2,3])), (true, false)) |> only |> iszero + + tup_adj = gradient((x,y) -> sum(2 .* x .+ log.(y)), (1,2), transpose([3,4,5])) + @test tup_adj[1] == Tangent{Tuple{Int64, Int64}}(6.0, 6.0) + @test tup_adj[2] ≈ [0.6666666666666666 0.5 0.4] + @test tup_adj[2] isa Transpose + @test gradient(x -> sum(atan.(x, (1,2,3))), Diagonal([4,5,6]))[1] isa Diagonal + + @test gradient(x -> sum((y -> (x*y)).([1,2,3])), 4.0) == (6.0,) # closure +end + +@testset "broadcast, 2nd order" begin + @test gradient(x -> gradient(y -> sum(y .* y), x)[1] |> sum, [1,2,3.0])[1] == [2,2,2] # calls "split broadcasting generic" with f = unthunk + @test gradient(x -> gradient(y -> sum(y .* x), x)[1].^3 |> sum, [1,2,3.0])[1] == [3,12,27] + @test_broken gradient(x -> gradient(y -> sum(y .* 2 .* y'), x)[1] |> sum, [1,2,3.0])[1] == [12, 12, 12] # Control flow support not fully implemented yet for higher-order + + @test_broken gradient(x -> sum(gradient(x -> sum(x .^ 2 .+ x'), x)[1]), [1,2,3.0])[1] == [6,6,6] # BoundsError: attempt to access 18-element Vector{Core.Compiler.BasicBlock} at index [0] + @test_broken gradient(x -> sum(gradient(x -> sum((x .+ 1) .* x .- x), x)[1]), [1,2,3.0])[1] == [2,2,2] + @test_broken gradient(x -> sum(gradient(x -> sum(x .* x ./ 2), x)[1]), [1,2,3.0])[1] == [1,1,1] + + @test_broken gradient(x -> sum(gradient(x -> sum(exp.(x)), x)[1]), [1,2,3])[1] ≈ exp.(1:3) # MethodError: no method matching copy(::Nothing) + @test_broken gradient(x -> sum(gradient(x -> sum(atan.(x, x')), x)[1]), [1,2,3.0])[1] ≈ [0,0,0] + @test_broken gradient(x -> sum(gradient(x -> sum(transpose(x) .* x), x)[1]), [1,2,3]) == ([6,6,6],) # accum(a::Transpose{Float64, Vector{Float64}}, b::ChainRulesCore.Tangent{Transpose{Int64, Vector{Int64}}, NamedTuple{(:parent,), Tuple{ChainRulesCore.NoTangent}}}) + @test_broken gradient(x -> sum(gradient(x -> sum(transpose(x) ./ x.^2), x)[1]), [1,2,3])[1] ≈ [27.675925925925927, -0.824074074074074, -2.1018518518518516] + + @test_broken gradient(z -> gradient(x -> sum((y -> (x^2*y)).([1,2,3])), z)[1], 5.0) == (12.0,) +end + # Higher order control flow not yet supported (https://github.com/JuliaDiff/Diffractor.jl/issues/24) #include("pinn.jl")