From 41722e96c51f49785360a4c75cbc0578e5f90b3e Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 21 Jan 2020 16:13:52 +0000 Subject: [PATCH 1/3] Break up giant single page of docs --- docs/src/FAQ.md | 69 ++++++++++++ docs/src/index.md | 185 +++------------------------------ docs/src/writing_good_rules.md | 79 ++++++++++++++ 3 files changed, 163 insertions(+), 170 deletions(-) create mode 100644 docs/src/FAQ.md create mode 100644 docs/src/writing_good_rules.md diff --git a/docs/src/FAQ.md b/docs/src/FAQ.md new file mode 100644 index 000000000..354edea1a --- /dev/null +++ b/docs/src/FAQ.md @@ -0,0 +1,69 @@ +# FAQ + +## What is up with the different symbols? + +### `Δx`, `∂x`, `dx` +ChainRules uses these perhaps atyptically. +As a notation that is the same across propagators, regardless of direction (incontrast see `ẋ` and `x̄` below). + + - `Δx` is the input to a propagator, (i.e a _seed_ for a _pullback_; or a _perturbation_ for a _pushforward_) + - `∂x` is the output of a propagator + - `dx` could be either + + +### dots and bars: ``\dot{y} = \dfrac{∂y}{∂x} = \overline{x}`` + - `v̇` is a derivative of the input moving forward: ``v̇ = \frac{∂v}{∂x}`` for input ``x``, intermediate value ``v``. + - `v̄` is a derivative of the output moving backward: ``v̄ = \frac{∂y}{∂v}`` for output ``y``, intermediate value ``v``. + +### others + - `Ω` is often used as the return value of the function. Especially, but not exclusively, for scalar functions. + - `ΔΩ` is thus a seed for the pullback. + - `∂Ω` is thus the output of a pushforward. + + +## Why does `rrule` return the primal function evaluation? +You might wonder why `frule(f, x)` returns `f(x)` and the derivative of `f` at `x`, and similarly for `rrule` returning `f(x)` and the pullback for `f` at `x`. +Why not just return the pushforward/pullback, and let the user call `f(x)` to get the answer separately? + +There are three reasons the rules also calculate the `f(x)`. +1. For some rules an alternative way of calculating `f(x)` can give the same answer while also generating intermediate values that can be used in the calculations required to propagate the derivative. +2. For many `rrule`s the output value is used in the definition of the pullback. For example `tan`, `sigmoid` etc. +3. For some `frule`s there exists a single, non-separable operation that will compute both derivative and primal result. For example many of the methods for [differential equation sensitivity analysis](https://docs.juliadiffeq.org/latest/analysis/sensitivity/#sensitivity-1). + +## Where are the derivatives for keyword arguments? +_pullbacks_ do not return a sensitivity for keyword arguments; +similarly _pushfowards_ do not accept a perturbation for keyword arguments. +This is because in practice functions are very rarely differentiable with respect to keyword arguments. +As a rule keyword arguments tend to control side-effects, like logging verbosity, +or to be functionality changing to perform a different operation, e.g. `dims=3`, and thus not differentiable. +To the best of our knowledge no Julia AD system, with support for the definition of custom primitives, supports differentiating with respect to keyword arguments. +At some point in the future ChainRules may support these. Maybe. + + +## What is the difference between `Zero` and `DoesNotExist` ? +`Zero` and `DoesNotExist` act almost exactly the same in practice: they result in no change whenever added to anything. +Odds are if you write a rule that returns the wrong one everything will just work fine. +We provide both to allow for clearer writing of rules, and easier debugging. + +`Zero()` represents the fact that if one perturbs (adds a small change to) the matching primal there will be no change in the behavour of the primal function. +For example in `fst(x,y) = x`, then the derivative of `fst` with respect to `y` is `Zero()`. +`fst(10, 5) == 10` and if we add `0,1` to `5` we still get `fst(10, 5.1)=10`. + +`DoesNotExist()` represents the fact that if one perturbs the matching primal, the primal function will now error. +For example in `access(xs, n) = xs[n]` then the derivative of `access` with respect to `n` is `DoesNotExist()`. +`access([10, 20, 30], 2) = 20`, but if we add `0.1` to `2` we get `access([10, 20, 30], 2.1)` which errors as indexing can't be applied at fractional indexes. + + +## When to use ChainRules vs ChainRulesCore? + +[ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) is a light-weight dependency for defining rules for functions in your packages, without you needing to depend on ChainRules itself. It has no dependencies of its own. + +[ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl) provides the full functionality, in particular it has all the rules for Base Julia and the standard libraries. Its thus a much heavier package to load. + +If you only want to define rules, not use them then you probably only want to load ChainRulesCore. +AD systems making use of ChainRules should load ChainRules (rather than ChainRulesCore). + +## Where should I put my rules? +In general, we recommend adding custom sensitivities to your own packages with ChainRulesCore, rather than adding them to ChainRules.jl. + +A few packages currently SpecialFunctions.jl and NaNMath.jl are in ChainRules.jl but this is a short-term measure. diff --git a/docs/src/index.md b/docs/src/index.md index 20fa11f92..9cf870259 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -6,7 +6,7 @@ DocTestSetup = :(using ChainRulesCore, ChainRules) [ChainRules](https://github.com/JuliaDiff/ChainRules.jl) provides a variety of common utilities that can be used by downstream [automatic differentiation (AD)](https://en.wikipedia.org/wiki/Automatic_differentiation) tools to define and execute forward-, reverse-, and mixed-mode primitives. -### Introduction +## Introduction ChainRules is all about providing a rich set of rules for differentiation. When a person learns introductory calculus, they learn that the derivative (with respect to `x`) of `a*x` is `a`, and the derivative of `sin(x)` is `cos(x)`, etc. @@ -18,7 +18,6 @@ Knowing rules for more complicated functions speeds up the autodiff process as i **ChainRules is an AD-independent collection of rules to use in a differentiation system.** -### Introduction !!! note "The whole field is a mess for terminology" It isn't just ChainRules, it is everyone. @@ -35,7 +34,7 @@ computing `foo(x)` is doing the _primal_ computation. `typeof(y)` and `typeof(x)` are both _primal_ types. -### `frule` and `rrule` +## `frule` and `rrule` !!! terminology "`frule` and `rrule`" `frule` and `rrule` are ChainRules specific terms. @@ -82,7 +81,7 @@ This operation is only possible in forward mode (where `frule` is used) because Thus the reverse mode returns the pullback function which the caller (usually an AD system) keeps hold of until derivative information about the output is available. -### The propagators: pushforward and pullback +## The propagators: pushforward and pullback !!! terminology "pushforward and pullback" @@ -96,21 +95,21 @@ This operation is only possible in forward mode (where `frule` is used) because These are also good names because effectively they propagate wiggles and wobbles through them, via the chain rule. (the term **backpropagator** may originate with ["Lambda The Ultimate Backpropagator"](http://www-bcl.cs.may.ie/~barak/papers/toplas-reverse.pdf) by Pearlmutter and Siskind, 2008) -#### Core Idea +### Core Idea -##### Less formally +#### Less formally - The **pushforward** takes a wiggle in the _input space_, and tells what wobble you would create in the output space, by passing it through the function. - The **pullback** takes wobbliness information with respect to the function's output, and tells the equivalent wobbliness with respect to the functions input. -##### More formally +#### More formally The **pushforward** of ``f`` takes the _sensitivity_ of the input of ``f`` to a quantity, and gives the _sensitivity_ of the output of ``f`` to that quantity The **pullback** of ``f`` takes the _sensitivity_ of a quantity to the output of ``f``, and gives the _sensitivity_ of that quantity to the input of ``f``. -#### Math +### Math This is all a bit simplified by talking in 1D. -##### Lighter Math +#### Lighter Math For a chain of expressions: ``` a = f(x) @@ -124,7 +123,7 @@ applies the chain rule to go from `∂c/∂b` to `∂c/∂a`. The pushforward of `g`, which also incorporates the knowledge of `∂b/∂a`, applies the chain rule to go from `∂a/∂x` to `∂b/∂x`. -#### Heavier Math +### Heavier Math If I have some functions: ``g(a)``, ``h(b)`` and ``f(x)=g(h(x))``, and I know the pullback of ``g``, at ``h(x)`` written: ``\mathrm{pullback}_{g(a)|a=h(x)}``, and I know the derivative of ``h`` with respect to its input ``b`` at ``g(x)``, @@ -141,7 +140,7 @@ pushforward to find ``\dfrac{∂f}{∂x}``: ``\dfrac{∂f}{∂x}=\mathrm{pushforward}_{h(b)|b=g(x)}\left(\left.\dfrac{∂g}{∂a}\right|_{a=x}\right)`` -#### The anatomy of pullback and pushforward +### The anatomy of pullback and pushforward For our function `foo(args...; kwargs...) = y`: @@ -190,7 +189,7 @@ If the function is `y = f(x)` often the pushforward will be written `ẏ = last( `ẏ` is commonly used to represent the perturbation for `y`. !!! note - + In the `frule`/pushforward, there is one `Δarg` per `arg` to the original function. The `Δargs` are similar in type/structure to the corresponding inputs `args` (`Δself` is explained below). @@ -218,7 +217,7 @@ It is common to write `function foo_pushforward(_, Δargs...)` in the case when Similarly every `pullback` returns an extra `∂self`, which for things without fields is the constant `NO_FIELDS`, indicating there are no fields within the function itself. -#### Pushforward / Pullback summary +### Pushforward / Pullback summary - **Pullback** - returned by `rrule` @@ -233,7 +232,7 @@ Similarly every `pullback` returns an extra `∂self`, which for things without - 1 return per original function return -#### Pullback/Pushforward and Directional Derivative/Gradient +### Pullback/Pushforward and Directional Derivative/Gradient The most trivial use of the `pushforward` from within `frule` is to calculate the directional derivative: @@ -261,7 +260,7 @@ s̄elf, ā, b̄, c̄ = ∇f Then we have that `∇f` is the _gradient_ of `f` at `(a, b, c)`. And we thus have the partial derivatives ``\overline{\mathrm{self}}, = \dfrac{∂f}{∂\mathrm{self}}``, ``\overline{a} = \dfrac{∂f}{∂a}``, ``\overline{b} = \dfrac{∂f}{∂b}``, ``\overline{c} = \dfrac{∂f}{∂c}``, including the and the self-partial derivative, ``\overline{\mathrm{self}}``. -### Differentials +## Differentials The values that come back from pullbacks or pushforwards are not always the same type as the input/outputs of the primal function. They are differentials, which correspond roughly to something able to represent the difference between two values of the primal types. @@ -276,7 +275,7 @@ The most important `AbstractDifferential`s when getting started are the ones abo - `Thunk`: this is a deferred computation. A thunk is a [word for a zero argument closure](https://en.wikipedia.org/wiki/Thunk). A computation wrapped in a `@thunk` doesn't get evaluated until `unthunk` is called on the thunk. `unthunk` is a no-op on non-thunked inputs. - `One`, `Zero`: There are special representations of `1` and `0`. They do great things around avoiding expanding `Thunks` in multiplication and (for `Zero`) addition. -#### Other `AbstractDifferential`s: +### Other `AbstractDifferential`s: - `Composite{P}`: this is the differential for tuples and structs. Use it like a `Tuple` or `NamedTuple`. The type parameter `P` is for the primal type. - `DoesNotExist`: Zero-like, represents that the operation on this input is not differentiable. Its primal type is normally `Integer` or `Bool`. - `InplaceableThunk`: it is like a Thunk but it can do `store!` and `accumulate!` in-place. @@ -346,157 +345,3 @@ using Zygote Zygote.gradient(foo, x) # (-2.0638950738662625,) ``` - - ------------------------------- - -## On writing good `rrule` / `frule` methods - -### Use `Zero()` or `One()` as return value - -The `Zero()` and `One()` differential objects exist as an alternative to directly returning -`0` or `zeros(n)`, and `1` or `I`. -They allow more optimal computation when chaining pullbacks/pushforwards, to avoid work. -They should be used where possible. - -### Use `Thunk`s appropriately: - -If work is only required for one of the returned differentials, then it should be wrapped in a `@thunk` (potentially using a `begin`-`end` block). - -If there are multiple return values, their computation should almost always be wrapped in a `@thunk`. - -Do _not_ wrap _variables_ in a `@thunk`; wrap the _computations_ that fill those variables in `@thunk`: - -```julia -# good: -∂A = @thunk(foo(x)) -return ∂A - -# bad: -∂A = foo(x) -return @thunk(∂A) -``` -In the bad example `foo(x)` gets computed eagerly, and all that the thunk is doing is wrapping the already calculated result in a function that returns it. - -### Be careful with using `adjoint` when you mean `transpose` - -Remember for complex numbers `a'` (i.e. `adjoint(a)`) takes the complex conjugate. -Instead you probably want `transpose(a)`, unless you've already restricted `a` to be a `AbstractMatrix{<:Real}`. - -### Code Style - -Use named local functions for the `pushforward`/`pullback`: - -```julia -# good: -function frule(::typeof(foo), x) - Y = foo(x) - function foo_pushforward(_, ẋ) - return bar(ẋ) - end - return Y, foo_pushforward -end -#== output -julia> frule(foo, 2) -(4, var"#foo_pushforward#11"()) -==# - -# bad: -function frule(::typeof(foo), x) - return foo(x), (_, ẋ) -> bar(ẋ) -end -#== output: -julia> frule(foo, 2) -(4, var"##9#10"()) -==# -``` - -While this is more verbose, it ensures that if an error is thrown during the `pullback`/`pushforward` the [`gensym`](https://docs.julialang.org/en/v1/base/base/#Base.gensym) name of the local function will include the name you gave it. -This makes it a lot simpler to debug from the stacktrace. - -### Write tests - -There are fairly decent tools for writing tests based on [FiniteDifferences.jl](https://github.com/JuliaDiff/FiniteDifferences.jl). -They are in [`tests/test_utils.jl`](https://github.com/JuliaDiff/ChainRules.jl/blob/master/test/test_util.jl). -Take a look at existing test and you should see how to do stuff. - -!!! warning - Use finite differencing to test derivatives. - Don't use analytical derivations for derivatives in the tests. - Those are what you use to define the rules, and so can not be confidently used in the test. - If you misread/misunderstood them, then your tests/implementation will have the same mistake. - -### CAS systems are your friends. - -It is very easy to check gradients or derivatives with a computer algebra system (CAS) like [WolframAlpha](https://www.wolframalpha.com/input/?i=gradient+atan2%28x%2Cy%29). - ------------------------------------------- - -## FAQ - -### What is up with the different symbols? - -#### `Δx`, `∂x`, `dx` -ChainRules uses these perhaps atyptically. -As a notation that is the same across propagators, regardless of direction (incontrast see `ẋ` and `x̄` below). - - - `Δx` is the input to a propagator, (i.e a _seed_ for a _pullback_; or a _perturbation_ for a _pushforward_) - - `∂x` is the output of a propagator - - `dx` could be either - - -#### dots and bars: ``\dot{y} = \dfrac{∂y}{∂x} = \overline{x}`` - - `v̇` is a derivative of the input moving forward: ``v̇ = \frac{∂v}{∂x}`` for input ``x``, intermediate value ``v``. - - `v̄` is a derivative of the output moving backward: ``v̄ = \frac{∂y}{∂v}`` for output ``y``, intermediate value ``v``. - -#### others - - `Ω` is often used as the return value of the function. Especially, but not exclusively, for scalar functions. - - `ΔΩ` is thus a seed for the pullback. - - `∂Ω` is thus the output of a pushforward. - - -### Why does `rrule` return the primal function evaluation? -You might wonder why `frule(f, x)` returns `f(x)` and the derivative of `f` at `x`, and similarly for `rrule` returning `f(x)` and the pullback for `f` at `x`. -Why not just return the pushforward/pullback, and let the user call `f(x)` to get the answer separately? - -There are three reasons the rules also calculate the `f(x)`. -1. For some rules an alternative way of calculating `f(x)` can give the same answer while also generating intermediate values that can be used in the calculations required to propagate the derivative. -2. For many `rrule`s the output value is used in the definition of the pullback. For example `tan`, `sigmoid` etc. -3. For some `frule`s there exists a single, non-separable operation that will compute both derivative and primal result. For example many of the methods for [differential equation sensitivity analysis](https://docs.juliadiffeq.org/latest/analysis/sensitivity/#sensitivity-1). - -### Where are the derivatives for keyword arguments? -_pullbacks_ do not return a sensitivity for keyword arguments; -similarly _pushfowards_ do not accept a perturbation for keyword arguments. -This is because in practice functions are very rarely differentiable with respect to keyword arguments. -As a rule keyword arguments tend to control side-effects, like logging verbosity, -or to be functionality changing to perform a different operation, e.g. `dims=3`, and thus not differentiable. -To the best of our knowledge no Julia AD system, with support for the definition of custom primitives, supports differentiating with respect to keyword arguments. -At some point in the future ChainRules may support these. Maybe. - - -### What is the difference between `Zero` and `DoesNotExist` ? -`Zero` and `DoesNotExist` act almost exactly the same in practice: they result in no change whenever added to anything. -Odds are if you write a rule that returns the wrong one everything will just work fine. -We provide both to allow for clearer writing of rules, and easier debugging. - -`Zero()` represents the fact that if one perturbs (adds a small change to) the matching primal there will be no change in the behavour of the primal function. -For example in `fst(x,y) = x`, then the derivative of `fst` with respect to `y` is `Zero()`. -`fst(10, 5) == 10` and if we add `0,1` to `5` we still get `fst(10, 5.1)=10`. - -`DoesNotExist()` represents the fact that if one perturbs the matching primal, the primal function will now error. -For example in `access(xs, n) = xs[n]` then the derivative of `access` with respect to `n` is `DoesNotExist()`. -`access([10, 20, 30], 2) = 20`, but if we add `0.1` to `2` we get `access([10, 20, 30], 2.1)` which errors as indexing can't be applied at fractional indexes. - - -### When to use ChainRules vs ChainRulesCore? - -[ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) is a light-weight dependency for defining rules for functions in your packages, without you needing to depend on ChainRules itself. It has no dependencies of its own. - -[ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl) provides the full functionality, in particular it has all the rules for Base Julia and the standard libraries. Its thus a much heavier package to load. - -If you only want to define rules, not use them then you probably only want to load ChainRulesCore. -AD systems making use of ChainRules should load ChainRules (rather than ChainRulesCore). - -### Where should I put my rules? -In general, we recommend adding custom sensitivities to your own packages with ChainRulesCore, rather than adding them to ChainRules.jl. - -A few packages currently SpecialFunctions.jl and NaNMath.jl are in ChainRules.jl but this is a short-term measure. \ No newline at end of file diff --git a/docs/src/writing_good_rules.md b/docs/src/writing_good_rules.md new file mode 100644 index 000000000..2fbba0d86 --- /dev/null +++ b/docs/src/writing_good_rules.md @@ -0,0 +1,79 @@ +# On writing good `rrule` / `frule` methods + +## Use `Zero()` or `One()` as return value + +The `Zero()` and `One()` differential objects exist as an alternative to directly returning +`0` or `zeros(n)`, and `1` or `I`. +They allow more optimal computation when chaining pullbacks/pushforwards, to avoid work. +They should be used where possible. + +## Use `Thunk`s appropriately: + +If work is only required for one of the returned differentials, then it should be wrapped in a `@thunk` (potentially using a `begin`-`end` block). + +If there are multiple return values, their computation should almost always be wrapped in a `@thunk`. + +Do _not_ wrap _variables_ in a `@thunk`; wrap the _computations_ that fill those variables in `@thunk`: + +```julia +# good: +∂A = @thunk(foo(x)) +return ∂A + +# bad: +∂A = foo(x) +return @thunk(∂A) +``` +In the bad example `foo(x)` gets computed eagerly, and all that the thunk is doing is wrapping the already calculated result in a function that returns it. + +## Be careful with using `adjoint` when you mean `transpose` + +Remember for complex numbers `a'` (i.e. `adjoint(a)`) takes the complex conjugate. +Instead you probably want `transpose(a)`, unless you've already restricted `a` to be a `AbstractMatrix{<:Real}`. + +## Code Style + +Use named local functions for the `pushforward`/`pullback`: + +```julia +# good: +function frule(::typeof(foo), x) + Y = foo(x) + function foo_pushforward(_, ẋ) + return bar(ẋ) + end + return Y, foo_pushforward +end +#== output +julia> frule(foo, 2) +(4, var"#foo_pushforward#11"()) +==# + +# bad: +function frule(::typeof(foo), x) + return foo(x), (_, ẋ) -> bar(ẋ) +end +#== output: +julia> frule(foo, 2) +(4, var"##9#10"()) +==# +``` + +While this is more verbose, it ensures that if an error is thrown during the `pullback`/`pushforward` the [`gensym`](https://docs.julialang.org/en/v1/base/base/#Base.gensym) name of the local function will include the name you gave it. +This makes it a lot simpler to debug from the stacktrace. + +## Write tests + +There are fairly decent tools for writing tests based on [FiniteDifferences.jl](https://github.com/JuliaDiff/FiniteDifferences.jl). +They are in [`tests/test_utils.jl`](https://github.com/JuliaDiff/ChainRules.jl/blob/master/test/test_util.jl). +Take a look at existing test and you should see how to do stuff. + +!!! warning + Use finite differencing to test derivatives. + Don't use analytical derivations for derivatives in the tests. + Those are what you use to define the rules, and so can not be confidently used in the test. + If you misread/misunderstood them, then your tests/implementation will have the same mistake. + +## CAS systems are your friends. + +It is very easy to check gradients or derivatives with a computer algebra system (CAS) like [WolframAlpha](https://www.wolframalpha.com/input/?i=gradient+atan2%28x%2Cy%29). From 1e021216305d26dce17721c44dcd1f8f667bec3d Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 21 Jan 2020 17:45:25 +0000 Subject: [PATCH 2/3] Update docs/src/writing_good_rules.md Co-Authored-By: Nick Robinson --- docs/src/writing_good_rules.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/writing_good_rules.md b/docs/src/writing_good_rules.md index 2fbba0d86..f5738d183 100644 --- a/docs/src/writing_good_rules.md +++ b/docs/src/writing_good_rules.md @@ -7,7 +7,7 @@ The `Zero()` and `One()` differential objects exist as an alternative to directl They allow more optimal computation when chaining pullbacks/pushforwards, to avoid work. They should be used where possible. -## Use `Thunk`s appropriately: +## Use `Thunk`s appropriately If work is only required for one of the returned differentials, then it should be wrapped in a `@thunk` (potentially using a `begin`-`end` block). From 5b8a145211e8322b1a57d4aa8a5775b836eb7cb7 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 21 Jan 2020 19:02:22 +0000 Subject: [PATCH 3/3] Update docs make.jl --- docs/make.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/make.jl b/docs/make.jl index a6cf03dea..123c15dba 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -11,6 +11,8 @@ makedocs( authors="Jarrett Revels and other contributors", pages=[ "Introduction" => "index.md", + "FAQ" => "FAQ.md", + "Writing Good Rules" => "writing_good_rules.md", "API" => "api.md", ], )