diff --git a/ext/BATAdvancedHMCExt.jl b/ext/BATAdvancedHMCExt.jl index 323d90c6f..65c0eb7f5 100644 --- a/ext/BATAdvancedHMCExt.jl +++ b/ext/BATAdvancedHMCExt.jl @@ -15,8 +15,9 @@ using BAT: MeasureLike, BATMeasure using BAT: get_context, get_adselector, _NoADSelected using BAT: getproposal, mcmc_target -using BAT: MCMCChainState, HMCState, HamiltonianMC, HMCProposalState, MCMCChainStateInfo, MCMCChainPoolInit, MCMCMultiCycleBurnin, MCMCProposalTunerState, MCMCTransformTunerState, NoMCMCTempering -using BAT: _current_sample_idx, _proposed_sample_idx, _cleanup_samples +using BAT: MCMCChainState, HMCState, HMCProposalState, MCMCChainStateInfo, MCMCChainPoolInit, MCMCMultiCycleBurnin, MCMCProposalTunerState, MCMCTransformTunerState +using BAT: NoMCMCTempering, NoMCMCTransformTuning, HMCTrajectoryTuning, RAMTuning +using BAT: _current_sample_idx, _proposed_sample_idx, _current_sample_z_idx, _proposed_sample_z_idx, _cleanup_samples using BAT: AbstractTransformTarget, NoAdaptiveTransform using BAT: RNGPartition, get_rng, set_rng! using BAT: mcmc_step!!, nsamples, nsteps, samples_available, eff_acceptance_ratio @@ -28,11 +29,13 @@ using BAT: CURRENT_SAMPLE, PROPOSED_SAMPLE, INVALID_SAMPLE, ACCEPTED_SAMPLE, REJ using BAT: HamiltonianMC using BAT: AHMCSampleID, AHMCSampleIDVector using BAT: HMCMetric, DiagEuclideanMetric, UnitEuclideanMetric, DenseEuclideanMetric -using BAT: HMCTuning, MassMatrixAdaptor, StepSizeAdaptor, NaiveHMCTuning, StanHMCTuning +using BAT: HMCTuning, MassMatrixAdaptor, HMCTrajectoryTuning, NaiveHMCTuning, StanHMCTuning + +using ChangesOfVariables: with_logabsdet_jacobian using ValueShapes: varshape -using Accessors: @set +using Accessors: @set, @reset BAT.ext_default(::BAT.PackageExtension{:AdvancedHMC}, ::Val{:DEFAULT_INTEGRATOR}) = AdvancedHMC.Leapfrog(NaN) diff --git a/ext/ahmc_impl/ahmc_config_impl.jl b/ext/ahmc_impl/ahmc_config_impl.jl index 642ffdf58..e290b118c 100644 --- a/ext/ahmc_impl/ahmc_config_impl.jl +++ b/ext/ahmc_impl/ahmc_config_impl.jl @@ -51,7 +51,7 @@ function ahmc_adaptor( end function ahmc_adaptor( - tuning::StepSizeAdaptor, + tuning::HMCTrajectoryTuning, metric::AdvancedHMC.AbstractMetric, integrator::AdvancedHMC.AbstractIntegrator, θ_init::AbstractVector{<:Real} diff --git a/ext/ahmc_impl/ahmc_sampler_impl.jl b/ext/ahmc_impl/ahmc_sampler_impl.jl index a452551ab..f190524c4 100644 --- a/ext/ahmc_impl/ahmc_sampler_impl.jl +++ b/ext/ahmc_impl/ahmc_sampler_impl.jl @@ -3,7 +3,9 @@ BAT.bat_default(::Type{TransformedMCMC}, ::Val{:pre_transform}, proposal::HamiltonianMC) = PriorToGaussian() -BAT.bat_default(::Type{TransformedMCMC}, ::Val{:proposal_tuning}, proposal::HamiltonianMC) = StanHMCTuning() +BAT.bat_default(::Type{TransformedMCMC}, ::Val{:proposal_tuning}, proposal::HamiltonianMC) = HMCTrajectoryTuning() + +bat_default(::Type{TransformedMCMC}, ::Val{:transform_tuning}, proposal::HamiltonianMC) = NoMCMCTransformTuning() BAT.bat_default(::Type{TransformedMCMC}, ::Val{:adaptive_transform}, proposal::HamiltonianMC) = NoAdaptiveTransform() @@ -40,16 +42,21 @@ function BAT._create_proposal_state( hamiltonian, init_transition = AdvancedHMC.sample_init(rng, init_hamiltonian, params_vec) integrator = _ahmc_set_step_size(proposal.integrator, hamiltonian, params_vec) termination = _ahmc_convert_termination(proposal.termination, params_vec) - kernel = HMCKernel(Trajectory{MultinomialTS}(integrator, termination)) + τ = Trajectory{MultinomialTS}(integrator, termination) + + # TODO: MD, remove, for debugging + init_rng = deepcopy(rng) + + z = AdvancedHMC.phasepoint(hamiltonian, init_transition.z.θ, rand(init_rng, hamiltonian.metric, hamiltonian.kinetic)) # Perform a dummy step to get type-stable transition value: - transition = AdvancedHMC.transition(deepcopy(rng), deepcopy(hamiltonian), deepcopy(kernel), init_transition.z) + transition = AdvancedHMC.transition(init_rng, deepcopy(τ), deepcopy(hamiltonian), z) HMCProposalState( integrator, termination, hamiltonian, - kernel, + τ, transition ) end @@ -59,76 +66,84 @@ function BAT._get_sample_id(proposal::HMCProposalState, id::Int32, cycle::Int32, return AHMCSampleID(id, cycle, stepno, sample_type, 0.0, 0, false, 0.0), AHMCSampleID end -function BAT.next_cycle!(mc_state::HMCState) - _cleanup_samples(mc_state) +function BAT.next_cycle!(chain_state::HMCState) + _cleanup_samples(chain_state) - mc_state.info = MCMCChainStateInfo(mc_state.info, cycle = mc_state.info.cycle + 1) - mc_state.nsamples = 0 - mc_state.stepno = 0 + chain_state.info = MCMCChainStateInfo(chain_state.info, cycle = chain_state.info.cycle + 1) + chain_state.nsamples = 0 + chain_state.stepno = 0 - reset_rng_counters!(mc_state) + reset_rng_counters!(chain_state) - resize!(mc_state.samples, 1) + resize!(chain_state.samples, 1) - i = _proposed_sample_idx(mc_state) - @assert mc_state.samples.info[i].sampletype == CURRENT_SAMPLE - mc_state.samples.weight[i] = 1 + i = _proposed_sample_idx(chain_state) + @assert chain_state.samples.info[i].sampletype == CURRENT_SAMPLE + chain_state.samples.weight[i] = 1 - t_stat = mc_state.proposal.transition.stat + t_stat = chain_state.proposal.transition.stat - mc_state.samples.info[i] = AHMCSampleID( - mc_state.info.id, mc_state.info.cycle, mc_state.stepno, CURRENT_SAMPLE, + chain_state.samples.info[i] = AHMCSampleID( + chain_state.info.id, chain_state.info.cycle, chain_state.stepno, CURRENT_SAMPLE, t_stat.hamiltonian_energy, t_stat.tree_depth, t_stat.numerical_error, t_stat.step_size ) - mc_state + chain_state end -# TODO: MD, should this be a !! function? -function BAT.mcmc_propose!!(mc_state::HMCState) - # @unpack target, proposal, f_transform, samples, context = mc_state - target = mc_state.target - proposal = mc_state.proposal - f_transform = mc_state.f_transform - samples = mc_state.samples - context = mc_state.context +# TODO: MD, Make Properly !! +function BAT.mcmc_propose!!(chain_state::HMCState) + target = chain_state.target + proposal = chain_state.proposal + f_transform = chain_state.f_transform + samples = chain_state.samples + context = chain_state.context rng = get_rng(context) - current = _current_sample_idx(mc_state) - proposed = _proposed_sample_idx(mc_state) + current = _current_sample_idx(chain_state) + proposed = _proposed_sample_idx(chain_state) x_current = samples.v[current] x_proposed = samples.v[proposed] current_log_posterior = samples.logd[current] + - proposal.transition = AdvancedHMC.transition(rng, proposal.hamiltonian, proposal.kernel, proposal.transition.z) - x_proposed[:] = proposal.transition.z.θ - - proposed_log_posterior = logdensityof(target, x_proposed) + τ = deepcopy(proposal.τ) + @reset τ.integrator = AdvancedHMC.jitter(rng, τ.integrator) - samples.logd[proposed] = proposed_log_posterior + hamiltonian = proposal.hamiltonian + z = AdvancedHMC.phasepoint(hamiltonian, Vector(x_current), rand(rng, hamiltonian.metric, hamiltonian.kinetic)) + trans = AdvancedHMC.transition(rng, τ, hamiltonian, z) + tstat = AdvancedHMC.stat(trans) + p_accept = tstat.acceptance_rate + x_proposed[:] = trans.z.θ + + logd_x_proposed = logdensityof(target, x_proposed) + + samples.logd[proposed] = logd_x_proposed + accepted = x_current != x_proposed - # TODO: Setting p_accept to 1 or 0 for now. - # Use AdvancedHMC.stat(transition).acceptance_rate in the future? - p_accept = Float64(accepted) + chain_state_new = @set chain_state.samples.v[proposed] = x_proposed + chain_state_new = @set chain_state.samples.logd[proposed] = logd_x_proposed + chain_state_new.proposal.transition = trans + # chain_state_new = @set chain_state.proposal.transition = trans # For some reason this doesn't save the new transition. Do @reset for each field of trans? - return mc_state, accepted, p_accept + return chain_state_new, accepted, p_accept end -function BAT._accept_reject!(mc_state::HMCState, accepted::Bool, p_accept::Float64, current::Integer, proposed::Integer) - # @unpack samples, proposal = mc_state - samples = mc_state.samples - proposal = mc_state.proposal +function BAT._accept_reject!(chain_state::HMCState, accepted::Bool, p_accept::Float64, current::Integer, proposed::Integer) + samples = chain_state.samples + proposal = chain_state.proposal if accepted samples.info.sampletype[current] = ACCEPTED_SAMPLE samples.info.sampletype[proposed] = CURRENT_SAMPLE - mc_state.nsamples += 1 + chain_state.nsamples += 1 tstat = AdvancedHMC.stat(proposal.transition) samples.info.hamiltonian_energy[proposed] = tstat.hamiltonian_energy @@ -140,11 +155,10 @@ function BAT._accept_reject!(mc_state::HMCState, accepted::Bool, p_accept::Float samples.info.sampletype[proposed] = REJECTED_SAMPLE end - delta_w_current, w_proposed = BAT.mcmc_weight_values(mc_state.weighting, p_accept, accepted) + delta_w_current, w_proposed = BAT.mcmc_weight_values(chain_state.weighting, p_accept, accepted) samples.weight[current] += delta_w_current samples.weight[proposed] = w_proposed end - BAT.eff_acceptance_ratio(mc_state::HMCState) = nsamples(mc_state) / nsteps(mc_state) diff --git a/ext/ahmc_impl/ahmc_tuner_impl.jl b/ext/ahmc_impl/ahmc_tuner_impl.jl index b7b58e504..8a7029ad4 100644 --- a/ext/ahmc_impl/ahmc_tuner_impl.jl +++ b/ext/ahmc_impl/ahmc_tuner_impl.jl @@ -9,7 +9,7 @@ end function HMCProposalTunerState(tuning::HMCTuning, chain_state::MCMCChainState) θ = first(chain_state.samples).v - adaptor = ahmc_adaptor(tuning, chain_state.proposal.hamiltonian.metric, chain_state.proposal.kernel.τ.integrator, θ) + adaptor = ahmc_adaptor(tuning, chain_state.proposal.hamiltonian.metric, chain_state.proposal.τ.integrator, θ) HMCProposalTunerState(tuning, tuning.target_acceptance, adaptor) end @@ -50,7 +50,7 @@ function BAT.mcmc_tuning_finalize!!(tuner::HMCProposalTunerState, chain_state::H proposal = chain_state.proposal AdvancedHMC.finalize!(adaptor) proposal.hamiltonian = AdvancedHMC.update(proposal.hamiltonian, adaptor) - proposal.kernel = AdvancedHMC.update(proposal.kernel, adaptor) + proposal.τ = AdvancedHMC.update(proposal.τ, adaptor) nothing end @@ -67,7 +67,7 @@ function BAT.mcmc_tune_post_step!!( AdvancedHMC.adapt!(adaptor, proposal_new.transition.z.θ, tstat.acceptance_rate) proposal_new.hamiltonian = AdvancedHMC.update(proposal_new.hamiltonian, adaptor) - proposal_new.kernel = AdvancedHMC.update(proposal_new.kernel, adaptor) + proposal_new.τ = AdvancedHMC.update(proposal_new.τ, adaptor) tstat = merge(tstat, (is_adapt =true,)) chain_state_tmp = @set chain_state.proposal.transition.stat = tstat diff --git a/src/extdefs/ahmc_defs/ahmc_alg.jl b/src/extdefs/ahmc_defs/ahmc_alg.jl index 3772987b4..68fd95c1a 100644 --- a/src/extdefs/ahmc_defs/ahmc_alg.jl +++ b/src/extdefs/ahmc_defs/ahmc_alg.jl @@ -46,14 +46,14 @@ export HamiltonianMC mutable struct HMCProposalState{ IT, TC, - HA,#<:AdvancedHMC.Hamiltonian, - KRNL,#<:AdvancedHMC.HMCKernel - TR# <:AdvancedHMC.Transition + HA, #<:AdvancedHMC.Hamiltonian, + TRJ,#<:AdvancedHMC.Trajectory + TR #<:AdvancedHMC.Transition } <: MCMCProposalState integrator::IT termination::TC hamiltonian::HA - kernel::KRNL + τ::TRJ transition::TR end diff --git a/src/extdefs/ahmc_defs/ahmc_config.jl b/src/extdefs/ahmc_defs/ahmc_config.jl index 653ca0b9d..4478fb906 100644 --- a/src/extdefs/ahmc_defs/ahmc_config.jl +++ b/src/extdefs/ahmc_defs/ahmc_config.jl @@ -20,7 +20,7 @@ abstract type HMCTuning <: MCMCProposalTuning end target_acceptance::Float64 = 0.8 end -@with_kw struct StepSizeAdaptor <: HMCTuning +@with_kw struct HMCTrajectoryTuning <: HMCTuning target_acceptance::Float64 = 0.8 end diff --git a/src/samplers/mcmc/mcmc_algorithm.jl b/src/samplers/mcmc/mcmc_algorithm.jl index 6e21d638a..6c6632cd2 100644 --- a/src/samplers/mcmc/mcmc_algorithm.jl +++ b/src/samplers/mcmc/mcmc_algorithm.jl @@ -1,7 +1,5 @@ # This file is a part of BAT.jl, licensed under the MIT License (MIT). - - """ abstract type MCMCAlgorithm diff --git a/src/samplers/mcmc/mcmc_tuning/mcmc_adaptive_mh_tuner.jl b/src/samplers/mcmc/mcmc_tuning/mcmc_adaptive_mh_tuner.jl index 09badd13e..26d261c78 100644 --- a/src/samplers/mcmc/mcmc_tuning/mcmc_adaptive_mh_tuner.jl +++ b/src/samplers/mcmc/mcmc_tuning/mcmc_adaptive_mh_tuner.jl @@ -96,7 +96,8 @@ function mcmc_tune_post_cycle!!(tuner::AdaptiveAffineTuningState, chain_state::M stats_reweight_factor = tuning.r reweight_relative!(stats, stats_reweight_factor) append!(stats, samples) - b = chain_state.f_transform.b + # TODO: MD, How to tune the shift? + b = chain_state.f_transform.b # Is left untouched at the moment α_min = minimum(tuning.α) α_max = maximum(tuning.α) diff --git a/src/samplers/mcmc/mcmc_tuning/mcmc_ram_tuner.jl b/src/samplers/mcmc/mcmc_tuning/mcmc_ram_tuner.jl index 28bbe57de..f12486752 100644 --- a/src/samplers/mcmc/mcmc_tuning/mcmc_ram_tuner.jl +++ b/src/samplers/mcmc/mcmc_tuning/mcmc_ram_tuner.jl @@ -71,7 +71,6 @@ end mcmc_tuning_finalize!!(tuner::RAMTrafoTunerState, chain::MCMCChainState) = nothing -# Return mc_state instead of f_transform function mcmc_tune_post_step!!( tuner_state::RAMTrafoTunerState, mc_state::MCMCChainState, diff --git a/src/samplers/mcmc/mh_sampler.jl b/src/samplers/mcmc/mh_sampler.jl index b6acb0ef4..517d6d056 100644 --- a/src/samplers/mcmc/mh_sampler.jl +++ b/src/samplers/mcmc/mh_sampler.jl @@ -73,7 +73,9 @@ function mcmc_propose!!(mc_state::MHChainState) z_current, logd_z_current = sample_z_current.v, sample_z_current.logd n_dims = size(z_current, 1) + z_proposed = z_current + rand(rng, proposal.proposaldist, n_dims) #TODO: check if proposal is symmetric? otherwise need additional factor? + x_proposed, ladj = with_logabsdet_jacobian(f_transform, z_proposed) logd_x_proposed = BAT.checked_logdensityof(target, x_proposed) logd_z_proposed = logd_x_proposed + ladj @@ -84,6 +86,7 @@ function mcmc_propose!!(mc_state::MHChainState) mc_state.sample_z[2] = DensitySample(z_proposed, logd_z_proposed, 0, _get_sample_id(proposal, mc_state.info.id, mc_state.info.cycle, mc_state.stepno, PROPOSED_SAMPLE)[1], nothing) # TODO: MD, should we check for symmetriy of proposal distribution? + # TODO: MD, Make more robust p_accept = clamp(exp(logd_z_proposed - logd_z_current), 0, 1)