Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove Composite Bundle #216

Merged
merged 15 commits into from
Sep 26, 2023
14 changes: 1 addition & 13 deletions src/AbstractDifferentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,7 @@ This is more or less the Diffractor equivelent of ForwardDiff.jl's `Dual` type.
"""
function bundle end
bundle(x, dx::ChainRulesCore.AbstractZero) = UniformBundle{1, typeof(x), typeof(dx)}(x, dx)
bundle(x::Number, dx::Number) = TaylorBundle{1}(x, (dx,))
bundle(x::AbstractArray{<:Number}, dx::AbstractArray{<:Number}) = TaylorBundle{1}(x, (dx,))
bundle(x::P, dx::Tangent{P}) where P = _bundle(x, ChainRulesCore.canonicalize(dx))

"helper that assumes tangent is in canonical form"
function _bundle(x::P, dx::Tangent{P}) where P
# SoA to AoS flip (hate this, hate it even more cos we just undo it later when we hit chainrules)
the_bundle = ntuple(Val{fieldcount(P)}()) do ii
bundle(getfield(x, ii), getproperty(dx, ii))
end
return CompositeBundle{1, P}(the_bundle)
end

bundle(x, dx) = TaylorBundle{1}(x, (dx,))

AD.@primitive function pushforward_function(b::DiffractorForwardBackend, f, args...)
return function pushforward(vs)
Expand Down
47 changes: 18 additions & 29 deletions src/stage1/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,6 @@ partial(x::TaylorTangent, i) = getfield(getfield(x, :coeffs), i)
partial(x::UniformTangent, i) = getfield(x, :val)
partial(x::ProductTangent, i) = ProductTangent(map(x->partial(x, i), getfield(x, :factors)))
partial(x::AbstractZero, i) = x
partial(x::CompositeBundle{N, B}, i) where {N, B<:Tuple} = Tangent{B}(map(x->partial(x, i), getfield(x, :tup))...)
function partial(x::CompositeBundle{N, B}, i) where {N, B}
# This is tangent for a struct, but fields partials are each stored in a plain tuple
# so we add the names back using the primal `B`
# TODO: If required this can be done as a `@generated` function so it is type-stable
backing = NamedTuple{fieldnames(B)}(map(x->partial(x, i), getfield(x, :tup)))
return Tangent{B, typeof(backing)}(backing)
end


primal(x::AbstractTangentBundle) = x.primal
Expand Down Expand Up @@ -42,27 +34,20 @@ function shuffle_down(b::TaylorBundle{N, B}) where {N, B}
ntuple(_sdown, N-1))
end

function shuffle_down(b::CompositeBundle{N, B}) where {N, B}
z = CompositeBundle{N-1, CompositeBundle{1, B}}(
(CompositeBundle{N-1, Tuple}(
map(shuffle_down, b.tup)
),)
)
z
end

function shuffle_up(r::CompositeBundle{1})
z₀ = primal(r.tup[1])
z₁ = partial(r.tup[1], 1)
z₂ = primal(r.tup[2])
z₁₂ = partial(r.tup[2], 1)
function shuffle_up(r::TaylorBundle{1, Tuple{B1,B2}}) where {B1,B2}
z₀ = primal(r)[1]
z₁ = partial(r, 1)[1]
z₂ = primal(r)[2]
z₁₂ = partial(r, 1)[2]
if z₁ == z₂
return TaylorBundle{2}(z₀, (z₁, z₁₂))
else
return ExplicitTangentBundle{2}(z₀, (z₁, z₂, z₁₂))
end
end

#==
function taylor_compatible(a::ATB{N}, b::ATB{N}) where {N}
primal(b) === a[TaylorTangentIndex(1)] || return false
return all(1:(N-1)) do i
Expand All @@ -76,6 +61,7 @@ isswifty(::UniformBundle) = true
isswifty(b::CompositeBundle) = all(isswifty, b.tup)
isswifty(::Any) = false

#TODO: port this to TaylorTangent over composite structures
function shuffle_up(r::CompositeBundle{N}) where {N}
a, b = r.tup
if isswifty(a) && isswifty(b) && taylor_compatible(a, b)
Expand All @@ -89,6 +75,7 @@ function shuffle_up(r::CompositeBundle{N}) where {N}
ntuple(i->partial(b,i), 1<<(N+1)-1)...))
end
end
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For reference as a part way step to determining this , I first refactored the CompositeBundle version into

function shuffle_up(r::CompositeBundle{N}) where {N}
    a, b = r.tup
    if isswifty(a) && isswifty(b) && taylor_compatible(a, b)
        the_partials = ntuple(N+1) do i
            if ii <= N
                a[TaylorTangentIndex(i)]  # == b[TaylorTangentIndex(i-1)] (except first which is b.primal)
            else  # ii = N+1
                b[TaylorTangentIndex(i-1)]                
            end
        end
        return TaylorBundle{N+1}(primal(a), the_partials)
    else
        the_primal = r.tup[1].primal
        a_partials = r.tup[1].tangent.partials
        b_partials = ntuple(i->partial(b,i), 1<<(N+1)-1) 
        the_partials = (a_partials..., primal_b, b_partials...)
        return TangentBundle{N+1}(the_primal, the_partials)
    end
end

==#
oxinabox marked this conversation as resolved.
Show resolved Hide resolved

function shuffle_up(r::UniformBundle{N, B, U}) where {N, B, U}
(a, b) = primal(r)
Expand Down Expand Up @@ -185,13 +172,6 @@ end
map(y->lifted_getfield(y, s), x.tangent.coeffs))
end

@Base.constprop :aggressive function (::∂☆{N})(::ATB{N, typeof(getfield)}, x::CompositeBundle{N}, s::AbstractTangentBundle{N, Int}) where {N}
x.tup[primal(s)]
end

@Base.constprop :aggressive function (::∂☆{N})(::ATB{N, typeof(getfield)}, x::CompositeBundle{N, B}, s::AbstractTangentBundle{N, Symbol}) where {N, B}
x.tup[Base.fieldindex(B, primal(s))]
end
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this shouldn't be needed as we already have cases for TaylorBundle above.
Likewise several others of this kind below will not be needed.


@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::UniformBundle{N, <:Any, U}, s::AbstractTangentBundle{N}) where {N, U}
UniformBundle{N,<:Any,U}(getfield(primal(x), primal(s)), x.tangent.val)
Expand All @@ -210,9 +190,12 @@ struct FwdMap{N, T<:AbstractTangentBundle{N}}
end
(f::FwdMap{N})(args::AbstractTangentBundle{N}...) where {N} = ∂☆{N}()(f.f, args...)

#==
# TODO port this to TaylorBundle over composite structure
function (::∂☆{N})(::ZeroBundle{N, typeof(map)}, f::ATB{N}, tup::CompositeBundle{N, <:Tuple}) where {N}
∂vararg{N}()(map(FwdMap(f), tup.tup)...)
end
==#

function (::∂☆{N})(::ZeroBundle{N, typeof(map)}, f::ATB{N}, args::ATB{N, <:AbstractArray}...) where {N}
# TODO: This could do an inplace map! to avoid the extra rebundling
Expand Down Expand Up @@ -254,23 +237,28 @@ function (this::∂☆{N})(::ZeroBundle{N, typeof(Core._apply_iterate)}, iterate
Core._apply_iterate(FwdIterate(iterate), this, (f,), args...)
end

#==
#TODO: port this to TaylorTangent over composite structures
function (this::∂☆{N})(::ZeroBundle{N, typeof(iterate)}, t::CompositeBundle{N, <:Tuple}) where {N}
r = iterate(t.tup)
r === nothing && return ZeroBundle{N}(nothing)
∂vararg{N}()(r[1], ZeroBundle{N}(r[2]))
end

#TODO: port this to TaylorTangent over composite structures
function (this::∂☆{N})(::ZeroBundle{N, typeof(iterate)}, t::CompositeBundle{N, <:Tuple}, a::ATB{N}, args::ATB{N}...) where {N}
r = iterate(t.tup, primal(a), map(primal, args)...)
r === nothing && return ZeroBundle{N}(nothing)
∂vararg{N}()(r[1], ZeroBundle{N}(r[2]))
end

#TODO: port this to TaylorTangent over composite structures
function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::CompositeBundle{N, <:Tuple}, i::ATB{N}) where {N}
r = Base.indexed_iterate(t.tup, primal(i))
∂vararg{N}()(r[1], ZeroBundle{N}(r[2]))
end

#TODO: port this to TaylorTangent over composite structures
function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::CompositeBundle{N, <:Tuple}, i::ATB{N}, st1::ATB{N}, st::ATB{N}...) where {N}
r = Base.indexed_iterate(t.tup, primal(i), primal(st1), map(primal, st)...)
∂vararg{N}()(r[1], ZeroBundle{N}(r[2]))
Expand All @@ -280,10 +268,11 @@ function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::Tan
∂vararg{N}()(this(ZeroBundle{N}(getfield), t, i), ZeroBundle{N}(primal(i) + 1))
end


#TODO: port this to TaylorTangent over composite structures
function (this::∂☆{N})(::ZeroBundle{N, typeof(getindex)}, t::CompositeBundle{N, <:Tuple}, i::ZeroBundle) where {N}
t.tup[primal(i)]
end
==#

function (this::∂☆{N})(::ZeroBundle{N, typeof(typeof)}, x::ATB{N}) where {N}
DNEBundle{N}(typeof(primal(x)))
Expand Down
44 changes: 41 additions & 3 deletions src/stage1/recurse_fwd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,51 @@ struct ∂vararg{N}; end

(::∂vararg{N})() where {N} = ZeroBundle{N}(())
function (::∂vararg{N})(a::AbstractTangentBundle{N}...) where N
CompositeBundle{N, Tuple{map(x->basespace(typeof(x)), a)...}}(a)
B = Tuple{map(x->basespace(Core.Typeof(x)), a)...}
return (∂☆new{N}())(B, a...)
end

struct ∂☆new{N}; end

(::∂☆new{N})(B::Type, a::AbstractTangentBundle{N}...) where {N} =
CompositeBundle{N, B}(a)
# we split out the 1st order derivative as a special case for performance
# but the nth order case does also work for this
function (::∂☆new{1})(B::Type, xs::AbstractTangentBundle{1}...)
oxinabox marked this conversation as resolved.
Show resolved Hide resolved
primal_args = map(primal, xs)
the_primal = _construct(B, primal_args)

tangent_tup = map(first_partial, xs)
the_partial = if B<:Tuple
Tangent{B, typeof(tangent_tup)}(tangent_tup)
else
names = fieldnames(B)
tangent_nt = NamedTuple{names}(tangent_tup)
Tangent{B, typeof(tangent_nt)}(tangent_nt)
end
return TaylorBundle{1, B}(the_primal, (the_partial,))
end

function (::∂☆new{N})(B::Type, xs::AbstractTangentBundle{N}...) where {N}
primal_args = map(primal, xs)
the_primal = _construct(B, primal_args)

the_partials = ntuple(Val{N}()) do ii
iith_order_type = ii==1 ? B : Any # the type of the higher order tangents isn't worth tracking
tangent_tup = map(x->partial(x, ii), xs)
tangent = if B<:Tuple
Tangent{iith_order_type, typeof(tangent_tup)}(tangent_tup)
else
names = fieldnames(B)
tangent_nt = NamedTuple{names}(tangent_tup)
Tangent{iith_order_type, typeof(tangent_nt)}(tangent_nt)
end
return tangent
end
return TaylorBundle{N, B}(the_primal, the_partials)
end

_construct(::Type{B}, args) where B<:Tuple = B(args)
# Hack for making things that do not have public constructors constructable:
@generated _construct(B::Type, args) = :($(Expr(:splatnew, :B, :args)))
oxinabox marked this conversation as resolved.
Show resolved Hide resolved

@generated (::∂☆new{N})(B::Type) where {N} = return :(ZeroBundle{$N}($(Expr(:new, :B))))

Expand Down
31 changes: 1 addition & 30 deletions src/tangent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,7 @@ end

function check_taylor_invariants(coeffs, primal, N)
@assert length(coeffs) == N
if isa(primal, TangentBundle)
@assert isa(coeffs[1], TangentBundle)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this check is just wrong AFAICT.
because if the primal is a TangentBundle, then the first derivative is a Tangent{TangentBundle}
The first deriviative has no primal component.

But this was never hit before because we were not making TaylorBundles when taking derivatives of TaylorBundles, (at least not in anything that happens in our test/forward.jl I checked) we were (I assume) making CompositeBundles

end

end
@ChainRulesCore.non_differentiable check_taylor_invariants(coeffs, primal, N)

Expand Down Expand Up @@ -290,33 +288,6 @@ end

Base.getindex(u::UniformBundle, ::TaylorTangentIndex) = u.tangent.val

"""
CompositeBundle{N, B, B <: Tuple}

Represents the tagent bundle where the base space is some tuple or struct type.
Mathematically, this tangent bundle is the product bundle of the individual
element bundles.
"""
struct CompositeBundle{N, B, T<:Tuple{Vararg{AbstractTangentBundle{N}}}} <: AbstractTangentBundle{N, B}
tup::T
end
CompositeBundle{N, B}(tup::T) where {N, B, T} = CompositeBundle{N, B, T}(tup)

function Base.getindex(tb::CompositeBundle{N, B} where N, tti::TaylorTangentIndex) where {B}
B <: SArray && error()
return partial(tb, tti.i)
end

primal(b::CompositeBundle{N, <:Tuple} where N) = map(primal, b.tup)
function primal(b::CompositeBundle{N, T} where N) where T<:CompositeBundle
T(map(primal, b.tup)...)
end
@generated primal(b::CompositeBundle{N, B} where N) where {B} =
quote
x = map(primal, b.tup)
$(Expr(:splatnew, B, :x))
end

expand_singleton_to_array(asize, a::AbstractZero) = fill(a, asize...)
expand_singleton_to_array(asize, a::AbstractArray) = a

Expand Down
8 changes: 4 additions & 4 deletions test/AbstractDifferentiationTests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ backend = Diffractor.DiffractorForwardBackend()

@test bundle(1.0, 2.0) isa Diffractor.TaylorBundle{1}
@test bundle([1.0, 2.0], [2.0, 3.0]) isa Diffractor.TaylorBundle{1}
@test bundle(1.5=>2.5, Tangent{Pair{Float64, Float64}}(first=1.0, second=2.0)) isa Diffractor.CompositeBundle{1}
@test bundle(1.5=>2.5, Tangent{Pair{Float64, Float64}}(first=1.0, second=2.0)) isa Diffractor.TaylorBundle{1}
@test bundle(1.1, ChainRulesCore.ZeroTangent()) isa Diffractor.ZeroBundle{1}
@test bundle(1.5=>2.5=>3.5, Tangent{Pair{Float64, Pair{Float64, Float64}}}(first=1.0, second=Tangent{Pair{Float64, Float64}}(first=1.0, second=2.0))) isa Diffractor.CompositeBundle{1}
@test bundle(1.5=>2.5=>3.5, Tangent{Pair{Float64, Pair{Float64, Float64}}}(first=1.0, second=Tangent{Pair{Float64, Float64}}(first=1.0, second=2.0))) isa Diffractor.TaylorBundle{1}

# noncanonical structural tangent
b = bundle(1.5=>2.5=>3.5, Tangent{Pair{Float64, Pair{Float64, Float64}}}(second=Tangent{Pair{Float64, Float64}}(second=2.0, first=1.0)))
t = Diffractor.first_partial(b)
@test b isa Diffractor.CompositeBundle{1}
@test b isa Diffractor.TaylorBundle{1}
@test iszero(t.first)
@test t.second.first == 1.0
@test t.second.second == 2.0
Expand All @@ -29,7 +29,7 @@ end

# standard tests from AbstractDifferentiation.test_utils
include(joinpath(pathof(AbstractDifferentiation), "..", "..", "test", "test_utils.jl"))
@testset "ForwardDiffBackend" begin
@testset "Standard AbstractDifferentiation.test_utils tests" begin
backends = [
@inferred(Diffractor.DiffractorForwardBackend())
]
Expand Down
45 changes: 43 additions & 2 deletions test/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ let var"'" = Diffractor.PrimeDerivativeFwd
# Integration tests
@test recursive_sin'(1.0) == cos(1.0)
@test recursive_sin''(1.0) == -sin(1.0)
# Error: ArgumentError: Tangent for the primal Tangent{Tuple{Float64, Float64}, Tuple{Float64, Float64}}
# should be backed by a NamedTuple type, not by Tuple{Tangent{Tuple{Float64, Float64}, Tuple{Float64, Float64}}}.

@test_broken recursive_sin'''(1.0) == -cos(1.0)
@test_broken recursive_sin''''(1.0) == sin(1.0)
@test_broken recursive_sin'''''(1.0) == cos(1.0)
Expand Down Expand Up @@ -90,4 +89,46 @@ end
end
end


@testset "structs" begin
struct IDemo
x::Float64
y::Float64
end

function foo(a)
obj = IDemo(2.0, a)
return obj.x * obj.y
end

let var"'" = Diffractor.PrimeDerivativeFwd
@test foo'(100.0) == 2.0
@test foo''(100.0) == 0.0
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this second derivative is actually broken on main

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we know what change broke it?

end
end

@testset "tuples" begin
function foo(a)
tup = (2.0, a)
return first(tup) * tup[2]
end

let var"'" = Diffractor.PrimeDerivativeFwd
@test foo'(100.0) == 2.0
@test foo''(100.0) == 0.0
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this second derivative is actually broken on main

end
end

@testset "vararg" begin
function foo(a)
tup = (2.0, a)
return *(tup...)
end

let var"'" = Diffractor.PrimeDerivativeFwd
@test foo'(100.0) == 2.0
@test foo''(100.0) == 0.0
end
end

end
15 changes: 8 additions & 7 deletions test/tangent.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module tagent
using Diffractor
using Diffractor: AbstractZeroBundle, ZeroBundle, DNEBundle
using Diffractor: TaylorBundle, TaylorTangentIndex, CompositeBundle
using Diffractor: TaylorBundle, TaylorTangentIndex
using Diffractor: ExplicitTangent, TaylorTangent, truncate
using ChainRulesCore
using Test
Expand All @@ -28,21 +28,22 @@ using Test
end

@testset "AD through constructor" begin
#https://github.com/JuliaDiff/Diffractor.jl/issues/152
# hits `getindex(::CompositeBundle{Foo152}, ::TaylorTangentIndex)`
# https://github.com/JuliaDiff/Diffractor.jl/issues/152
# Though we have now removed the underlying cause, we keep this as a regression test just in case
struct Foo152
x::Float64
end

# Unit Test
cb = CompositeBundle{1, Foo152}((TaylorBundle{1, Float64}(23.5, (1.0,)),))
cb = TaylorBundle{1, Foo152}(Foo152(23.5), (Tangent{Foo152}(;x=1.0),))
oxinabox marked this conversation as resolved.
Show resolved Hide resolved
tti = TaylorTangentIndex(1,)
@test cb[tti] == Tangent{Foo152}(; x=1.0)

# Integration Test
var"'" = Diffractor.PrimeDerivativeFwd
f(x) = Foo152(x)
@test f'(23.5) == Tangent{Foo152}(; x=1.0)
let var"'" = Diffractor.PrimeDerivativeFwd
f(x) = Foo152(x)
@test f'(23.5) == Tangent{Foo152}(; x=1.0)
end
end

@testset "truncate" begin
Expand Down