Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Break up giant single page of docs #158

Merged
merged 3 commits into from
Jan 21, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions docs/src/FAQ.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# FAQ

## What is up with the different symbols?

### `Δx`, `∂x`, `dx`
ChainRules uses these perhaps atyptically.
As a notation that is the same across propagators, regardless of direction (incontrast see `ẋ` and `x̄` below).

- `Δx` is the input to a propagator, (i.e a _seed_ for a _pullback_; or a _perturbation_ for a _pushforward_)
- `∂x` is the output of a propagator
- `dx` could be either


### dots and bars: ``\dot{y} = \dfrac{∂y}{∂x} = \overline{x}``
- `v̇` is a derivative of the input moving forward: ``v̇ = \frac{∂v}{∂x}`` for input ``x``, intermediate value ``v``.
- `v̄` is a derivative of the output moving backward: ``v̄ = \frac{∂y}{∂v}`` for output ``y``, intermediate value ``v``.

### others
- `Ω` is often used as the return value of the function. Especially, but not exclusively, for scalar functions.
- `ΔΩ` is thus a seed for the pullback.
- `∂Ω` is thus the output of a pushforward.


## Why does `rrule` return the primal function evaluation?
You might wonder why `frule(f, x)` returns `f(x)` and the derivative of `f` at `x`, and similarly for `rrule` returning `f(x)` and the pullback for `f` at `x`.
Why not just return the pushforward/pullback, and let the user call `f(x)` to get the answer separately?

