Skip to content

Commit

Permalink
Bugfixes and re-formatting.
Browse files Browse the repository at this point in the history
  • Loading branch information
yebai committed Apr 30, 2019
1 parent f6c560b commit 46696f9
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 26 deletions.
16 changes: 9 additions & 7 deletions src/inference/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,13 +206,15 @@ end


function sample(
model::Model, alg::Hamiltonian;
model::Model,
alg::Hamiltonian;
save_state=false, # flag for state saving
resume_from=nothing, # chain to continue
reuse_spl_n=0, # flag for spl re-using
adaptor = AHMCAdaptor(),
# adaptor = AHMCAdaptor(alg),
init_theta::Union{Nothing,Array{<:Any,1}}=nothing,
rng::AbstractRNG=GLOBAL_RNG,
kwargs...
)
# Create sampler
spl = reuse_spl_n > 0 ? resume_from.info[:spl] : Sampler(alg)
Expand Down Expand Up @@ -308,7 +310,7 @@ function step(
spl::Sampler{<:AdaptiveHamiltonian},
vi::VarInfo,
is_first::Val{true};
adaptor=AHMCAdaptor(),
# adaptor=AHMCAdaptor(),
kwargs...
)
spl.selector.tag != :default && link!(vi, spl)
Expand All @@ -329,7 +331,7 @@ function step(

spl.info[:h] = h
spl.info[:traj] = gen_traj(spl.alg, init_ϵ)
spl.info[:adaptor] = adaptor
spl.info[:adaptor] = AHMCAdaptor(spl.alg, metric)

spl.selector.tag != :default && invlink!(vi, spl)
return vi, true
Expand Down Expand Up @@ -576,9 +578,9 @@ observe(spl::Sampler{<:Hamiltonian},
#### Default HMC stepsize and mass matrix adaptor
####

function AHMCAdaptor()
function AHMCAdaptor(alg::AdaptiveHamiltonian, metric)
adaptor = AHMC.StanNUTSAdaptor(
spl.alg.n_adapts, AHMC.PreConditioner(metric),
AHMC.NesterovDualAveraging(spl.alg.δ, init_ϵ)
alg.n_adapts, AHMC.PreConditioner(metric),
AHMC.NesterovDualAveraging(alg.δ, alg.init_ϵ)
)
end
44 changes: 25 additions & 19 deletions src/utilities/stan-interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,34 +15,37 @@

# Ref
# http://goedman.github.io/Stan.jl/latest/index.html#Types-1

function sample(mf::T, ss::CmdStan.Sample) where T
return sample(mf, ss.num_samples, ss.num_warmup, ss.save_warmup, ss.thin, ss.adapt, ss.alg)
return sample(mf, ss.num_samples, ss.num_warmup,
ss.save_warmup, ss.thin, ss.adapt, ss.alg)
end
function sample( mf::T,
num_samples::Int,
num_warmup::Int,
save_warmup::Bool,
thin::Int,
ss::CmdStan.Sample
) where T

function sample(mf::T,
num_samples::Int,
num_warmup::Int,
save_warmup::Bool,
thin::Int,
ss::CmdStan.Sample
) where T
return sample(mf, num_samples, num_warmup, save_warmup, thin, ss.adapt, ss.alg)
end

function sample( mf::T,
num_samples::Int,
num_warmup::Int,
save_warmup::Bool,
thin::Int,
adapt::CmdStan.Adapt,
alg::CmdStan.Hmc
) where T
function sample(mf::T,
num_samples::Int,
num_warmup::Int,
save_warmup::Bool,
thin::Int,
adapt::CmdStan.Adapt,
alg::CmdStan.Hmc
) where T
if alg.stepsize_jitter != 0.0
@warn("[Turing.sample] Turing does not support adding noise to stepsize yet.")
end
if adapt.engaged == false
if isa(alg.engine, CmdStan.Static) # hmc
stepnum = Int(round(alg.engine.int_time / alg.stepsize))
sample(mf, HMC(num_samples, alg.stepsize, stepnum); adapt_conf=adapt)
sample(mf, HMC(num_samples, alg.stepsize, stepnum); adaptor=NUTSAdaptor(adapt))
elseif isa(alg.engine, CmdStan.Nuts) # error
error("[Turing.sample] CmdStan.Nuts cannot be used with adapt.engaged set as false")
end
Expand All @@ -52,15 +55,18 @@ function sample( mf::T,
adaptor=NUTSAdaptor(adapt))
elseif isa(alg.engine, CmdStan.Nuts) # nuts
if isa(alg.metric, CmdStan.diag_e)
sample(mf, NUTS(num_samples, num_warmup, adapt.delta); adaptor=NUTSAdaptor(adapt))
sample(mf, NUTS(num_samples, num_warmup, adapt.delta);
adaptor=NUTSAdaptor(adapt))
else # TODO: reove the following since Turing support this feature now.
@warn("[Turing.sample] Turing does not support full covariance matrix for pre-conditioning yet.")
end
end
end
end

function Sampler(alg::Hamiltonian, adaptor::CmdStanAdaptorType) where CmdStanAdaptorType
function Sampler(alg::Hamiltonian,
adaptor::CmdStanAdaptorType
) where {CmdStanAdaptorType}
_sampler(alg::Hamiltonian, AHMCAdaptor(adaptor))
end

Expand Down

0 comments on commit 46696f9

Please sign in to comment.