Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove Composite Bundle #216

Merged
merged 15 commits into from
Sep 26, 2023
Merged

Remove Composite Bundle #216

merged 15 commits into from
Sep 26, 2023

Conversation

oxinabox
Copy link
Member

This is a prerequisite for #189 as how CompositeBundle works, causes a problem when it comes to mutation.

Right now CompositeBundle stores everything as a tuple of other Bundles, one per field of the primal.
Rather than storing things as a a a primal and a tangent (or really 1 or more tangents).
This is a problem because when we need to get stuff back seperated into primal and tangent (which we need to do to apply frules or to actually get at the derivative result) we need to assemble the Tangent (note the capital T, this is the type that represents an immutable structural tangent) on the fly.
Which we do here:

function partial(x::CompositeBundle{N, B}, i) where {N, B}
# This is tangent for a struct, but fields partials are each stored in a plain tuple
# so we add the names back using the primal `B`
# TODO: If required this can be done as a `@generated` function so it is type-stable
backing = NamedTuple{fieldnames(B)}(map(x->partial(x, i), getfield(x, :tup)))
return Tangent{B, typeof(backing)}(backing)
end

That is all well and good if everything is immuatable, since there is no need to preserve mutation or worry about aliasing if things are immuable.
If on the other hand, things are mutable, you do need to worry about this things. In particular the MutableTangent type must be aliased in the same places its primal is (infact in the PR JuliaDiff/ChainRulesCore.jl#626 that created that, I always made a note in the docstring), you very much can not break it a part and recreate it on demand.
Ergo, I think CompositeBundle needs to be removed and replaced with something that works like the other tangent types, a pair of primal and tangent values.
Namely TaylorTangent in most cases, though in theory it might end up with an ExplictTangent

However, actually removing CompositeBundle seems a net significant simplification of the code base.
This PR is net negative in lines of code in src/ and will get even more negative once I work out which of the remaining const prop aggressive special cases can be deleted.
It also fixes some existing issues and I think will make others easier to fix in the future.

src/stage1/forward.jl Outdated Show resolved Hide resolved

@Base.constprop :aggressive function (::∂☆{N})(::ATB{N, typeof(getfield)}, x::CompositeBundle{N, B}, s::AbstractTangentBundle{N, Symbol}) where {N, B}
x.tup[Base.fieldindex(B, primal(s))]
end
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this shouldn't be needed as we already have cases for TaylorBundle above.
Likewise several others of this kind below will not be needed.

test/forward.jl Outdated Show resolved Hide resolved

let var"'" = Diffractor.PrimeDerivativeFwd
@test foo'(100.0) == 2.0
@test foo''(100.0) == 0.0
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this second derivative is actually broken on main

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we know what change broke it?


let var"'" = Diffractor.PrimeDerivativeFwd
@test foo'(100.0) == 2.0
@test foo''(100.0) == 0.0
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this second derivative is actually broken on main

@codecov
Copy link

codecov bot commented Sep 20, 2023

Codecov Report

Attention: 12 lines in your changes are missing coverage. Please review.

Comparison is base (ee8af85) 54.90% compared to head (40da34b) 55.05%.
Report is 2 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #216      +/-   ##
==========================================
+ Coverage   54.90%   55.05%   +0.14%     
==========================================
  Files          28       28              
  Lines        2801     2810       +9     
==========================================
+ Hits         1538     1547       +9     
  Misses       1263     1263              
Files Coverage Δ
src/AbstractDifferentiation.jl 100.00% <100.00%> (ø)
src/stage1/compiler_utils.jl 95.45% <ø> (ø)
src/stage1/recurse_fwd.jl 94.82% <100.00%> (+3.91%) ⬆️
src/tangent.jl 46.45% <100.00%> (-3.55%) ⬇️
src/stage1/forward.jl 68.32% <65.71%> (-0.17%) ⬇️

... and 1 file with indirect coverage changes

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

_TangentBundle(Val{N}(), primal, TaylorTangent(coeffs))
end

function check_taylor_invariants(coeffs, primal, N)
@assert length(coeffs) == N
if isa(primal, TangentBundle)
@assert isa(coeffs[1], TangentBundle)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this check is just wrong AFAICT.
because if the primal is a TangentBundle, then the first derivative is a Tangent{TangentBundle}
The first deriviative has no primal component.

But this was never hit before because we were not making TaylorBundles when taking derivatives of TaylorBundles, (at least not in anything that happens in our test/forward.jl I checked) we were (I assume) making CompositeBundles

a_partials = ntuple(i->partial(r, ii)[1], N)
b_partials = ntuple(i->partial(r, ii)[2], N)
the_partials = (a_partials..., primal_b, b_partials...)
return TangentBundle{N+1}(the_primal, the_partials)
end
end
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For reference as a part way step to determining this , I first refactored the CompositeBundle version into

function shuffle_up(r::CompositeBundle{N}) where {N}
    a, b = r.tup
    if isswifty(a) && isswifty(b) && taylor_compatible(a, b)
        the_partials = ntuple(N+1) do i
            if ii <= N
                a[TaylorTangentIndex(i)]  # == b[TaylorTangentIndex(i-1)] (except first which is b.primal)
            else  # ii = N+1
                b[TaylorTangentIndex(i-1)]                
            end
        end
        return TaylorBundle{N+1}(primal(a), the_partials)
    else
        the_primal = r.tup[1].primal
        a_partials = r.tup[1].tangent.partials
        b_partials = ntuple(i->partial(b,i), 1<<(N+1)-1) 
        the_partials = (a_partials..., primal_b, b_partials...)
        return TangentBundle{N+1}(the_primal, the_partials)
    end
end

use fieldcount

taylor bundle not taylor tangent
Comment on lines +231 to +240
"for a TaylorTangent{N, <:Tuple} this breaks it up unto 1 TaylorTangent{N} for each element of the primal tuple"
function destructure(r::TaylorBundle{N, B}) where {N, B<:Tuple}
return ntuple(fieldcount(B)) do field_ii
the_primal = primal(r)[field_ii]
the_partials = ntuple(N) do order_ii
partial(r, order_ii)[field_ii]
end
return TaylorBundle{N}(the_primal, the_partials)
end
end
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this, or can we just use unbundle?

Comment on lines 76 to 79
a_partials = ntuple(i->partial(r, ii)[1], N)
b_partials = ntuple(i->partial(r, ii)[2], N)
the_partials = (a_partials..., primal_b, b_partials...)
return TangentBundle{N+1}(the_primal, the_partials)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my problem with this is I don;t think that b_partials length 1<<(N+1)-1
because it has length N
and those are never equal

@oxinabox oxinabox marked this pull request as ready for review September 20, 2023 11:55
@oxinabox
Copy link
Member Author

Nightly is failing just cos of reverse mode, unreleated to this PR

src/stage1/recurse_fwd.jl Outdated Show resolved Hide resolved
@Keno
Copy link
Collaborator

Keno commented Sep 20, 2023

I don't exactly remember what the original motivation for this was, but if everything goes through fine, I have no particular attachment to this type.

src/stage1/forward.jl Outdated Show resolved Hide resolved
src/stage1/forward.jl Outdated Show resolved Hide resolved
test/tangent.jl Outdated Show resolved Hide resolved
@oxinabox oxinabox merged commit ffd8840 into main Sep 26, 2023
@oscardssmith oscardssmith deleted the ox/nocompo branch September 26, 2023 12:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants