-
Notifications
You must be signed in to change notification settings - Fork 31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove Composite Bundle #216
Changes from 5 commits
7c0641c
be091aa
c205e34
cf137a0
13762e7
f4f996d
50383b6
bfd99c2
121ec20
b257eeb
b9f74ce
34cedf8
9794666
ba5841e
40da34b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,14 +4,6 @@ partial(x::TaylorTangent, i) = getfield(getfield(x, :coeffs), i) | |
partial(x::UniformTangent, i) = getfield(x, :val) | ||
partial(x::ProductTangent, i) = ProductTangent(map(x->partial(x, i), getfield(x, :factors))) | ||
partial(x::AbstractZero, i) = x | ||
partial(x::CompositeBundle{N, B}, i) where {N, B<:Tuple} = Tangent{B}(map(x->partial(x, i), getfield(x, :tup))...) | ||
function partial(x::CompositeBundle{N, B}, i) where {N, B} | ||
# This is tangent for a struct, but fields partials are each stored in a plain tuple | ||
# so we add the names back using the primal `B` | ||
# TODO: If required this can be done as a `@generated` function so it is type-stable | ||
backing = NamedTuple{fieldnames(B)}(map(x->partial(x, i), getfield(x, :tup))) | ||
return Tangent{B, typeof(backing)}(backing) | ||
end | ||
|
||
|
||
primal(x::AbstractTangentBundle) = x.primal | ||
|
@@ -42,27 +34,20 @@ function shuffle_down(b::TaylorBundle{N, B}) where {N, B} | |
ntuple(_sdown, N-1)) | ||
end | ||
|
||
function shuffle_down(b::CompositeBundle{N, B}) where {N, B} | ||
z = CompositeBundle{N-1, CompositeBundle{1, B}}( | ||
(CompositeBundle{N-1, Tuple}( | ||
map(shuffle_down, b.tup) | ||
),) | ||
) | ||
z | ||
end | ||
|
||
function shuffle_up(r::CompositeBundle{1}) | ||
z₀ = primal(r.tup[1]) | ||
z₁ = partial(r.tup[1], 1) | ||
z₂ = primal(r.tup[2]) | ||
z₁₂ = partial(r.tup[2], 1) | ||
function shuffle_up(r::TaylorBundle{1, Tuple{B1,B2}}) where {B1,B2} | ||
z₀ = primal(r)[1] | ||
z₁ = partial(r, 1)[1] | ||
z₂ = primal(r)[2] | ||
z₁₂ = partial(r, 1)[2] | ||
if z₁ == z₂ | ||
return TaylorBundle{2}(z₀, (z₁, z₁₂)) | ||
else | ||
return ExplicitTangentBundle{2}(z₀, (z₁, z₂, z₁₂)) | ||
end | ||
end | ||
|
||
#== | ||
function taylor_compatible(a::ATB{N}, b::ATB{N}) where {N} | ||
primal(b) === a[TaylorTangentIndex(1)] || return false | ||
return all(1:(N-1)) do i | ||
|
@@ -76,6 +61,7 @@ isswifty(::UniformBundle) = true | |
isswifty(b::CompositeBundle) = all(isswifty, b.tup) | ||
isswifty(::Any) = false | ||
|
||
#TODO: port this to TaylorTangent over composite structures | ||
function shuffle_up(r::CompositeBundle{N}) where {N} | ||
a, b = r.tup | ||
if isswifty(a) && isswifty(b) && taylor_compatible(a, b) | ||
|
@@ -89,6 +75,7 @@ function shuffle_up(r::CompositeBundle{N}) where {N} | |
ntuple(i->partial(b,i), 1<<(N+1)-1)...)) | ||
end | ||
end | ||
==# | ||
oxinabox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
function shuffle_up(r::UniformBundle{N, B, U}) where {N, B, U} | ||
(a, b) = primal(r) | ||
|
@@ -185,13 +172,6 @@ end | |
map(y->lifted_getfield(y, s), x.tangent.coeffs)) | ||
end | ||
|
||
@Base.constprop :aggressive function (::∂☆{N})(::ATB{N, typeof(getfield)}, x::CompositeBundle{N}, s::AbstractTangentBundle{N, Int}) where {N} | ||
x.tup[primal(s)] | ||
end | ||
|
||
@Base.constprop :aggressive function (::∂☆{N})(::ATB{N, typeof(getfield)}, x::CompositeBundle{N, B}, s::AbstractTangentBundle{N, Symbol}) where {N, B} | ||
x.tup[Base.fieldindex(B, primal(s))] | ||
end | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this shouldn't be needed as we already have cases for |
||
|
||
@Base.constprop :aggressive function (::∂☆{N})(f::ATB{N, typeof(getfield)}, x::UniformBundle{N, <:Any, U}, s::AbstractTangentBundle{N}) where {N, U} | ||
UniformBundle{N,<:Any,U}(getfield(primal(x), primal(s)), x.tangent.val) | ||
|
@@ -210,9 +190,12 @@ struct FwdMap{N, T<:AbstractTangentBundle{N}} | |
end | ||
(f::FwdMap{N})(args::AbstractTangentBundle{N}...) where {N} = ∂☆{N}()(f.f, args...) | ||
|
||
#== | ||
# TODO port this to TaylorBundle over composite structure | ||
function (::∂☆{N})(::ZeroBundle{N, typeof(map)}, f::ATB{N}, tup::CompositeBundle{N, <:Tuple}) where {N} | ||
∂vararg{N}()(map(FwdMap(f), tup.tup)...) | ||
end | ||
==# | ||
|
||
function (::∂☆{N})(::ZeroBundle{N, typeof(map)}, f::ATB{N}, args::ATB{N, <:AbstractArray}...) where {N} | ||
# TODO: This could do an inplace map! to avoid the extra rebundling | ||
|
@@ -254,23 +237,28 @@ function (this::∂☆{N})(::ZeroBundle{N, typeof(Core._apply_iterate)}, iterate | |
Core._apply_iterate(FwdIterate(iterate), this, (f,), args...) | ||
end | ||
|
||
#== | ||
#TODO: port this to TaylorTangent over composite structures | ||
function (this::∂☆{N})(::ZeroBundle{N, typeof(iterate)}, t::CompositeBundle{N, <:Tuple}) where {N} | ||
r = iterate(t.tup) | ||
r === nothing && return ZeroBundle{N}(nothing) | ||
∂vararg{N}()(r[1], ZeroBundle{N}(r[2])) | ||
end | ||
|
||
#TODO: port this to TaylorTangent over composite structures | ||
function (this::∂☆{N})(::ZeroBundle{N, typeof(iterate)}, t::CompositeBundle{N, <:Tuple}, a::ATB{N}, args::ATB{N}...) where {N} | ||
r = iterate(t.tup, primal(a), map(primal, args)...) | ||
r === nothing && return ZeroBundle{N}(nothing) | ||
∂vararg{N}()(r[1], ZeroBundle{N}(r[2])) | ||
end | ||
|
||
#TODO: port this to TaylorTangent over composite structures | ||
function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::CompositeBundle{N, <:Tuple}, i::ATB{N}) where {N} | ||
r = Base.indexed_iterate(t.tup, primal(i)) | ||
∂vararg{N}()(r[1], ZeroBundle{N}(r[2])) | ||
end | ||
|
||
#TODO: port this to TaylorTangent over composite structures | ||
function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::CompositeBundle{N, <:Tuple}, i::ATB{N}, st1::ATB{N}, st::ATB{N}...) where {N} | ||
r = Base.indexed_iterate(t.tup, primal(i), primal(st1), map(primal, st)...) | ||
∂vararg{N}()(r[1], ZeroBundle{N}(r[2])) | ||
|
@@ -280,10 +268,11 @@ function (this::∂☆{N})(::ZeroBundle{N, typeof(Base.indexed_iterate)}, t::Tan | |
∂vararg{N}()(this(ZeroBundle{N}(getfield), t, i), ZeroBundle{N}(primal(i) + 1)) | ||
end | ||
|
||
|
||
#TODO: port this to TaylorTangent over composite structures | ||
function (this::∂☆{N})(::ZeroBundle{N, typeof(getindex)}, t::CompositeBundle{N, <:Tuple}, i::ZeroBundle) where {N} | ||
t.tup[primal(i)] | ||
end | ||
==# | ||
|
||
function (this::∂☆{N})(::ZeroBundle{N, typeof(typeof)}, x::ATB{N}) where {N} | ||
DNEBundle{N}(typeof(primal(x))) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -208,9 +208,7 @@ end | |
|
||
function check_taylor_invariants(coeffs, primal, N) | ||
@assert length(coeffs) == N | ||
if isa(primal, TangentBundle) | ||
@assert isa(coeffs[1], TangentBundle) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this check is just wrong AFAICT. But this was never hit before because we were not making |
||
end | ||
|
||
end | ||
@ChainRulesCore.non_differentiable check_taylor_invariants(coeffs, primal, N) | ||
|
||
|
@@ -290,33 +288,6 @@ end | |
|
||
Base.getindex(u::UniformBundle, ::TaylorTangentIndex) = u.tangent.val | ||
|
||
""" | ||
CompositeBundle{N, B, B <: Tuple} | ||
|
||
Represents the tagent bundle where the base space is some tuple or struct type. | ||
Mathematically, this tangent bundle is the product bundle of the individual | ||
element bundles. | ||
""" | ||
struct CompositeBundle{N, B, T<:Tuple{Vararg{AbstractTangentBundle{N}}}} <: AbstractTangentBundle{N, B} | ||
tup::T | ||
end | ||
CompositeBundle{N, B}(tup::T) where {N, B, T} = CompositeBundle{N, B, T}(tup) | ||
|
||
function Base.getindex(tb::CompositeBundle{N, B} where N, tti::TaylorTangentIndex) where {B} | ||
B <: SArray && error() | ||
return partial(tb, tti.i) | ||
end | ||
|
||
primal(b::CompositeBundle{N, <:Tuple} where N) = map(primal, b.tup) | ||
function primal(b::CompositeBundle{N, T} where N) where T<:CompositeBundle | ||
T(map(primal, b.tup)...) | ||
end | ||
@generated primal(b::CompositeBundle{N, B} where N) where {B} = | ||
quote | ||
x = map(primal, b.tup) | ||
$(Expr(:splatnew, B, :x)) | ||
end | ||
|
||
expand_singleton_to_array(asize, a::AbstractZero) = fill(a, asize...) | ||
expand_singleton_to_array(asize, a::AbstractArray) = a | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,8 +25,7 @@ let var"'" = Diffractor.PrimeDerivativeFwd | |
# Integration tests | ||
@test recursive_sin'(1.0) == cos(1.0) | ||
@test recursive_sin''(1.0) == -sin(1.0) | ||
# Error: ArgumentError: Tangent for the primal Tangent{Tuple{Float64, Float64}, Tuple{Float64, Float64}} | ||
# should be backed by a NamedTuple type, not by Tuple{Tangent{Tuple{Float64, Float64}, Tuple{Float64, Float64}}}. | ||
|
||
@test_broken recursive_sin'''(1.0) == -cos(1.0) | ||
@test_broken recursive_sin''''(1.0) == sin(1.0) | ||
@test_broken recursive_sin'''''(1.0) == cos(1.0) | ||
|
@@ -90,4 +89,46 @@ end | |
end | ||
end | ||
|
||
|
||
@testset "structs" begin | ||
struct IDemo | ||
x::Float64 | ||
y::Float64 | ||
end | ||
|
||
function foo(a) | ||
obj = IDemo(2.0, a) | ||
return obj.x * obj.y | ||
end | ||
|
||
let var"'" = Diffractor.PrimeDerivativeFwd | ||
@test foo'(100.0) == 2.0 | ||
@test foo''(100.0) == 0.0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this second derivative is actually broken on main There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we know what change broke it? |
||
end | ||
end | ||
|
||
@testset "tuples" begin | ||
function foo(a) | ||
tup = (2.0, a) | ||
return first(tup) * tup[2] | ||
end | ||
|
||
let var"'" = Diffractor.PrimeDerivativeFwd | ||
@test foo'(100.0) == 2.0 | ||
@test foo''(100.0) == 0.0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this second derivative is actually broken on main |
||
end | ||
end | ||
|
||
@testset "vararg" begin | ||
function foo(a) | ||
tup = (2.0, a) | ||
return *(tup...) | ||
end | ||
|
||
let var"'" = Diffractor.PrimeDerivativeFwd | ||
@test foo'(100.0) == 2.0 | ||
@test foo''(100.0) == 0.0 | ||
end | ||
end | ||
|
||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For reference as a part way step to determining this , I first refactored the CompositeBundle version into