There are three reasons the rules also calculate the `f(x)`.
1. For some rules an alternative way of calculating `f(x)` can give the same answer while also generating intermediate values that can be used in the calculations required to propagate the derivative.
2. For many `rrule`s the output value is used in the definition of the pullback. For example `tan`, `sigmoid` etc.
3. For some `frule`s there exists a single, non-separable operation that will compute both derivative and primal result. For example many of the methods for [differential equation sensitivity analysis](https://docs.juliadiffeq.org/latest/analysis/sensitivity/#sensitivity-1).

## Where are the derivatives for keyword arguments?
_pullbacks_ do not return a sensitivity for keyword arguments;
similarly _pushfowards_ do not accept a perturbation for keyword arguments.
This is because in practice functions are very rarely differentiable with respect to keyword arguments.
As a rule keyword arguments tend to control side-effects, like logging verbosity,
or to be functionality changing to perform a different operation, e.g. `dims=3`, and thus not differentiable.
To the best of our knowledge no Julia AD system, with support for the definition of custom primitives, supports differentiating with respect to keyword arguments.
At some point in the future ChainRules may support these. Maybe.


## What is the difference between `Zero` and `DoesNotExist` ?
`Zero` and `DoesNotExist` act almost exactly the same in practice: they result in no change whenever added to anything.
Odds are if you write a rule that returns the wrong one everything will just work fine.
We provide both to allow for clearer writing of rules, and easier debugging.

`Zero()` represents the fact that if one perturbs (adds a small change to) the matching primal there will be no change in the behavour of the primal function.
For example in `fst(x,y) = x`, then the derivative of `fst` with respect to `y` is `Zero()`.
`fst(10, 5) == 10` and if we add `0,1` to `5` we still get `fst(10, 5.1)=10`.

`DoesNotExist()` represents the fact that if one perturbs the matching primal, the primal function will now error.
For example in `access(xs, n) = xs[n]` then the derivative of `access` with respect to `n` is `DoesNotExist()`.
`access([10, 20, 30], 2) = 20`, but if we add `0.1` to `2` we get `access([10, 20, 30], 2.1)` which errors as indexing can't be applied at fractional indexes.


## When to use ChainRules vs ChainRulesCore?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

[ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) is a light-weight dependency for defining rules for functions in your packages, without you needing to depend on ChainRules itself. It has no dependencies of its own.

[ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl) provides the full functionality, in particular it has all the rules for Base Julia and the standard libraries. Its thus a much heavier package to load.

If you only want to define rules, not use them then you probably only want to load ChainRulesCore.
AD systems making use of ChainRules should load ChainRules (rather than ChainRulesCore).

## Where should I put my rules?
In general, we recommend adding custom sensitivities to your own packages with ChainRulesCore, rather than adding them to ChainRules.jl.

A few packages currently SpecialFunctions.jl and NaNMath.jl are in ChainRules.jl but this is a short-term measure.
185 changes: 15 additions & 170 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ DocTestSetup = :(using ChainRulesCore, ChainRules)

[ChainRules](https://github.com/JuliaDiff/ChainRules.jl) provides a variety of common utilities that can be used by downstream [automatic differentiation (AD)](https://en.wikipedia.org/wiki/Automatic_differentiation) tools to define and execute forward-, reverse-, and mixed-mode primitives.

### Introduction
## Introduction

ChainRules is all about providing a rich set of rules for differentiation.
When a person learns introductory calculus, they learn that the derivative (with respect to `x`) of `a*x` is `a`, and the derivative of `sin(x)` is `cos(x)`, etc.
Expand All @@ -18,7 +18,6 @@ Knowing rules for more complicated functions speeds up the autodiff process as i

**ChainRules is an AD-independent collection of rules to use in a differentiation system.**

### Introduction

!!! note "The whole field is a mess for terminology"
It isn't just ChainRules, it is everyone.
Expand All @@ -35,7 +34,7 @@ computing `foo(x)` is doing the _primal_ computation.
`typeof(y)` and `typeof(x)` are both _primal_ types.


### `frule` and `rrule`
## `frule` and `rrule`

!!! terminology "`frule` and `rrule`"
`frule` and `rrule` are ChainRules specific terms.
Expand Down Expand Up @@ -82,7 +81,7 @@ This operation is only possible in forward mode (where `frule` is used) because
Thus the reverse mode returns the pullback function which the caller (usually an AD system) keeps hold of until derivative information about the output is available.


### The propagators: pushforward and pullback
## The propagators: pushforward and pullback


!!! terminology "pushforward and pullback"
Expand All @@ -96,21 +95,21 @@ This operation is only possible in forward mode (where `frule` is used) because
These are also good names because effectively they propagate wiggles and wobbles through them, via the chain rule.
(the term **backpropagator** may originate with ["Lambda The Ultimate Backpropagator"](http://www-bcl.cs.may.ie/~barak/papers/toplas-reverse.pdf) by Pearlmutter and Siskind, 2008)

#### Core Idea
### Core Idea

##### Less formally
#### Less formally

- The **pushforward** takes a wiggle in the _input space_, and tells what wobble you would create in the output space, by passing it through the function.
- The **pullback** takes wobbliness information with respect to the function's output, and tells the equivalent wobbliness with respect to the functions input.

##### More formally
#### More formally
The **pushforward** of ``f`` takes the _sensitivity_ of the input of ``f`` to a quantity, and gives the _sensitivity_ of the output of ``f`` to that quantity
The **pullback** of ``f`` takes the _sensitivity_ of a quantity to the output of ``f``, and gives the _sensitivity_ of that quantity to the input of ``f``.

#### Math
### Math
This is all a bit simplified by talking in 1D.

##### Lighter Math
#### Lighter Math
For a chain of expressions:
```
a = f(x)
Expand All @@ -124,7 +123,7 @@ applies the chain rule to go from `∂c/∂b` to `∂c/∂a`.
The pushforward of `g`, which also incorporates the knowledge of `∂b/∂a`,
applies the chain rule to go from `∂a/∂x` to `∂b/∂x`.

#### Heavier Math
### Heavier Math
If I have some functions: ``g(a)``, ``h(b)`` and ``f(x)=g(h(x))``, and I know
the pullback of ``g``, at ``h(x)`` written: ``\mathrm{pullback}_{g(a)|a=h(x)}``,
and I know the derivative of ``h`` with respect to its input ``b`` at ``g(x)``,
Expand All @@ -141,7 +140,7 @@ pushforward to find ``\dfrac{∂f}{∂x}``:
``\dfrac{∂f}{∂x}=\mathrm{pushforward}_{h(b)|b=g(x)}\left(\left.\dfrac{∂g}{∂a}\right|_{a=x}\right)``


#### The anatomy of pullback and pushforward
### The anatomy of pullback and pushforward

For our function `foo(args...; kwargs...) = y`:

Expand Down Expand Up @@ -190,7 +189,7 @@ If the function is `y = f(x)` often the pushforward will be written `ẏ = last(
`ẏ` is commonly used to represent the perturbation for `y`.

!!! note

In the `frule`/pushforward,
there is one `Δarg` per `arg` to the original function.
The `Δargs` are similar in type/structure to the corresponding inputs `args` (`Δself` is explained below).
Expand Down Expand Up @@ -218,7 +217,7 @@ It is common to write `function foo_pushforward(_, Δargs...)` in the case when
Similarly every `pullback` returns an extra `∂self`, which for things without fields is the constant `NO_FIELDS`, indicating there are no fields within the function itself.


#### Pushforward / Pullback summary
### Pushforward / Pullback summary

- **Pullback**
- returned by `rrule`
Expand All @@ -233,7 +232,7 @@ Similarly every `pullback` returns an extra `∂self`, which for things without
- 1 return per original function return


#### Pullback/Pushforward and Directional Derivative/Gradient
### Pullback/Pushforward and Directional Derivative/Gradient

The most trivial use of the `pushforward` from within `frule` is to calculate the directional derivative:

Expand Down Expand Up @@ -261,7 +260,7 @@ s̄elf, ā, b̄, c̄ = ∇f
Then we have that `∇f` is the _gradient_ of `f` at `(a, b, c)`.
And we thus have the partial derivatives ``\overline{\mathrm{self}}, = \dfrac{∂f}{∂\mathrm{self}}``, ``\overline{a} = \dfrac{∂f}{∂a}``, ``\overline{b} = \dfrac{∂f}{∂b}``, ``\overline{c} = \dfrac{∂f}{∂c}``, including the and the self-partial derivative, ``\overline{\mathrm{self}}``.

### Differentials
## Differentials

The values that come back from pullbacks or pushforwards are not always the same type as the input/outputs of the primal function.
They are differentials, which correspond roughly to something able to represent the difference between two values of the primal types.
Expand All @@ -276,7 +275,7 @@ The most important `AbstractDifferential`s when getting started are the ones abo
- `Thunk`: this is a deferred computation. A thunk is a [word for a zero argument closure](https://en.wikipedia.org/wiki/Thunk). A computation wrapped in a `@thunk` doesn't get evaluated until `unthunk` is called on the thunk. `unthunk` is a no-op on non-thunked inputs.
- `One`, `Zero`: There are special representations of `1` and `0`. They do great things around avoiding expanding `Thunks` in multiplication and (for `Zero`) addition.

#### Other `AbstractDifferential`s:
### Other `AbstractDifferential`s:
- `Composite{P}`: this is the differential for tuples and structs. Use it like a `Tuple` or `NamedTuple`. The type parameter `P` is for the primal type.
- `DoesNotExist`: Zero-like, represents that the operation on this input is not differentiable. Its primal type is normally `Integer` or `Bool`.
- `InplaceableThunk`: it is like a Thunk but it can do `store!` and `accumulate!` in-place.
Expand Down Expand Up @@ -346,157 +345,3 @@ using Zygote
Zygote.gradient(foo, x)
# (-2.0638950738662625,)
```

-------------------------------

## On writing good `rrule` / `frule` methods

### Use `Zero()` or `One()` as return value

The `Zero()` and `One()` differential objects exist as an alternative to directly returning
`0` or `zeros(n)`, and `1` or `I`.
They allow more optimal computation when chaining pullbacks/pushforwards, to avoid work.
They should be used where possible.

### Use `Thunk`s appropriately:

If work is only required for one of the returned differentials, then it should be wrapped in a `@thunk` (potentially using a `begin`-`end` block).

If there are multiple return values, their computation should almost always be wrapped in a `@thunk`.

Do _not_ wrap _variables_ in a `@thunk`; wrap the _computations_ that fill those variables in `@thunk`:

```julia
# good:
∂A = @thunk(foo(x))
return ∂A

# bad:
∂A = foo(x)
return @thunk(∂A)
```
In the bad example `foo(x)` gets computed eagerly, and all that the thunk is doing is wrapping the already calculated result in a function that returns it.

### Be careful with using `adjoint` when you mean `transpose`

Remember for complex numbers `a'` (i.e. `adjoint(a)`) takes the complex conjugate.
Instead you probably want `transpose(a)`, unless you've already restricted `a` to be a `AbstractMatrix{<:Real}`.

### Code Style

Use named local functions for the `pushforward`/`pullback`:

```julia
# good:
function frule(::typeof(foo), x)
Y = foo(x)
function foo_pushforward(_, ẋ)
return bar(ẋ)
end
return Y, foo_pushforward
end
#== output
julia> frule(foo, 2)
(4, var"#foo_pushforward#11"())
==#

# bad:
function frule(::typeof(foo), x)
return foo(x), (_, ẋ) -> bar(ẋ)
end
#== output:
julia> frule(foo, 2)
(4, var"##9#10"())
==#
```

While this is more verbose, it ensures that if an error is thrown during the `pullback`/`pushforward` the [`gensym`](https://docs.julialang.org/en/v1/base/base/#Base.gensym) name of the local function will include the name you gave it.
This makes it a lot simpler to debug from the stacktrace.

### Write tests

There are fairly decent tools for writing tests based on [FiniteDifferences.jl](https://github.com/JuliaDiff/FiniteDifferences.jl).
They are in [`tests/test_utils.jl`](https://github.com/JuliaDiff/ChainRules.jl/blob/master/test/test_util.jl).
Take a look at existing test and you should see how to do stuff.

!!! warning
Use finite differencing to test derivatives.
Don't use analytical derivations for derivatives in the tests.
Those are what you use to define the rules, and so can not be confidently used in the test.
If you misread/misunderstood them, then your tests/implementation will have the same mistake.

### CAS systems are your friends.

It is very easy to check gradients or derivatives with a computer algebra system (CAS) like [WolframAlpha](https://www.wolframalpha.com/input/?i=gradient+atan2%28x%2Cy%29).

------------------------------------------

## FAQ

### What is up with the different symbols?

#### `Δx`, `∂x`, `dx`
ChainRules uses these perhaps atyptically.
As a notation that is the same across propagators, regardless of direction (incontrast see `ẋ` and `x̄` below).

- `Δx` is the input to a propagator, (i.e a _seed_ for a _pullback_; or a _perturbation_ for a _pushforward_)
- `∂x` is the output of a propagator
- `dx` could be either


#### dots and bars: ``\dot{y} = \dfrac{∂y}{∂x} = \overline{x}``
- `v̇` is a derivative of the input moving forward: ``v̇ = \frac{∂v}{∂x}`` for input ``x``, intermediate value ``v``.
- `v̄` is a derivative of the output moving backward: ``v̄ = \frac{∂y}{∂v}`` for output ``y``, intermediate value ``v``.

#### others
- `Ω` is often used as the return value of the function. Especially, but not exclusively, for scalar functions.
- `ΔΩ` is thus a seed for the pullback.
- `∂Ω` is thus the output of a pushforward.


### Why does `rrule` return the primal function evaluation?
You might wonder why `frule(f, x)` returns `f(x)` and the derivative of `f` at `x`, and similarly for `rrule` returning `f(x)` and the pullback for `f` at `x`.
Why not just return the pushforward/pullback, and let the user call `f(x)` to get the answer separately?

There are three reasons the rules also calculate the `f(x)`.
1. For some rules an alternative way of calculating `f(x)` can give the same answer while also generating intermediate values that can be used in the calculations required to propagate the derivative.
2. For many `rrule`s the output value is used in the definition of the pullback. For example `tan`, `sigmoid` etc.
3. For some `frule`s there exists a single, non-separable operation that will compute both derivative and primal result. For example many of the methods for [differential equation sensitivity analysis](https://docs.juliadiffeq.org/latest/analysis/sensitivity/#sensitivity-1).

### Where are the derivatives for keyword arguments?
_pullbacks_ do not return a sensitivity for keyword arguments;
similarly _pushfowards_ do not accept a perturbation for keyword arguments.
This is because in practice functions are very rarely differentiable with respect to keyword arguments.
As a rule keyword arguments tend to control side-effects, like logging verbosity,
or to be functionality changing to perform a different operation, e.g. `dims=3`, and thus not differentiable.
To the best of our knowledge no Julia AD system, with support for the definition of custom primitives, supports differentiating with respect to keyword arguments.
At some point in the future ChainRules may support these. Maybe.


### What is the difference between `Zero` and `DoesNotExist` ?
`Zero` and `DoesNotExist` act almost exactly the same in practice: they result in no change whenever added to anything.
Odds are if you write a rule that returns the wrong one everything will just work fine.
We provide both to allow for clearer writing of rules, and easier debugging.

`Zero()` represents the fact that if one perturbs (adds a small change to) the matching primal there will be no change in the behavour of the primal function.
For example in `fst(x,y) = x`, then the derivative of `fst` with respect to `y` is `Zero()`.
`fst(10, 5) == 10` and if we add `0,1` to `5` we still get `fst(10, 5.1)=10`.

`DoesNotExist()` represents the fact that if one perturbs the matching primal, the primal function will now error.
For example in `access(xs, n) = xs[n]` then the derivative of `access` with respect to `n` is `DoesNotExist()`.
`access([10, 20, 30], 2) = 20`, but if we add `0.1` to `2` we get `access([10, 20, 30], 2.1)` which errors as indexing can't be applied at fractional indexes.


### When to use ChainRules vs ChainRulesCore?

[ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) is a light-weight dependency for defining rules for functions in your packages, without you needing to depend on ChainRules itself. It has no dependencies of its own.

[ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl) provides the full functionality, in particular it has all the rules for Base Julia and the standard libraries. Its thus a much heavier package to load.

If you only want to define rules, not use them then you probably only want to load ChainRulesCore.
AD systems making use of ChainRules should load ChainRules (rather than ChainRulesCore).

### Where should I put my rules?
In general, we recommend adding custom sensitivities to your own packages with ChainRulesCore, rather than adding them to ChainRules.jl.

A few packages currently SpecialFunctions.jl and NaNMath.jl are in ChainRules.jl but this is a short-term measure.
Loading