Skip to content
This repository has been archived by the owner on Apr 18, 2023. It is now read-only.

Use ChainRules #189

Merged
merged 71 commits into from
Jul 5, 2021
Merged

Use ChainRules #189

merged 71 commits into from
Jul 5, 2021

Conversation

oxinabox
Copy link
Member

@oxinabox oxinabox commented Sep 3, 2020

Supercedes #178

Follows https://www.juliadiff.org/ChainRulesCore.jl/dev/autodiff/operator_overloading.html
(which will probably get updates during this based on practical learnings)

What this PR does:

  • Removes DiffRules + a ton of internal rules in favour of ChainRules.jl
    • Main Focus of this PR
    • Few files are outright deleted, but many are much much reduced
  • Support Differential types, including Composite, Thunks, and InplacableThunks
    • This actually took almost no work at all, Nabla doesn’t internally do much that can conflict with those.
    • It’s internal method for inplace accumulation just needed Nabla.update! swapped out for the very similar add!! so it would work for InplaceableThunks
    • Final output needed unthunking performed so that the public API did not contain thunks.
  • Remove DualNumbers.jl in favour of ForwardDiff.jl
    • Uses the internal DualNumber type of ForwardDiff.
    • Possibly this should be changed to be using a public higher level interface, but that might be better in a follow up PR.
    • Some parts are already change to use ForwardDiff.derivative but for multiargument things it still uses the Dual numbers directly.
    • The main reason for this change is it lets us not need to special-case higher-order functions operating on scalars that we have DiffRules for, as that happens internally in ForwardDiff.jl
    • There is an open issue to effectively do a ton of these special cases in ChainRules.jl Higher-Order functions: Low-hanging fruit JuliaDiff/ChainRules.jl#222
  • Introduce node_type tranform that is like unionise_type transform but without making the union
  • Change preprocess not receive its inputs pre-unboxe-d, but have the default fallback unbox them and recall process
    • We need it not to be unbox-ed as we need the original branch object in order to be able to able to get at the pullback
    • The fallback preserves backward compatibility. Though i don’t think anything is using it.
  • Add a bunch of comments about what is going on and why. To the existing code.
  • Fix handling in conde_transformations for when encountering VarArg{T, N} where N
  • Many lists of operations that we no longer defined in the source code are moved into tests so that we can check that the operations still work.
  • Drops support for SpecialFunctions 0.9. add support for SpecialFunction 0.10
    • Stop testing lgamma/loggamma as they don't both exist in nondeprecated form in same version
    • Stops testing lbeta as it is now weird and we don’t use it any where anyway.

Things i suggest leaving for potential future PRs

(but that reviewers might disagree with)

  • Move more rules into ChainRules
  • Remove use of :no_N in src/conde_tranformations/utils.jl. Which doesn’t seem to ever be hit
  • Remove a bunch of fields from the Branch type as they are not used
  • Maybe even go full ReverseDiffZero and just store a propagate function, a mutable partial, and the tape; and get rid of a lot of the internal machinery for the reverse tape etc.
  • Making Pair{Node} return Node{Pair} for consistency and so diagm will hit rules we define in ChainRules.jl

Notes on implementation

The core logic is to use of the Operator Overloading interface of ChainRules, which lets you register a hook that is triggered passing in a type- type representing the signature of every primal function that ChainRulesCore has an overload of rrule for.

This hook is the generate_overload function.

This filters out a bunch of things.

It then uses ExprTools to get a AST for function defination that would be suitable for overloading the primal function (as an overloading based AD like Nabla does).

From that it generates: overloads for that primal but with in turn each argument swapped out for the matching node (this is why node_type was added to the code tranformation functions).

And earlier version use unionise_type instead of swapping it out, but for things with primal type of Any (which shows up for nondifferentiable_rule), this just resulted in Union{Node{Any}, Any} which simplifies to Any. Which mean we were overwriting the original primal definition which will break everything.

The key thing these generated primal overloads do is create a Branch that stores the pullback.

We then generate a method for preprocess which invokes that pullback, computing the partials for all the arguments.
And we generate a method for that just talkes that partial computed by preprocess and return the right one for the specified Arg{N}.

Things to do before Review

  • Should this PR be broken up before the review?
    • Probably not, it is hard to do
  • Should this PR be partially squashed before review?
    • Probably, as much as it easily can be to remove some of the WIP commits and things that were undone.
    • Probably not worth the effort to do things that require reordering
  • Write a list of everything this PR does.
    • Add it to this document
  • Should we explicitly have multiple rounds of review, focussing on different things?
  • Should this be merged into master after being accepted or squashed merged into some staging branch.
    • Likely we will want to do a round of performance checking and follow up PRs there.
  • Who should review this PR?

