Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multinomial sampling for NUTS, and more #79

Merged
merged 41 commits into from
Jul 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
4c8cbc9
Initial design for new `Trajectory` and `Termination` types.
yebai May 28, 2019
8f962ea
fix some type definitions
xukai92 Jul 5, 2019
df44c2a
rename NUTS_Termination to NoUTurnTermination
xukai92 Jul 5, 2019
1dde236
introduce DoublingTree type
xukai92 Jul 5, 2019
820fc68
fix RFC typo
xukai92 Jul 5, 2019
1863da3
initial DoublingTree impl
xukai92 Jul 6, 2019
675d5c0
add a isUturn function
xukai92 Jul 6, 2019
a671fa8
add a comment
xukai92 Jul 6, 2019
c0ea831
improve NUTS type
xukai92 Jul 6, 2019
71694c1
improve abstraction
xukai92 Jul 6, 2019
fd4e79b
bugfix for MH step and add more comments (#75)
xukai92 Jul 6, 2019
3c9140e
draft a docstring for the sample function (#75)
xukai92 Jul 6, 2019
6cd458d
RFC find_good_eps (#27) and add some more comments (#75)
xukai92 Jul 6, 2019
c261e13
add some preparation code
xukai92 Jul 6, 2019
608e7be
indicate the numerical error location (#71)
xukai92 Jul 6, 2019
9e9e147
Merge branch 'kx/multinomial' of github.com:TuringLang/AdvancedHMC.jl…
xukai92 Jul 7, 2019
28bcc82
Re-organise code to improve readability - no functionality change.
yebai Jul 8, 2019
2a2f428
Re-organise code to improve readability - no functionality change.
yebai Jul 8, 2019
4a582e6
Merge branch 'master' into kx/multinomial
yebai Jul 8, 2019
fe2efcf
remove replace sampling with samplerType
xukai92 Jul 9, 2019
046abd7
revert naming
xukai92 Jul 9, 2019
c828a04
improve test script
xukai92 Jul 9, 2019
2a76526
not use find_good_eps for precondition only adaptation
xukai92 Jul 9, 2019
228e910
improve function names in tests
xukai92 Jul 10, 2019
4650d4c
breaking change: passed-in grad function is now suppoed to return a t…
xukai92 Jul 10, 2019
2c4b46a
bugfix
xukai92 Jul 10, 2019
271af8f
actually use the cache mechanism; almost 2x speed-up
xukai92 Jul 10, 2019
4ab0ef7
add DiffResults to test deps
xukai92 Jul 10, 2019
db3faad
improve comments and naming
xukai92 Jul 10, 2019
2bf3943
update packages in env
xukai92 Jul 10, 2019
455be69
Code sytle updates - no functionality change.
yebai Jul 10, 2019
2b56945
Code sytle updates - no functionality change.
yebai Jul 10, 2019
e2e1e69
Unify `merge` function via kwargs - no functionality change.
yebai Jul 10, 2019
0222d10
Rename `sample` ==> `combine` - no functionality change.
yebai Jul 10, 2019
accbc55
Rename `merge` ==> `combine` - no functionality change.
yebai Jul 10, 2019
66d3fe4
Renamed `isUturn` ==> `isturn`, `iscontinued` ==> `isdivergent` - no …
yebai Jul 10, 2019
3c8c268
Make `∂H∂θ(h, θ)` return `DualValue` istead of `Tuple`.
yebai Jul 10, 2019
f23426c
make rng the first argument of combine if used
xukai92 Jul 10, 2019
62c0d78
update README.md
xukai92 Jul 10, 2019
4a842f6
double sample numbers in test
xukai92 Jul 10, 2019
202377f
Update README.md
yebai Jul 10, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 20 additions & 21 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ version = "1.0.1"
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"

[[CSTParser]]
deps = ["LibGit2", "Test", "Tokenize"]
git-tree-sha1 = "437c93bc191cd55957b3f8dee7794b6131997c56"
deps = ["Tokenize"]
git-tree-sha1 = "376a39f1862000442011390f1edf5e7f4dcc7142"
uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f"
version = "0.5.2"
version = "0.6.0"

[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
Expand All @@ -22,10 +22,10 @@ uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "2.1.0"

[[DataStructures]]
deps = ["InteractiveUtils", "OrderedCollections", "Random", "Serialization", "Test"]
git-tree-sha1 = "ca971f03e146cf144a9e2f2ce59674f5bf0e8038"
deps = ["InteractiveUtils", "OrderedCollections"]
git-tree-sha1 = "0809951a1774dc724da22d26e4289bbaab77809a"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.15.0"
version = "0.17.0"

[[Dates]]
deps = ["Printf"]
Expand All @@ -41,9 +41,9 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"

[[FillArrays]]
deps = ["LinearAlgebra", "Random", "SparseArrays", "Test"]
git-tree-sha1 = "4c9269860074f5e9cf3f9078c49228bf13e4b33b"
git-tree-sha1 = "9ab8f76758cbabba8d7f103c51dce7f73fcf8e92"
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "0.6.1"
version = "0.6.3"

[[InplaceOps]]
deps = ["LinearAlgebra", "Test"]
Expand All @@ -57,9 +57,9 @@ uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"

[[LazyArrays]]
deps = ["FillArrays", "LinearAlgebra", "MacroTools", "StaticArrays", "Test"]
git-tree-sha1 = "78f4dd7e19b21d82f3541a59a887ab92f0ab00b6"
git-tree-sha1 = "5eec856c454496abe8f4504227fcc187205a502a"
uuid = "5078a376-72f3-5289-bfd5-ec5146d43c02"
version = "0.8.1"
version = "0.9.0"

[[LibGit2]]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
Expand Down Expand Up @@ -114,10 +114,10 @@ deps = ["Unicode"]
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"

[[ProgressMeter]]
deps = ["Distributed", "Printf", "Random", "Test"]
git-tree-sha1 = "48058bc11607676e5bbc0b974af79106c6200787"
deps = ["Distributed", "Printf"]
git-tree-sha1 = "0f08e0e74e5b160ca20d3962a2620038b75881c7"
uuid = "92933f4c-e287-5a05-a399-4b506db050ca"
version = "0.9.0"
version = "1.0.0"

[[REPL]]
deps = ["InteractiveUtils", "Markdown", "Sockets"]
Expand Down Expand Up @@ -151,30 +151,29 @@ deps = ["LinearAlgebra", "Random"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[[StaticArrays]]
deps = ["InteractiveUtils", "LinearAlgebra", "Random", "Statistics", "Test"]
git-tree-sha1 = "3841b39ed5f047db1162627bf5f80a9cd3e39ae2"
deps = ["LinearAlgebra", "Random", "Statistics"]
git-tree-sha1 = "db23bbf50064c582b6f2b9b043c8e7e98ea8c0c6"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "0.10.3"
version = "0.11.0"

[[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[[StatsBase]]
deps = ["DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"]
git-tree-sha1 = "8a0f4b09c7426478ab677245ab2b0b68552143c7"
git-tree-sha1 = "2b6ca97be7ddfad5d9f16a13fe277d29f3d11c23"
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
version = "0.30.0"
version = "0.31.0"

[[Test]]
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[[Tokenize]]
deps = ["Printf", "Test"]
git-tree-sha1 = "3e83f60b74911d3042d3550884ca2776386a02b8"
git-tree-sha1 = "0de343efc07da00cd449d5b04e959ebaeeb3305d"
uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624"
version = "0.5.3"
version = "0.5.4"

[[UUIDs]]
deps = ["Random", "SHA"]
Expand Down
9 changes: 5 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[compat]
julia = ">= 1.0"

[extras]
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

[targets]
test = ["Test", "Distributions", "ForwardDiff", "Turing", "Distributed"]

[compat]
julia = ">= 1.0"
test = ["Test", "Distributions", "ForwardDiff", "Turing", "Distributed", "DiffResults"]
34 changes: 22 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,46 @@

[![Build Status](https://travis-ci.org/TuringLang/AdvancedHMC.jl.svg?branch=master)](https://travis-ci.org/TuringLang/AdvancedHMC.jl) [![Coverage Status](https://coveralls.io/repos/github/TuringLang/AdvancedHMC.jl/badge.svg?branch=kx%2Fbug-fix)](https://coveralls.io/github/TuringLang/AdvancedHMC.jl?branch=kx%2Fbug-fix)

**The code from this repository is used to implement HMC in [Turing.jl](https://github.com/yebai/Turing.jl). Try it out when it's available!**
**The code from this repository is used to implement HMC samplers in [Turing.jl](https://github.com/yebai/Turing.jl).**

**UPDATE**: The gradient function passed in to `Hamiltonian` is supposed to return a value-gradient tuple now!

## Minimal examples - sampling from a multivariate Gaussian using NUTS

```julia
using Distributions: MvNormal, logpdf
using ForwardDiff: gradient
using AdvancedHMC
### Define the target distribution and its gradient
using Distributions: logpdf, MvNormal
using DiffResults: GradientResult, value, gradient
using ForwardDiff: gradient!

# Define the target distribution and its gradient
const D = 10
const target = MvNormal(zeros(D), ones(D))
logπ(θ::AbstractVector{<:Real}) = logpdf(target, θ)
∂logπ∂θ(θ::AbstractVector{<:Real}) = gradient(logπ, θ)
ℓπ(θ) = logpdf(target, θ)

function ∂ℓπ∂θ(θ)
res = GradientResult(θ)
gradient!(res, ℓπ, θ)
return (value(res), gradient(res))
end

### Build up a HMC sampler to draw samples
using AdvancedHMC

# Sampling parameter settings
n_samples = 100_000
n_adapts = 2_000

# Initial points
# Draw a random starting points
θ_init = randn(D)

# Define metric space, Hamiltonian and sampling method
metric = DenseEuclideanMetric(D)
h = Hamiltonian(metric, logπ, ∂logπ∂θ)
# Define metric space, Hamiltonian, sampling method and adaptor
metric = DiagEuclideanMetric(D)
h = Hamiltonian(metric, ℓπ, ∂ℓπ∂θ)
prop = NUTS(Leapfrog(find_good_eps(h, θ_init)))
adaptor = StanHMCAdaptor(n_adapts, Preconditioner(metric), NesterovDualAveraging(0.8, prop.integrator.ϵ))

# Sampling
samples = sample(h, prop, θ_init, n_samples, adaptor, n_adapts)
samples = sample(h, prop, θ_init, n_samples, adaptor, n_adapts; progress=true)
```

## Reference
Expand Down
28 changes: 16 additions & 12 deletions src/hamiltonian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,29 +10,33 @@ end
# Create a `Hamiltonian` with a new `M⁻¹`
(h::Hamiltonian)(M⁻¹) = Hamiltonian(h.metric(M⁻¹), h.ℓπ, h.∂ℓπ∂θ)

∂H∂θ(h::Hamiltonian, θ::AbstractVector) = -h.∂ℓπ∂θ(θ)

∂H∂r(h::Hamiltonian{<:UnitEuclideanMetric}, r::AbstractVector) = copy(r)
∂H∂r(h::Hamiltonian{<:DiagEuclideanMetric}, r::AbstractVector) = h.metric.M⁻¹ .* r
∂H∂r(h::Hamiltonian{<:DenseEuclideanMetric}, r::AbstractVector) = h.metric.M⁻¹ * r

struct DualValue{Tv<:AbstractFloat, Tg<:AbstractVector{Tv}}
value::Tv # Cached value, e.g. logπ(θ).
gradient::Tg # Cached gradient, e.g. ∇logπ(θ).
end

# `∂H∂θ` now returns `(logprob, -∂ℓπ∂θ)`
function ∂H∂θ(h::Hamiltonian, θ::AbstractVector)
res = h.∂ℓπ∂θ(θ)
return DualValue(res[1], -res[2])
end

∂H∂r(h::Hamiltonian{<:UnitEuclideanMetric}, r::AbstractVector) = copy(r)
∂H∂r(h::Hamiltonian{<:DiagEuclideanMetric}, r::AbstractVector) = h.metric.M⁻¹ .* r
∂H∂r(h::Hamiltonian{<:DenseEuclideanMetric}, r::AbstractVector) = h.metric.M⁻¹ * r

struct PhasePoint{T<:AbstractVector, V<:DualValue}
θ::T # Position variables / model parameters.
r::T # Momentum variables
ℓπ::V # Cached neg potential energy for the current θ.
ℓκ::V # Cached neg kinect energy for the current r.
function PhasePoint(θ::T, r::T, ℓπ::V, ℓκ::V) where {T,V}
@argcheck length(θ) == length(r) == length(ℓπ.gradient) == length(ℓπ.gradient)
# if !(all(isfinite, θ) && all(isfinite, r) && all(isfinite, ℓπ) && all(isfinite, ℓκ))
if !(isfinite(θ) && isfinite(r) && isfinite(ℓπ) && isfinite(ℓκ))
@warn "The current proposal will be rejected (due to numerical error(s))..."
ℓκ = DualValue(-Inf, ℓκ.gradient)
isfiniteθ, isfiniter, isfiniteℓπ, isfiniteℓκ = isfinite(θ), isfinite(r), isfinite(ℓπ), isfinite(ℓκ)
if !(isfiniteθ && isfiniter && isfiniteℓπ && isfiniteℓκ)
@warn "The current proposal will be rejected due to numerical error(s)." isfiniteθ isfiniter isfiniteℓπ isfiniteℓκ
ℓπ = DualValue(-Inf, ℓπ.gradient)
ℓκ = DualValue(-Inf, ℓκ.gradient)
end
new{T,V}(θ, r, ℓπ, ℓκ)
end
Expand All @@ -42,8 +46,8 @@ phasepoint(
h::Hamiltonian,
θ::T,
r::T;
ℓπ = DualValue(neg_energy(h, r, θ), ∂H∂θ(h, θ)),
ℓκ = DualValue(neg_energy(h, θ), ∂H∂r(h, r))
ℓπ=∂H∂θ(h, θ),
ℓκ=DualValue(neg_energy(h, r, θ), ∂H∂r(h, r))
) where {T<:AbstractVector} = PhasePoint(θ, r, ℓπ, ℓκ)


Expand Down
10 changes: 5 additions & 5 deletions src/integrator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ function step(
@unpack θ, r = z
ϵ = fwd ? lf.ϵ : -lf.ϵ

∇θ = ∂H∂θ(h, θ)
@unpack value, gradient = ∂H∂θ(h, θ)
for i = 1:abs(n_steps)
r = r - ϵ/2 * ∇θ # Take a half leapfrog step for momentum variable
r = r - ϵ/2 * gradient # Take a half leapfrog step for momentum variable
∇r = ∂H∂r(h, r)
θ = θ + ϵ * ∇r # Take a full leapfrog step for position variable
∇θ = ∂H∂θ(h, θ)
r = r - ϵ/2 * ∇θ # Take a half leapfrog step for momentum variable
z = phasepoint(h, θ, r)
@unpack value, gradient = ∂H∂θ(h, θ)
r = r - ϵ/2 * gradient # Take a half leapfrog step for momentum variable
z = phasepoint(h, θ, r; ℓπ=DualValue(value, gradient))
!isfinite(z) && break
end
return z
Expand Down
21 changes: 21 additions & 0 deletions src/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,27 @@ sample(
progress::Bool=false
) where {T<:Real} = sample(GLOBAL_RNG, h, τ, θ, n_samples, adaptor, n_adapts; verbose=verbose, progress=progress)

"""
sample(
rng::AbstractRNG,
h::Hamiltonian,
τ::AbstractProposal,
θ::AbstractVector{T},
n_samples::Int,
adaptor::Adaptation.AbstractAdaptor=Adaptation.NoAdaptation(),
n_adapts::Int=min(div(n_samples, 10), 1_000);
verbose::Bool=true,
progress::Bool=false
)

Sample `n_samples` samples using the proposal `τ` under Hamiltonian `h`.
- the initial point is given by `θ`
- the randomness is controlled by `rng`
- the adaptor is set by `adaptor`, for which the default is no adapation
- it will perform `n_adapts` steps of adapations, for which the default is the minimum of `1_000` and 10% of `n_samples`
- the verbosity is controlled by the boolean variable `verbose` and
- the visibility of the progress meter is controlled by the bollean variable `progress`
"""
function sample(
rng::AbstractRNG,
h::Hamiltonian,
Expand Down
Loading