diff --git a/src/inference/hmc.jl b/src/inference/hmc.jl index 6bffb390c..d1513ceed 100644 --- a/src/inference/hmc.jl +++ b/src/inference/hmc.jl @@ -206,13 +206,15 @@ end function sample( - model::Model, alg::Hamiltonian; + model::Model, + alg::Hamiltonian; save_state=false, # flag for state saving resume_from=nothing, # chain to continue reuse_spl_n=0, # flag for spl re-using - adaptor = AHMCAdaptor(), + # adaptor = AHMCAdaptor(alg), init_theta::Union{Nothing,Array{<:Any,1}}=nothing, rng::AbstractRNG=GLOBAL_RNG, + kwargs... ) # Create sampler spl = reuse_spl_n > 0 ? resume_from.info[:spl] : Sampler(alg) @@ -308,7 +310,7 @@ function step( spl::Sampler{<:AdaptiveHamiltonian}, vi::VarInfo, is_first::Val{true}; - adaptor=AHMCAdaptor(), + # adaptor=AHMCAdaptor(), kwargs... ) spl.selector.tag != :default && link!(vi, spl) @@ -329,7 +331,7 @@ function step( spl.info[:h] = h spl.info[:traj] = gen_traj(spl.alg, init_ϵ) - spl.info[:adaptor] = adaptor + spl.info[:adaptor] = AHMCAdaptor(spl.alg, metric) spl.selector.tag != :default && invlink!(vi, spl) return vi, true @@ -576,9 +578,9 @@ observe(spl::Sampler{<:Hamiltonian}, #### Default HMC stepsize and mass matrix adaptor #### -function AHMCAdaptor() +function AHMCAdaptor(alg::AdaptiveHamiltonian, metric) adaptor = AHMC.StanNUTSAdaptor( - spl.alg.n_adapts, AHMC.PreConditioner(metric), - AHMC.NesterovDualAveraging(spl.alg.δ, init_ϵ) + alg.n_adapts, AHMC.PreConditioner(metric), + AHMC.NesterovDualAveraging(alg.δ, alg.init_ϵ) ) end diff --git a/src/utilities/stan-interface.jl b/src/utilities/stan-interface.jl index 2a7ed0b55..a06b220b4 100644 --- a/src/utilities/stan-interface.jl +++ b/src/utilities/stan-interface.jl @@ -15,34 +15,37 @@ # Ref # http://goedman.github.io/Stan.jl/latest/index.html#Types-1 + function sample(mf::T, ss::CmdStan.Sample) where T - return sample(mf, ss.num_samples, ss.num_warmup, ss.save_warmup, ss.thin, ss.adapt, ss.alg) + return sample(mf, ss.num_samples, ss.num_warmup, + ss.save_warmup, ss.thin, ss.adapt, ss.alg) end -function sample( mf::T, - num_samples::Int, - num_warmup::Int, - save_warmup::Bool, - thin::Int, - ss::CmdStan.Sample - ) where T + +function sample(mf::T, + num_samples::Int, + num_warmup::Int, + save_warmup::Bool, + thin::Int, + ss::CmdStan.Sample +) where T return sample(mf, num_samples, num_warmup, save_warmup, thin, ss.adapt, ss.alg) end -function sample( mf::T, - num_samples::Int, - num_warmup::Int, - save_warmup::Bool, - thin::Int, - adapt::CmdStan.Adapt, - alg::CmdStan.Hmc - ) where T +function sample(mf::T, + num_samples::Int, + num_warmup::Int, + save_warmup::Bool, + thin::Int, + adapt::CmdStan.Adapt, + alg::CmdStan.Hmc +) where T if alg.stepsize_jitter != 0.0 @warn("[Turing.sample] Turing does not support adding noise to stepsize yet.") end if adapt.engaged == false if isa(alg.engine, CmdStan.Static) # hmc stepnum = Int(round(alg.engine.int_time / alg.stepsize)) - sample(mf, HMC(num_samples, alg.stepsize, stepnum); adapt_conf=adapt) + sample(mf, HMC(num_samples, alg.stepsize, stepnum); adaptor=NUTSAdaptor(adapt)) elseif isa(alg.engine, CmdStan.Nuts) # error error("[Turing.sample] CmdStan.Nuts cannot be used with adapt.engaged set as false") end @@ -52,7 +55,8 @@ function sample( mf::T, adaptor=NUTSAdaptor(adapt)) elseif isa(alg.engine, CmdStan.Nuts) # nuts if isa(alg.metric, CmdStan.diag_e) - sample(mf, NUTS(num_samples, num_warmup, adapt.delta); adaptor=NUTSAdaptor(adapt)) + sample(mf, NUTS(num_samples, num_warmup, adapt.delta); + adaptor=NUTSAdaptor(adapt)) else # TODO: reove the following since Turing support this feature now. @warn("[Turing.sample] Turing does not support full covariance matrix for pre-conditioning yet.") end @@ -60,7 +64,9 @@ function sample( mf::T, end end -function Sampler(alg::Hamiltonian, adaptor::CmdStanAdaptorType) where CmdStanAdaptorType +function Sampler(alg::Hamiltonian, + adaptor::CmdStanAdaptorType +) where {CmdStanAdaptorType} _sampler(alg::Hamiltonian, AHMCAdaptor(adaptor)) end