Things for reviewers to consider:

  • How is our testing?
    • This PR doesn’t really delete any of the tests even of things that have moved.
      • Figure that leaving them there gives an extensive set of integration tests.
    • This PR doesn’t really add many tests of its own, even of the rule generation code.
      • It’s not part of the public API
      • In effect it is extensively covered by the integration tests on all the sensitivities.
      • Is this enough?
  • Which parts should move to ExprTools.jl?
    • In particular from the src/sensitivities/chainrules.jl file
    • There are two key reasons we might want to move some of this into ExprTools.
        1. Similar things will be needed by other packages wanting to also use the overload generation API of ChainRules. E.g. ReverseDiff.jl is planning on using it. As is ForwardDiff2.
        1. We are actually basically entirely using internal APIs right now. So we either need to move something out from Nabla that exposes what Nabla needs as a single public API. Or we need to make everything Nabla uses part of the ExprTools public API.
  • Are the sensitivities left in Nabla sensible? Should more be moved to ChainRules?
    • It was not a goal of this PR to move things to ChainRules.
      • But it was a goal to remove things that are redundant given they are now in ChainRules.
      • Some things were moved because moving them was easier than getting them to work as they currently were.
    • Things that remain generally fall into a few categories
      • Obscure: some of the BLAS rules, noone has cared enough to move them.
      • (currently) Impossibly to implement in ChainRules: namely any kind of higher-order function like map.
      • Nabla is being weird: e.g. defining Pair{<:Node, <:Node}, rather than Node{<:Pair}
      • Those that remain and do not fall into these categories are worth commenting on. We should compile a list of them.
  • If a new rrule is added to ChainRules for something Nabla has it will cause Nabla to break due to ambiguity.
    • We could work around this via making sure when rules define overloads they also hard-code the Arg{1} and Arg{2} etc cases. That would remove ambiguities i think.
    • Downside is we would not immediately find out about the redundancy
    • We could simply leave it as is, allow the redundancy to throw an error, which we will pick-up in Nightly CI, and then we can just delete the redundant code. Nabla has extensive tests so it will be caught.
  • Usual stuff:
    • Are there TODOs that were added in this PR? (there are lots already there)
    • Is there commented out code added in this PR? (there was lots already there)

@oxinabox
Copy link
Member Author

oxinabox commented Sep 7, 2020

to do the inplace I would like to have
JuliaDiff/ChainRulesCore.jl#113 (comment)

but i don't need it since can just overload update! for InplaceableThunk

src/core.jl Show resolved Hide resolved
end

"like `ExprTools.signature` but on a signature type-tuple, not a Method"
function build_def(sig)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly this should move into ExprTools.

@oxinabox
Copy link
Member Author

oxinabox commented Sep 8, 2020

Needs invenia/ExprTools.jl#12

@oxinabox
Copy link
Member Author

oxinabox commented Sep 8, 2020

Probably what this should do is look at the method table and check if the simple unionized overload would eclipse any in the wrong way (what is that? I need to think carefully).
And if not, use that.
But if so use the one where it generate all the combinatoric overloads.

Though maybe that check would take longer than the extra processing time to generate and load all of them

Project.toml Outdated Show resolved Hide resolved
src/core.jl Show resolved Hide resolved
src/core.jl Outdated Show resolved Hide resolved
src/sensitivity.jl Outdated Show resolved Hide resolved
@oxinabox oxinabox force-pushed the ox/chainrules branch 2 times, most recently from 0993838 to a335e03 Compare October 21, 2020 15:15
@mattBrzezinski mattBrzezinski self-assigned this Oct 21, 2020
@oxinabox
Copy link
Member Author

oxinabox commented Jul 2, 2021

I thought i was done,
then I realizes that i could block it from erroring when new rules were added to chainrules that were also still in Nabla,
by making a list of all rules that we still have, and adding them to our block list.

Also I realised the docs wouldn't build anymore.
because we were using a version of Documenter that was so old that it wasn't compatible with Compat.jl 0.3 (which ChainRules uses)

Should be all sorted now

@oxinabox oxinabox merged commit 8d3dc2b into master Jul 5, 2021
@oxinabox oxinabox deleted the ox/chainrules branch July 5, 2021 18:36
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants