Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove Composite Bundle #216

Merged
merged 15 commits into from
Sep 26, 2023
14 changes: 1 addition & 13 deletions src/AbstractDifferentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,7 @@ This is more or less the Diffractor equivelent of ForwardDiff.jl's `Dual` type.
"""
function bundle end
bundle(x, dx::ChainRulesCore.AbstractZero) = UniformBundle{1, typeof(x), typeof(dx)}(x, dx)
bundle(x::Number, dx::Number) = TaylorBundle{1}(x, (dx,))
bundle(x::AbstractArray{<:Number}, dx::AbstractArray{<:Number}) = TaylorBundle{1}(x, (dx,))
bundle(x::P, dx::Tangent{P}) where P = _bundle(x, ChainRulesCore.canonicalize(dx))

"helper that assumes tangent is in canonical form"
function _bundle(x::P, dx::Tangent{P}) where P
# SoA to AoS flip (hate this, hate it even more cos we just undo it later when we hit chainrules)
the_bundle = ntuple(Val{fieldcount(P)}()) do ii
bundle(getfield(x, ii), getproperty(dx, ii))
end
return CompositeBundle{1, P}(the_bundle)
end

bundle(x, dx) = TaylorBundle{1}(x, (dx,))

AD.@primitive function pushforward_function(b::DiffractorForwardBackend, f, args...)
return function pushforward(vs)
Expand Down
4 changes: 1 addition & 3 deletions src/stage1/compiler_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@ function Base.push!(cfg::CFG, bb::BasicBlock)
push!(cfg.index, bb.stmts.start)
end

if VERSION <= v"1.11.0-DEV.116"
Base.getindex(ir::IRCode, ssa::SSAValue) = Core.Compiler.getindex(ir, ssa)
end
oxinabox marked this conversation as resolved.
Show resolved Hide resolved
Base.getindex(ir::IRCode, ssa::SSAValue) = Core.Compiler.getindex(ir, ssa)

Base.copy(ir::IRCode) = Core.Compiler.copy(ir)

Expand Down
84 changes: 37 additions & 47 deletions src/stage1/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,6 @@ partial(x::TaylorTangent, i) = getfield(getfield(x, :coeffs), i)
partial(x::UniformTangent, i) = getfield(x, :val)
partial(x::ProductTangent, i) = ProductTangent(map(x->partial(x, i), getfield(x, :factors)))
partial(x::AbstractZero, i) = x
partial(x::CompositeBundle{N, B}, i) where {N, B<:Tuple} = Tangent{B}(map(x->partial(x, i), getfield(x, :tup))...)
function partial(x::CompositeBundle{N, B}, i) where {N, B}
# This is tangent for a struct, but fields partials are each stored in a plain tuple
# so we add the names back using the primal `B`
# TODO: If required this can be done as a `@generated` function so it is type-stable
backing = NamedTuple{fieldnames(B)}(map(x->partial(x, i), getfield(x, :tup)))
return Tangent{B, typeof(backing)}(backing)
end


primal(x::AbstractTangentBundle) = x.primal
Expand Down Expand Up @@ -42,20 +34,12 @@ function shuffle_down(b::TaylorBundle{N, B}) where {N, B}
ntuple(_sdown, N-1))
end

function shuffle_down(b::CompositeBundle{N, B}) where {N, B}
z = CompositeBundle{N-1, CompositeBundle{1, B}}(
(CompositeBundle{N-1, Tuple}(
map(shuffle_down, b.tup)
),)
)
z
end

function shuffle_up(r::CompositeBundle{1})
z₀ = primal(r.tup[1])
z₁ = partial(r.tup[1], 1)
z₂ = primal(r.tup[2])
z₁₂ = partial(r.tup[2], 1)
function shuffle_up(r::TaylorBundle{1, Tuple{B1,B2}}) where {B1,B2}
z₀ = primal(r)[1]
z₁ = partial(r, 1)[1]
z₂ = primal(r)[2]
z₁₂ = partial(r, 1)[2]
if z₁ == z₂
return TaylorBundle{2}(z₀, (z₁, z₁₂))
else
Expand All @@ -70,26 +54,33 @@ function taylor_compatible(a::ATB{N}, b::ATB{N}) where {N}
end
end

# Check whether the tangent bundle element is taylor-like
isswifty(::TaylorBundle) = true
isswifty(::UniformBundle) = true
isswifty(b::CompositeBundle) = all(isswifty, b.tup)
isswifty(::Any) = false

function shuffle_up(r::CompositeBundle{N}) where {N}
a, b = r.tup
if isswifty(a) && isswifty(b) && taylor_compatible(a, b)
return TaylorBundle{N+1}(primal(a),
ntuple(i->i == N+1 ?
b[TaylorTangentIndex(i-1)] : a[TaylorTangentIndex(i)],
N+1))
function taylor_compatible(r::TaylorBundle{N, Tuple{B1,B2}}) where {N, B1,B2}
partial(r, 1)[1] = primal(r)[2] || return false
return all(1:N-1) do ii
partial(r, i+1)[1] == partial(r, i)[2]
end
end
function shuffle_up(r::TaylorBundle{N, Tuple{B1,B2}}) where {N, B1,B2}
the_primal = primal(r)[1]
if taylor_compatible(r)
the_partials = ntuple(N+1) do i
if ii <= N
partial(r, i)[1] # == `partial(r,i-1)[2]` (except first which is primal(r)[2])
else # ii = N+1
partial(r, i-1)[2]
end
end
return TaylorBundle{N+1}(the_primal, the_partials)
else
return TangentBundle{N+1}(r.tup[1].primal,
(r.tup[1].tangent.partials..., primal(b),
ntuple(i->partial(b,i), 1<<(N+1)-1)...))
#XXX: am dubious of the correctness of this
a_partials = ntuple(i->partial(r, ii)[1], N)
b_partials = ntuple(i->partial(r, ii)[2], N)
oxinabox marked this conversation as resolved.
Show resolved Hide resolved
oxinabox marked this conversation as resolved.
Show resolved Hide resolved
the_partials = (a_partials..., primal_b, b_partials...)
return TangentBundle{N+1}(the_primal, the_partials)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my problem with this is I don;t think that b_partials length 1<<(N+1)-1
because it has length N
and those are never equal

end
end
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For reference as a part way step to determining this , I first refactored the CompositeBundle version into

function shuffle_up(r::CompositeBundle{N}) where {N}
    a, b = r.tup
    if isswifty(a) && isswifty(b) && taylor_compatible(a, b)
        the_partials = ntuple(N+1) do i
            if ii <= N
                a[TaylorTangentIndex(i)]  # == b[TaylorTangentIndex(i-1)] (except first which is b.primal)
            else  # ii = N+1
                b[TaylorTangentIndex(i-1)]                
            end
        end
        return TaylorBundle{N+1}(primal(a), the_partials)
    else
        the_primal = r.tup[1].primal
        a_partials = r.tup[1].tangent.partials
        b_partials = ntuple(i->partial(b,i), 1<<(N+1)-1) 
        the_partials = (a_partials..., primal_b, b_partials...)
        return TangentBundle{N+1}(the_primal, the_partials)
    end
end



function shuffle_up(r::UniformBundle{N, B, U}) where {N, B, U}
(a, b) = primal(r)
if r.tangent.val === b
Expand Down Expand Up @@ -185,13 +176,6 @@ end
map(y->lifted_getfield(y, s), x.tangent.coeffs))
end

@Base.constprop :aggressive function (::∂☆{N})(::ATB{N, typeof(getfield)}, x::CompositeBundle{N}, s::AbstractTangentBundle{N, Int}) where {N}
x.tup[primal(s)]
end

@Base.constprop :aggressive function (::∂☆{N})(::ATB{N, typeof(getfield)}, x::CompositeBundle{N, B}, s::AbstractTangentBundle{N, Symbol}) where {N, B}
x.tup[Base.fieldindex(B, primal(s))]
end
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this shouldn't be needed as we already have cases for TaylorBundle above.
Likewise several others of this kind below will not be needed.


@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::UniformBundle{N, <:Any, U}, s::AbstractTangentBundle{N}) where {N, U}
UniformBundle{N,<:Any,U}(getfield(primal(x), primal(s)), x.tangent.val)
Expand All @@ -210,8 +194,8 @@ struct FwdMap{N, T<:AbstractTangentBundle{N}}
end
(f::FwdMap{N})(args::AbstractTangentBundle{N}...) where {N} = ∂☆{N}()(f.f, args...)

function (::∂☆{N})(::ZeroBundle{N, typeof(map)}, f::ATB{N}, tup::CompositeBundle{N, <:Tuple}) where {N}
∂vararg{N}()(map(FwdMap(f), tup.tup)...)
function (::∂☆{N})(::ZeroBundle{N, typeof(map)}, f::ATB{N}, tup::TaylorBundle{N, <:Tuple}) where {N}
∂vararg{N}()(map(FwdMap(f), destructure(tup))...)
end

function (::∂☆{N})(::ZeroBundle{N, typeof(map)}, f::ATB{N}, args::ATB{N, <:AbstractArray}...) where {N}
Expand Down Expand Up @@ -254,23 +238,28 @@ function (this::∂☆{N})(::ZeroBundle{N, typeof(Core._apply_iterate)}, iterate
Core._apply_iterate(FwdIterate(iterate), this, (f,), args...)
end

#==
#TODO: port this to TaylorTangent over composite structures
function (this::∂☆{N})(::ZeroBundle{N, typeof(iterate)}, t::CompositeBundle{N, <:Tuple}) where {N}
r = iterate(t.tup)
r === nothing && return ZeroBundle{N}(nothing)
∂vararg{N}()(r[1], ZeroBundle{N}(r[2]))
end

#TODO: port this to TaylorTangent over composite structures
function (this::∂☆{N})(::ZeroBundle{N, typeof(iterate)}, t::CompositeBundle{N, <:Tuple}, a::ATB{N}, args::ATB{N}...) where {N}
r = iterate(t.tup, primal(a), map(primal, args)...)
r === nothing && return ZeroBundle{N}(nothing)
∂vararg{N}()(r[1], ZeroBundle{N}(r[2]))
end

#TODO: port this to TaylorTangent over composite structures
function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::CompositeBundle{N, <:Tuple}, i::ATB{N}) where {N}
r = Base.indexed_iterate(t.tup, primal(i))
∂vararg{N}()(r[1], ZeroBundle{N}(r[2]))
end

#TODO: port this to TaylorTangent over composite structures
function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::CompositeBundle{N, <:Tuple}, i::ATB{N}, st1::ATB{N}, st::ATB{N}...) where {N}
r = Base.indexed_iterate(t.tup, primal(i), primal(st1), map(primal, st)...)
∂vararg{N}()(r[1], ZeroBundle{N}(r[2]))
Expand All @@ -280,10 +269,11 @@ function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::Tan
∂vararg{N}()(this(ZeroBundle{N}(getfield), t, i), ZeroBundle{N}(primal(i) + 1))
end


#TODO: port this to TaylorTangent over composite structures
function (this::∂☆{N})(::ZeroBundle{N, typeof(getindex)}, t::CompositeBundle{N, <:Tuple}, i::ZeroBundle) where {N}
t.tup[primal(i)]
end
==#

function (this::∂☆{N})(::ZeroBundle{N, typeof(typeof)}, x::ATB{N}) where {N}
DNEBundle{N}(typeof(primal(x)))
Expand Down
44 changes: 41 additions & 3 deletions src/stage1/recurse_fwd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,51 @@ struct ∂vararg{N}; end

(::∂vararg{N})() where {N} = ZeroBundle{N}(())
function (::∂vararg{N})(a::AbstractTangentBundle{N}...) where N
CompositeBundle{N, Tuple{map(x->basespace(typeof(x)), a)...}}(a)
B = Tuple{map(x->basespace(Core.Typeof(x)), a)...}
return (∂☆new{N}())(B, a...)
end

struct ∂☆new{N}; end

(::∂☆new{N})(B::Type, a::AbstractTangentBundle{N}...) where {N} =
CompositeBundle{N, B}(a)
# we split out the 1st order derivative as a special case for performance
# but the nth order case does also work for this
function (::∂☆new{1})(B::Type, xs::AbstractTangentBundle{1}...)
oxinabox marked this conversation as resolved.
Show resolved Hide resolved
primal_args = map(primal, xs)
the_primal = _construct(B, primal_args)

tangent_tup = map(first_partial, xs)
the_partial = if B<:Tuple
Tangent{B, typeof(tangent_tup)}(tangent_tup)
else
names = fieldnames(B)
tangent_nt = NamedTuple{names}(tangent_tup)
Tangent{B, typeof(tangent_nt)}(tangent_nt)
end
return TaylorBundle{1, B}(the_primal, (the_partial,))
end

function (::∂☆new{N})(B::Type, xs::AbstractTangentBundle{N}...) where {N}
primal_args = map(primal, xs)
the_primal = _construct(B, primal_args)

the_partials = ntuple(Val{N}()) do ii
iith_order_type = ii==1 ? B : Any # the type of the higher order tangents isn't worth tracking
tangent_tup = map(x->partial(x, ii), xs)
tangent = if B<:Tuple
Tangent{iith_order_type, typeof(tangent_tup)}(tangent_tup)
else
names = fieldnames(B)
tangent_nt = NamedTuple{names}(tangent_tup)
Tangent{iith_order_type, typeof(tangent_nt)}(tangent_nt)
end
return tangent
end
return TaylorBundle{N, B}(the_primal, the_partials)
end

_construct(::Type{B}, args) where B<:Tuple = B(args)
# Hack for making things that do not have public constructors constructable:
@generated _construct(B::Type, args) = :($(Expr(:splatnew, :B, :args)))
oxinabox marked this conversation as resolved.
Show resolved Hide resolved

@generated (::∂☆new{N})(B::Type) where {N} = return :(ZeroBundle{$N}($(Expr(:new, :B))))

Expand Down
43 changes: 13 additions & 30 deletions src/tangent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,7 @@ end

function check_taylor_invariants(coeffs, primal, N)
@assert length(coeffs) == N
if isa(primal, TangentBundle)
@assert isa(coeffs[1], TangentBundle)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this check is just wrong AFAICT.
because if the primal is a TangentBundle, then the first derivative is a Tangent{TangentBundle}
The first deriviative has no primal component.

But this was never hit before because we were not making TaylorBundles when taking derivatives of TaylorBundles, (at least not in anything that happens in our test/forward.jl I checked) we were (I assume) making CompositeBundles

end

end
@ChainRulesCore.non_differentiable check_taylor_invariants(coeffs, primal, N)

Expand All @@ -230,6 +228,18 @@ function Base.getindex(tb::TaylorBundle, tti::CanonicalTangentIndex)
tb.tangent.coeffs[count_ones(tti.i)]
end

"for a TaylorTangent{N, <:Tuple} this breaks it up unto 1 TaylorTangent{N} for each element of the primal tuple"
function destructure(r::TaylorBundle{N, B}) where {N, B<:Tuple}
return ntuple(fieldcount(B)) do field_ii
the_primal = primal(r)[field_ii]
the_partials = ntuple(N) do order_ii
partial(r, order_ii)[field_ii]
end
return TaylorBundle{N}(the_primal, the_partials)
end
end
Comment on lines +231 to +240
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this, or can we just use unbundle?



function truncate(tt::TaylorTangent, order::Val{N}) where {N}
TaylorTangent(tt.coeffs[1:N])
end
Expand Down Expand Up @@ -290,33 +300,6 @@ end

Base.getindex(u::UniformBundle, ::TaylorTangentIndex) = u.tangent.val

"""
CompositeBundle{N, B, B <: Tuple}

