Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multinomial sampling for NUTS, and more #79

Merged
merged 41 commits into from
Jul 10, 2019
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
4c8cbc9
Initial design for new `Trajectory` and `Termination` types.
yebai May 28, 2019
8f962ea
fix some type definitions
xukai92 Jul 5, 2019
df44c2a
rename NUTS_Termination to NoUTurnTermination
xukai92 Jul 5, 2019
1dde236
introduce DoublingTree type
xukai92 Jul 5, 2019
820fc68
fix RFC typo
xukai92 Jul 5, 2019
1863da3
initial DoublingTree impl
xukai92 Jul 6, 2019
675d5c0
add a isUturn function
xukai92 Jul 6, 2019
a671fa8
add a comment
xukai92 Jul 6, 2019
c0ea831
improve NUTS type
xukai92 Jul 6, 2019
71694c1
improve abstraction
xukai92 Jul 6, 2019
fd4e79b
bugfix for MH step and add more comments (#75)
xukai92 Jul 6, 2019
3c9140e
draft a docstring for the sample function (#75)
xukai92 Jul 6, 2019
6cd458d
RFC find_good_eps (#27) and add some more comments (#75)
xukai92 Jul 6, 2019
c261e13
add some preparation code
xukai92 Jul 6, 2019
608e7be
indicate the numerical error location (#71)
xukai92 Jul 6, 2019
9e9e147
Merge branch 'kx/multinomial' of github.com:TuringLang/AdvancedHMC.jl…
xukai92 Jul 7, 2019
28bcc82
Re-organise code to improve readability - no functionality change.
yebai Jul 8, 2019
2a2f428
Re-organise code to improve readability - no functionality change.
yebai Jul 8, 2019
4a582e6
Merge branch 'master' into kx/multinomial
yebai Jul 8, 2019
fe2efcf
remove replace sampling with samplerType
xukai92 Jul 9, 2019
046abd7
revert naming
xukai92 Jul 9, 2019
c828a04
improve test script
xukai92 Jul 9, 2019
2a76526
not use find_good_eps for precondition only adaptation
xukai92 Jul 9, 2019
228e910
improve function names in tests
xukai92 Jul 10, 2019
4650d4c
breaking change: passed-in grad function is now suppoed to return a t…
xukai92 Jul 10, 2019
2c4b46a
bugfix
xukai92 Jul 10, 2019
271af8f
actually use the cache mechanism; almost 2x speed-up
xukai92 Jul 10, 2019
4ab0ef7
add DiffResults to test deps
xukai92 Jul 10, 2019
db3faad
improve comments and naming
xukai92 Jul 10, 2019
2bf3943
update packages in env
xukai92 Jul 10, 2019
455be69
Code sytle updates - no functionality change.
yebai Jul 10, 2019
2b56945
Code sytle updates - no functionality change.
yebai Jul 10, 2019
e2e1e69
Unify `merge` function via kwargs - no functionality change.
yebai Jul 10, 2019
0222d10
Rename `sample` ==> `combine` - no functionality change.
yebai Jul 10, 2019
accbc55
Rename `merge` ==> `combine` - no functionality change.
yebai Jul 10, 2019
66d3fe4
Renamed `isUturn` ==> `isturn`, `iscontinued` ==> `isdivergent` - no …
yebai Jul 10, 2019
3c8c268
Make `∂H∂θ(h, θ)` return `DualValue` istead of `Tuple`.
yebai Jul 10, 2019
f23426c
make rng the first argument of combine if used
xukai92 Jul 10, 2019
62c0d78
update README.md
xukai92 Jul 10, 2019
4a842f6
double sample numbers in test
xukai92 Jul 10, 2019
202377f
Update README.md
yebai Jul 10, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
254 changes: 196 additions & 58 deletions src/trajectory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,70 @@ transition(
z::PhasePoint
) where {I<:AbstractIntegrator} = transition(GLOBAL_RNG, τ, h, z)

#################################
### Hong's abstraction starts ###
#################################

###
### Create a `Termination` type for each `Trajectory` type, e.g. HMC, NUTS etc.
### Merge all `Trajectory` types, and make `transition` dispatch on `Termination`,
### such that we can overload `transition` for different HMC samplers.
### NOTE: stopping creteria, max_depth::Int, Δ_max::AbstractFloat, n_steps, λ
###

"""
Abstract type for termination.
"""
abstract type AbstractTermination end

# Termination type for HMC and HMCDA
struct StaticTermination{D<:AbstractFloat} <: AbstractTermination
n_steps :: Int
Δ_max :: D
end

# NoUTurnTermination
struct NoUTurnTermination{D<:AbstractFloat} <: AbstractTermination
max_depth :: Int
Δ_max :: D
# TODO: add other necessary fields for No-U-Turn stopping creteria.
end

struct Trajectory{I<:AbstractIntegrator} <: AbstractTrajectory{I}
integrator :: I
n_steps :: Int # Counter for total leapfrog steps already applied.
Δ :: AbstractFloat # Current hamiltonian energy minus starting hamiltonian energy
# TODO: replace all ``*Trajectory` types with `Trajectory`.
# TODO: add turn statistic, divergent statistic, proposal statistic
end

isterminated(
x::StaticTermination,
τ::Trajectory
) = τ.n_steps >= x.n_steps || τ.Δ >= x.Δ_max

# Combine trajectories, e.g. those created by the build_tree algorithm.
# NOTE: combine proposal (via slice/multinomial sampling), combine turn statistic,
# and combine divergent statistic.
combine_trajectory(τ′::Trajectory, τ′′::Trajectory) = nothing # To-be-implemented.

## TODO: move slice variable `logu` into `Trajectory`?
combine_proposal(τ′::Trajectory, τ′′::Trajectory) = nothing # To-be-implemented.
combine_turn(τ′::Trajectory, τ′′::Trajectory) = nothing # To-be-implemented.
combine_divergence(τ′::Trajectory, τ′′::Trajectory) = nothing # To-be-implemented.

###############################
### Hong's abstraction ends ###
###############################

transition(
τ::Trajectory{I},
h::Hamiltonian,
z::PhasePoint,
t::T
) where {I<:AbstractIntegrator,T<:AbstractTermination} = nothing


###
### Standard HMC implementation with fixed leapfrog step numbers.
###
Expand All @@ -21,11 +85,6 @@ struct StaticTrajectory{I<:AbstractIntegrator} <: AbstractTrajectory{I}
n_steps :: Int
end

"""
Termination (i.e. no-U-turn).
"""
struct Termination end

"""
Create a `StaticTrajectory` with a new integrator
"""
Expand Down Expand Up @@ -82,121 +141,200 @@ end
### Advanced HMC implementation with (adaptive) dynamic trajectory length.
###

# Types for slice and multinomial sampling; with types the branches for different sampling methods shall be compiled away
abstract type AbstractNUTSSampling end

struct SliceNUTSSampling <: AbstractNUTSSampling end
struct MultinomialNUTSSampling <: AbstractNUTSSampling end
const SUPPORTED_NUTS_SAMPLING = Dict(:slice => SliceNUTSSampling(), :multinomial => MultinomialNUTSSampling())

abstract type AbstractNUTSSampler end

struct SliceNUTSSampler{F<:AbstractFloat} <: AbstractNUTSSampler
logu :: F # slice variable in log space
n :: Int # number of acceptable candicates, i.e. prob is larger than slice variable u
end
struct MultinomialNUTSSampler{F<:AbstractFloat} <: AbstractNUTSSampler
w :: F # total energy for the given tree, i.e. sum of energy of all leaves
end

combine(s1::SliceNUTSSampler, s2::SliceNUTSSampler) = SliceNUTSSampler(s1.logu, s1.n + s2.n)
combine(s1::MultinomialNUTSSampler, s2::MultinomialNUTSSampler) = MultinomialNUTSSampler(s1.w + s2.w)

"""
Dynamic trajectory HMC using the no-U-turn termination criteria algorithm.
"""
struct NUTS{I<:AbstractIntegrator} <: DynamicTrajectory{I}
struct NUTS{I<:AbstractIntegrator,F<:AbstractFloat,S<:AbstractNUTSSampling} <: DynamicTrajectory{I}
integrator :: I
max_depth :: Int
Δ_max :: AbstractFloat
Δ_max :: F
sampling :: S
end


# Helper function to use default values
NUTS(integrator::AbstractIntegrator) = NUTS(integrator, 10, 1000.0)
function NUTS(
integrator::AbstractIntegrator,
max_depth::Int=10,
Δ_max::AbstractFloat=1000.0;
sampling::Symbol=:multinomial
)
@assert sampling in keys(SUPPORTED_NUTS_SAMPLING) "NUTS only supports the following sampling methods: $(keys(SUPPORTED_NUTS_SAMPLING))"
return NUTS(integrator, max_depth, Δ_max, SUPPORTED_NUTS_SAMPLING[sampling])
end

"""
Create a `NUTS` with a new integrator
"""
function (snuts::NUTS)(integrator::AbstractIntegrator)
return NUTS(integrator, snuts.max_depth, snuts.Δ_max)
function (nuts::NUTS)(integrator::AbstractIntegrator)
return NUTS(integrator, nuts.max_depth, nuts.Δ_max, nuts.sampling)
end


###
### The doubling tree algorithm for expanding trajectory.
###

# TODO: implement a more efficient way to build the balance tree
struct DoublingTree{S<:AbstractNUTSSampler}
zleft # left most leaf node
zright # right most leaf node
zcand # candidate leaf node
sampler::S # condidate sampler
s # termination stats, i.e. 0 means termination and 1 means continuation
α # MH stats, i.e. sum of MH accept prob for all leapfrog steps
nα # total # of leap frog steps, i.e. phase points in a trajectory
end
# TODO: merge DoublingTree and Trajectory

function isUturn(h::Hamiltonian, zleft::PhasePoint, zright::PhasePoint)
θdiff = zright.θ - zleft.θ
return (dot(θdiff, ∂H∂r(h, zleft.r)) >= 0 ? 1 : 0) * (dot(θdiff, ∂H∂r(h, zright.r)) >= 0 ? 1 : 0)
end

function sample(rng::AbstractRNG, dtleft::DoublingTree{SliceNUTSSampler{F}}, dtright::DoublingTree{SliceNUTSSampler{F}}) where {F<:AbstractFloat}
return rand(rng) < dtleft.sampler.n / (dtleft.sampler.n + dtright.sampler.n) ? dtleft.zcand : dtright.zcand
end
function sample(rng::AbstractRNG, dtleft::DoublingTree{MultinomialNUTSSampler{F}}, dtright::DoublingTree{MultinomialNUTSSampler{F}}) where {F<:AbstractFloat}
return rand(rng) < dtleft.sampler.w / (dtleft.sampler.w + dtright.sampler.w) ? dtleft.zcand : dtright.zcand
end
sample(dtleft::DoublingTree, dtright::DoublingTree) = sample(GLOBAL_RNG, dtleft, dtright)

function merge(rng::AbstractRNG, h::Hamiltonian, dtleft::DoublingTree, dtright::DoublingTree)
zleft = dtleft.zleft
zright = dtright.zright
zcand = sample(rng, dtleft, dtright)
sampler = combine(dtleft.sampler, dtright.sampler)
s = dtleft.s * dtright.s * isUturn(h, zleft, zright)
return DoublingTree(zleft, zright, zcand, sampler, s, dtright.α + dtright.α, dtright.nα + dtright.nα)
end

"""
merge(h::Hamiltonian, dtleft::DoublingTree, dtright::DoublingTree)

Merge a left tree `dtleft` and a right tree `dtright` under given Hamiltonian `h`.
"""
merge(h::Hamiltonian, dtleft::DoublingTree, dtright::DoublingTree) = merge(GLOBAL_RNG, h, dtleft, dtright)
iscontinued(s::SliceNUTSSampler, nt::NUTS, H::AbstractFloat, H′::AbstractFloat) = (s.logu < nt.Δ_max + -H′) ? 1 : 0
# REVIEW: @Hong can you please double check if the implementation below is correct
iscontinued(s::MultinomialNUTSSampler, nt::NUTS, H::AbstractFloat, H′::AbstractFloat) = (-H < nt.Δ_max + -H′) ? 1 : 0

BaseSampler(s::SliceNUTSSampler, H::AbstractFloat) = SliceNUTSSampler(s.logu, (s.logu <= -H) ? 1 : 0)
BaseSampler(s::MultinomialNUTSSampler, H::AbstractFloat) = MultinomialNUTSSampler(H)

function build_tree(
rng::AbstractRNG,
nt::DynamicTrajectory{I},
nt::NUTS{I,F,S},
h::Hamiltonian,
z::PhasePoint,
logu::AbstractFloat,
sampler::AbstractNUTSSampler,
v::Int,
j::Int,
H::AbstractFloat
) where {I<:AbstractIntegrator,T<:Real}
) where {I<:AbstractIntegrator,F<:AbstractFloat,S<:AbstractNUTSSampling}
if j == 0
# Base case - take one leapfrog step in the direction v.
z′ = step(nt.integrator, h, z, v)
H′ = -neg_energy(z′)
n′ = (logu <= -H′) ? 1 : 0
s′ = (logu < nt.Δ_max + -H′) ? 1 : 0
sampler = BaseSampler(sampler, H′)
s′ = iscontinued(sampler, nt, H, H′)
α′ = exp(min(0, H - H′))

return z′, z′, z′, n′, s′, α′, 1
return DoublingTree(z′, z′, z′, sampler, s′, α′, 1)
else
# Recursion - build the left and right subtrees.
zm, zp, z′, n′, s′, α′, n′α = build_tree(rng, nt, h, z, logu, v, j - 1, H)

if s′ == 1
dt′ = build_tree(rng, nt, h, z, sampler, v, j - 1, H)
# Expand tree if not terminated
if dt′.s == 1
# Expand left
if v == -1
zm, _, z′′, n′′, s′′, α′′, n′′α = build_tree(rng, nt, h, zm, logu, v, j - 1, H)
else
_, zp, z′′, n′′, s′′, α′′, n′′α = build_tree(rng, nt, h, zp, logu, v, j - 1, H)
end
if rand(rng) < n′′ / (n′ + n′′)
z′ = z′′
dt′′ = build_tree(rng, nt, h, dt′.zleft, sampler, v, j - 1, H) # left tree
dtleft, dtright = dt′′, dt′
# Expand right
else
dt′′ = build_tree(rng, nt, h, dt′.zright, sampler, v, j - 1, H) # right tree
dtleft, dtright = dt′, dt′′
end
α′ = α′ + α′′
n′α = n′α + n′′α
s′ = s′′ * (dot(zp.θ - zm.θ, ∂H∂r(h, zm.r)) >= 0 ? 1 : 0) * (dot(zp.θ - zm.θ, ∂H∂r(h, zp.r)) >= 0 ? 1 : 0)
n′ = n′ + n′′
dt′ = merge(rng, h, dtleft, dtright)
end

# s: termination stats
# α: MH stats, i.e. sum of MH accept prob for all leapfrog steps
# nα: total # of leap frog steps, i.e. phase points in a trajectory
# n: # of acceptable candicates, i.e. prob is larger than slice variable u
return zm, zp, z′, n′, s′, α′, n′α
return dt′
end
end

"""
Recursivly build a tree for a given depth `j`.
"""
build_tree(
nt::DynamicTrajectory{I},
nt::NUTS{I,F,S},
h::Hamiltonian,
z::PhasePoint,
logu::AbstractFloat,
sampler::AbstractNUTSSampler,
v::Int,
j::Int,
H::AbstractFloat
) where {I<:AbstractIntegrator,T<:Real} = build_tree(GLOBAL_RNG, nt, h, z, logu, v, j, H)
) where {I<:AbstractIntegrator,F<:AbstractFloat,S<:AbstractNUTSSampling} = build_tree(GLOBAL_RNG, nt, h, z, sampler, v, j, H)

mh_accept(rng::AbstractRNG, s::SliceNUTSSampler, s′::SliceNUTSSampler) = mh_accept(rng, s.n, s′.n)
mh_accept(rng::AbstractRNG, s::MultinomialNUTSSampler, s′::MultinomialNUTSSampler) = mh_accept(rng, s.w, s′.w)
mh_accept(s::AbstractNUTSSampler, s′::AbstractNUTSSampler) = mh_accept(GLOBAL_RNG, s, s′)

InitSampler(rng::AbstractRNG, ::SliceNUTSSampling, H::AbstractFloat) = SliceNUTSSampler(log(rand(rng)) - H, 1)
InitSampler(rng::AbstractRNG, ::MultinomialNUTSSampling, H::AbstractFloat) = MultinomialNUTSSampler(H)
InitSampler(s::AbstractNUTSSampling, H::AbstractFloat) = InitSampler(GLOBAL_RNG, s, H)

function transition(
rng::AbstractRNG,
nt::DynamicTrajectory{I},
nt::NUTS{I,F,S},
h::Hamiltonian,
z::PhasePoint
) where {I<:AbstractIntegrator,T<:Real}
θ, r = z.θ, z.r
) where {I<:AbstractIntegrator,F<:AbstractFloat,S<:AbstractNUTSSampling}
H = -neg_energy(z)
logu = log(rand(rng)) - H

zm = z; zp = z; z_new = z; j = 0; n = 1; s = 1
zleft = z; zright = z; zcand = z; j = 0; s = 1; sampler = InitSampler(rng, nt.sampling, H)

local α, nα
local dt
while s == 1 && j <= nt.max_depth
# Sample a direction; `-1` means left and `1` means right
v = rand(rng, [-1, 1])
if v == -1
zm, _, z′, n′, s′, α, nα = build_tree(rng, nt, h, zm, logu, v, j, H)
# Create a tree with depth `j` on the left
dt = build_tree(rng, nt, h, zleft, sampler, v, j, H)
zleft = dt.zleft
else
zm, _, z′, n′, s′, α, nα = build_tree(rng, nt, h, zm, logu, v, j, H)
# Create a tree with depth `j` on the right
dt = build_tree(rng, nt, h, zright, sampler, v, j, H)
zright = dt.zright
end

if s′ == 1
if rand(rng) < min(1, n′ / n)
z_new = z′
end
# Perform a MH step if not terminated
if dt.s == 1 && mh_accept(rng, sampler, dt.sampler)[1]
zcand = dt.zcand
end

n = n + n′
s = s′ * (dot(zp.θ - zm.θ, ∂H∂r(h, zm.r)) >= 0 ? 1 : 0) * (dot(zp.θ - zm.θ, ∂H∂r(h, zp.r)) >= 0 ? 1 : 0)
# Combine the sampler from the proposed tree and the current tree
sampler = combine(sampler, dt.sampler)
# Detect termination
s = s * dt.s * isUturn(h, zleft, zright)
# Increment tree depth
j = j + 1
end

return z_new, α / nα
return zcand, dt.α / dt.
end

###
Expand Down
3 changes: 2 additions & 1 deletion test/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ n_adapts = 2_000
@testset "$(typeof(τ))" for τ in [
StaticTrajectory(Leapfrog(ϵ), n_steps),
HMCDA(Leapfrog(ϵ), ϵ * n_steps),
NUTS(Leapfrog(find_good_eps(h, θ_init))),
NUTS(Leapfrog(find_good_eps(h, θ_init)); sampling=:multinomial),
NUTS(Leapfrog(find_good_eps(h, θ_init)); sampling=:slice),
]
samples = sample(h, τ, θ_init, n_samples; verbose=false, progress=PROGRESS)
@test mean(samples[n_adapts+1:end]) ≈ zeros(D) atol=RNDATOL
Expand Down