Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug-fixing in dual averaging for HMC #709

Merged
merged 10 commits into from
Mar 12, 2019
14 changes: 7 additions & 7 deletions src/core/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ gradient_logp(
sampler::AbstractSampler=SampleFromPrior(),
)

Computes the value of the log joint of `θ` and its gradient for the model
specified by `(vi, sampler, model)` using whichever automatic differentation
Computes the value of the log joint of `θ` and its gradient for the model
specified by `(vi, sampler, model)` using whichever automatic differentation
tool is currently active.
"""
function gradient_logp(
Expand All @@ -76,7 +76,7 @@ function gradient_logp(
) where {TS <: Sampler}

ad_type = getADtype(TS)
if ad_type <: ForwardDiffAD
if ad_type <: ForwardDiffAD
return gradient_logp_forward(θ, vi, model, sampler)
else ad_type <: FluxTrackerAD
return gradient_logp_reverse(θ, vi, model, sampler)
Expand All @@ -91,7 +91,7 @@ gradient_logp_forward(
spl::AbstractSampler=SampleFromPrior(),
)

Computes the value of the log joint of `θ` and its gradient for the model
Computes the value of the log joint of `θ` and its gradient for the model
specified by `(vi, spl, model)` using forwards-mode AD from ForwardDiff.jl.
"""
function gradient_logp_forward(
Expand Down Expand Up @@ -133,7 +133,7 @@ gradient_logp_reverse(
sampler::AbstractSampler=SampleFromPrior(),
)

Computes the value of the log joint of `θ` and its gradient for the model
Computes the value of the log joint of `θ` and its gradient for the model
specified by `(vi, sampler, model)` using reverse-mode AD from Flux.jl.
"""
function gradient_logp_reverse(
Expand Down Expand Up @@ -166,7 +166,7 @@ end

function verifygrad(grad::AbstractVector{<:Real})
if any(isnan, grad) || any(isinf, grad)
@warn("Numerical error has been found in gradients.")
@warn("Numerical error in gradients. Rejecting current proposal...")
@warn("grad = $(grad)")
return false
else
Expand All @@ -183,7 +183,7 @@ end

import StatsFuns: nbinomlogpdf
# Note the definition of NegativeBinomial in Julia is not the same as Wikipedia's.
# Check the docstring of NegativeBinomial, r is the number of successes and
# Check the docstring of NegativeBinomial, r is the number of successes and
# k is the number of failures
_nbinomlogpdf_grad_1(r, p, k) = k == 0 ? log(p) : sum(1 / (k + r - i) for i in 1:k) + log(p)
_nbinomlogpdf_grad_2(r, p, k) = -k / (1 - p) + r / p
Expand Down
1 change: 0 additions & 1 deletion src/inference/adapt/adapt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ function adapt!(tp::ThreePhaseAdapter, stats::Real, θ; adapt_ϵ=false, adapt_M=
if tp.state.n == tp.n_adapts
if adapt_ϵ
ϵ = exp(tp.ssa.state.x_bar)
tp.ssa.state.ϵ = min(one(ϵ), ϵ)
end
@info " Adapted ϵ = $(getss(tp)), std = $(string(tp.pc)); $(tp.state.n) iterations is used for adaption."
else
Expand Down
7 changes: 4 additions & 3 deletions src/inference/adapt/stepsize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,11 @@ function adapt_stepsize!(da::DualAveraging, stats::Real)

if isnan(ϵ) || isinf(ϵ)
@warn "Incorrect ϵ = $ϵ; ϵ_previous = $(da.state.ϵ) is used instead."
else
ϵ > 5*one(ϵ) && @warn "$ϵ exceeds 5.0; capped to 5.0 for numerical stability"
da.state.ϵ = min(5*one(ϵ), ϵ)
ϵ = da.state.ϵ
x_bar = da.state.x_bar
H_bar = da.state.H_bar
end
da.state.ϵ = ϵ
da.state.x_bar = x_bar
da.state.H_bar = H_bar
end
Expand Down
60 changes: 37 additions & 23 deletions src/inference/support/hmc_core.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Ref: https://github.com/stan-dev/stan/blob/develop/src/stan/mcmc/hmc/hamiltonians/diag_e_metric.hpp

using Statistics: middle

"""
gen_grad_func(vi::VarInfo, sampler::Sampler, model)

Expand Down Expand Up @@ -232,53 +234,65 @@ function find_good_eps(model, spl::Sampler{T}, vi::VarInfo) where T
H_func = gen_H_func()
θ = vi[spl]
ϵ = _find_good_eps(θ, lj_func, grad_func, H_func, momentum_sampler)
@info "\r[$T] found initial ϵ: $ϵ"
@info "[Turing] found initial ϵ: $ϵ"
return ϵ
end

function _find_good_eps(θ, lj_func, grad_func, H_func, momentum_sampler; max_num_iters=12)
##
## Heuristically find optimal ϵ
##
function _find_good_eps(θ, lj_func, grad_func, H_func, momentum_sampler; max_num_iters=100)
@info "[Turing] looking for good initial eps..."
ϵ = 0.1
ϵ_prime = ϵ = 0.1
a_min, a_cross, a_max = 0.25, 0.5, 0.75 # minimal, crossing, maximal accept ratio
d = 2.0

p = momentum_sampler()
H0 = H_func(θ, p, lj_func(θ))

θ_prime, p_prime, τ = _leapfrog(θ, p, 1, ϵ, grad_func)
h = τ == 0 ? Inf : H_func(θ_prime, p_prime, lj_func(θ_prime))

delta_H = H0 - h
direction = delta_H > log(0.8) ? 1 : -1

iter_num = 1

# Heuristically find optimal ϵ
while (iter_num <= max_num_iters)
θ = θ_prime
delta_H = H0 - h # logp(θ`) - logp(θ)
direction = delta_H > log(a_cross) ? 1 : -1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!


p = momentum_sampler()
H0 = H_func(θ, p, lj_func(θ))

θ_prime, p_prime, τ = _leapfrog(θ, p, 1, ϵ, grad_func)
# Crossing step: increase/decrease ϵ until accept ratio cross a_cross.
for _ = 1:max_num_iters
ϵ_prime = direction == 1 ? d * ϵ : 1/d * ϵ
θ_prime, p_prime, τ = _leapfrog(θ, p, 1, ϵ_prime, grad_func)
h = τ == 0 ? Inf : H_func(θ_prime, p_prime, lj_func(θ_prime))
Turing.DEBUG && @debug "direction = $direction, h = $h"

delta_H = H0 - h

if ((direction == 1) && !(delta_H > log(0.8)))
Turing.DEBUG && @debug "[Turing] ϵ = $ϵ_prime, accept ratio a = $(min(1,(exp(delta_H))))"
if ((direction == 1) && !(delta_H > log(a_cross)))
break
elseif ((direction == -1) && !(delta_H < log(0.8)))
elseif ((direction == -1) && !(delta_H < log(a_cross)))
break
else
ϵ = direction == 1 ? 2.0 * ϵ : 0.5 * ϵ
ϵ = ϵ_prime
end

iter_num += 1
end

while h == Inf # revert if the last change is too big
ϵ = ϵ / 2 # safe is more important than large
θ_prime, p_prime, τ = _leapfrog(θ, p, 1, ϵ, grad_func)
# Bisection step: ensure final accept ratio: a_min < a < a_max.
# See https://en.wikipedia.org/wiki/Bisection_method
ϵ, ϵ_prime = ϵ < ϵ_prime ? (ϵ, ϵ_prime) : (ϵ_prime, ϵ) # Ensure ϵ < ϵ_prime
for _ = 1:max_num_iters
ϵ_mid = middle(ϵ, ϵ_prime)
θ_prime, p_prime, τ = _leapfrog(θ, p, 1, ϵ_mid, grad_func)
h = τ == 0 ? Inf : H_func(θ_prime, p_prime, lj_func(θ_prime))

delta_H = H0 - h

Turing.DEBUG && @debug "[Turing] ϵ = $ϵ_mid, accept ratio a = $(min(1,(exp(delta_H))))"
if (exp(delta_H) > a_max)
ϵ = ϵ_mid
elseif (exp(delta_H) < a_min)
ϵ_prime = ϵ_mid
else
ϵ = ϵ_mid; break
end
end

return ϵ
Expand Down