Represents the tagent bundle where the base space is some tuple or struct type.
Mathematically, this tangent bundle is the product bundle of the individual
element bundles.
"""
struct CompositeBundle{N, B, T<:Tuple{Vararg{AbstractTangentBundle{N}}}} <: AbstractTangentBundle{N, B}
tup::T
end
CompositeBundle{N, B}(tup::T) where {N, B, T} = CompositeBundle{N, B, T}(tup)

function Base.getindex(tb::CompositeBundle{N, B} where N, tti::TaylorTangentIndex) where {B}
B <: SArray && error()
return partial(tb, tti.i)
end

primal(b::CompositeBundle{N, <:Tuple} where N) = map(primal, b.tup)
function primal(b::CompositeBundle{N, T} where N) where T<:CompositeBundle
T(map(primal, b.tup)...)
end
@generated primal(b::CompositeBundle{N, B} where N) where {B} =
quote
x = map(primal, b.tup)
$(Expr(:splatnew, B, :x))
end

expand_singleton_to_array(asize, a::AbstractZero) = fill(a, asize...)
expand_singleton_to_array(asize, a::AbstractArray) = a

Expand Down
8 changes: 4 additions & 4 deletions test/AbstractDifferentiationTests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ backend = Diffractor.DiffractorForwardBackend()

@test bundle(1.0, 2.0) isa Diffractor.TaylorBundle{1}
@test bundle([1.0, 2.0], [2.0, 3.0]) isa Diffractor.TaylorBundle{1}
@test bundle(1.5=>2.5, Tangent{Pair{Float64, Float64}}(first=1.0, second=2.0)) isa Diffractor.CompositeBundle{1}
@test bundle(1.5=>2.5, Tangent{Pair{Float64, Float64}}(first=1.0, second=2.0)) isa Diffractor.TaylorBundle{1}
@test bundle(1.1, ChainRulesCore.ZeroTangent()) isa Diffractor.ZeroBundle{1}
@test bundle(1.5=>2.5=>3.5, Tangent{Pair{Float64, Pair{Float64, Float64}}}(first=1.0, second=Tangent{Pair{Float64, Float64}}(first=1.0, second=2.0))) isa Diffractor.CompositeBundle{1}
@test bundle(1.5=>2.5=>3.5, Tangent{Pair{Float64, Pair{Float64, Float64}}}(first=1.0, second=Tangent{Pair{Float64, Float64}}(first=1.0, second=2.0))) isa Diffractor.TaylorBundle{1}

# noncanonical structural tangent
b = bundle(1.5=>2.5=>3.5, Tangent{Pair{Float64, Pair{Float64, Float64}}}(second=Tangent{Pair{Float64, Float64}}(second=2.0, first=1.0)))
t = Diffractor.first_partial(b)
@test b isa Diffractor.CompositeBundle{1}
@test b isa Diffractor.TaylorBundle{1}
@test iszero(t.first)
@test t.second.first == 1.0
@test t.second.second == 2.0
Expand All @@ -29,7 +29,7 @@ end

# standard tests from AbstractDifferentiation.test_utils
include(joinpath(pathof(AbstractDifferentiation), "..", "..", "test", "test_utils.jl"))
@testset "ForwardDiffBackend" begin
@testset "Standard AbstractDifferentiation.test_utils tests" begin
backends = [
@inferred(Diffractor.DiffractorForwardBackend())
]
Expand Down
45 changes: 43 additions & 2 deletions test/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ let var"'" = Diffractor.PrimeDerivativeFwd
# Integration tests
@test recursive_sin'(1.0) == cos(1.0)
@test recursive_sin''(1.0) == -sin(1.0)
# Error: ArgumentError: Tangent for the primal Tangent{Tuple{Float64, Float64}, Tuple{Float64, Float64}}
# should be backed by a NamedTuple type, not by Tuple{Tangent{Tuple{Float64, Float64}, Tuple{Float64, Float64}}}.

@test_broken recursive_sin'''(1.0) == -cos(1.0)
@test_broken recursive_sin''''(1.0) == sin(1.0)
@test_broken recursive_sin'''''(1.0) == cos(1.0)
Expand Down Expand Up @@ -90,4 +89,46 @@ end
end
end


@testset "structs" begin
struct IDemo
x::Float64
y::Float64
end

function foo(a)
obj = IDemo(2.0, a)
return obj.x * obj.y
end

let var"'" = Diffractor.PrimeDerivativeFwd
@test foo'(100.0) == 2.0
@test foo''(100.0) == 0.0
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this second derivative is actually broken on main

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we know what change broke it?

end
end

@testset "tuples" begin
function foo(a)
tup = (2.0, a)
return first(tup) * tup[2]
end

let var"'" = Diffractor.PrimeDerivativeFwd
@test foo'(100.0) == 2.0
@test foo''(100.0) == 0.0
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this second derivative is actually broken on main

end
end

@testset "vararg" begin
function foo(a)
tup = (2.0, a)
return *(tup...)
end

let var"'" = Diffractor.PrimeDerivativeFwd
@test foo'(100.0) == 2.0
@test foo''(100.0) == 0.0
end
end

end
Loading