Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Refactor forward mode data structures #54

Merged
merged 1 commit into from
Dec 27, 2022
Merged

WIP: Refactor forward mode data structures #54

merged 1 commit into from
Dec 27, 2022

Conversation

Keno
Copy link
Collaborator

@Keno Keno commented Sep 13, 2021

To allow chunking like ForwardDiff.

I think this works on the Diffractor side at this point, but there's a question about the interfaces with ChainRules. In particular, to what extent do we want to want/need ChainRules to be chunking-aware. Previously ChainRules did allow tuples/arrays in its forwards rules, but we removed that support in preparation for Diffractor, because having broadcasting in the rules makes them a lot more expensive at higher orders. I think our options here are:

  1. Put the broadcasting back into ChainRules
  2. Have some sort of ImplicitBroadcast type that wraps our values and implicitly broadcasts them as appropriate (so the rules look "scalar", but actually work with chunked values).
  3. Option 2 + some compiler support to implement ImplicitBroadcast even if the type annotations are too restrictive.
  4. Change the frule interface to also use closures like rrule, with the semantics being that for chunked evaluation, you broadcast the pushforward over the tangent vectors.

My current thinking is that either option 3 or 4 is perhaps attractive, but I'd like @oxinabox @mcabbott and @simeonschaub to weigh in with their opinions.

@oxinabox
Copy link
Member

4 is a nonstarter.
We used to do that.
But we removed it because solving ODEs and their sensitivities (including chunked) is best done all at once in one DE solve step, that that's primal and sensibility inputs.
Search ChainRulesCore for fusing.

It's also breaking, so not doing that.
For some cases though you can use the new experimental feature for derivatives without augmented primals, but that is not available for all cases. (And in theory is available for only a smaller infinite number of functions)

2 sounds good to me

test/runtests.jl Outdated Show resolved Hide resolved
Copy link
Member

@simeonschaub simeonschaub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd probably also be in favor of 2 or 3. I think it's probably best not to introduce another layer of closures into the frule convention.

Could you expand on what you mean by:

3\. Option 2 + some compiler support to implement `ImplicitBroadcast` even if the type annotations are too restrictive.

? In which cases is this necessary?

Comment on lines +111 to +112
This restriction forms a submanifold of the original manifold. The naming is
by analogy with the (truncated) Taylor series
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I must admit, I am still a little confused about the name TaylorTangent and why it's useful. That's mostly unrelated to this PR though.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It implements Taylor mode AD like https://github.com/JuliaDiff/TaylorSeries.jl.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought it still had separate coefficients for each variable, or did I misunderstand what the ∂ₙ are? https://github.com/JuliaDiff/TaylorSeries.jl#examples

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This docstring shows the analogy between the ∂ₙ and the coefficients of a taylor series. In particular, taylor mode takes advantage of symmetry of the derivative operators (i.e. the order of the operators is not distinguishable). For example, you can't use taylor mode to evaluate the second derivative in ∂₁ sin(∂₂ sin(x)), but you can in ∂^2 sin(sin(x)) = ∂₁ ∂₂ sin(sin(x)) = ∂₂ ∂₁. Because of this symmetry under exchange of the operators, you also get symmetry in the tangent space that you can then represent explicitly.

Comment on lines 9 to 12
first_partial(x::ExplicitTangentBundle{1}) = getfield(getfield(getfield(x, :tangent), :partials), 1)
first_partial(x::TaylorBundle{1}) = getfield(getfield(getfield(x, :tangent), :coeffs), 1)
first_partial(x::UniformBundle) = getfield(getfield(x, :tangent), :val)
first_partial(x::CompositeBundle) = map(first_partial, getfield(x, :tup))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we actually need to define first_partial separately? Can't we just use partial(x, 1) instead?

@simeonschaub
Copy link
Member

I am wondering whether it might even be enough to just overload + and scalar multiplication for ProductTangent (and perhaps some others ChainRulesCore defines for AbstractTangents). Or would that quickly run into issues?

@AnasAbdelR
Copy link

I had this PR open out of pure interest, accidently hit approved thinking I was on my other tab, please disregard 😅

@codecov-commenter
Copy link

codecov-commenter commented Sep 29, 2022

Codecov Report

Base: 55.54% // Head: 55.24% // Decreases project coverage by -0.29% ⚠️

Coverage data is based on head (c8fe4f9) compared to base (55d2871).
Patch coverage: 50.70% of modified lines in pull request are covered.

Additional details and impacted files
@@            Coverage Diff             @@
##             main      #54      +/-   ##
==========================================
- Coverage   55.54%   55.24%   -0.30%     
==========================================
  Files          21       21              
  Lines        2157     2163       +6     
==========================================
- Hits         1198     1195       -3     
- Misses        959      968       +9     
Impacted Files Coverage Δ
src/stage1/mixed.jl 0.00% <0.00%> (ø)
src/tangent.jl 32.67% <37.50%> (-1.37%) ⬇️
src/jet.jl 39.66% <40.00%> (ø)
src/stage1/forward.jl 73.28% <73.91%> (-0.87%) ⬇️
src/interface.jl 69.09% <100.00%> (-1.82%) ⬇️
src/stage1/recurse_fwd.jl 93.33% <0.00%> (-0.96%) ⬇️

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report at Codecov.
📢 Do you have feedback about the report comment? Let us know in this issue.

@Keno
Copy link
Collaborator Author

Keno commented Dec 27, 2022

I think I'm pretty convinced that the broadcast problem should be solved in a separate package that support SIMT semantics over arbitrary julia functions that we would depend on here.

@Keno
Copy link
Collaborator Author

Keno commented Dec 27, 2022

I'm planning to merge this as is and then do further work on this on master. Looks like there may still be some bugs remaining, but I want to merge the various threads first.

To allow chunking like ForwardDiff.

import `∂⃖¹`

some fixes
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants