Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transformed hmc #455

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions ext/BATAdvancedHMCExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ using BAT: MeasureLike, BATMeasure

using BAT: get_context, get_adselector, _NoADSelected
using BAT: getproposal, mcmc_target
using BAT: MCMCChainState, HMCState, HamiltonianMC, HMCProposalState, MCMCChainStateInfo, MCMCChainPoolInit, MCMCMultiCycleBurnin, MCMCProposalTunerState, MCMCTransformTunerState, NoMCMCTempering
using BAT: _current_sample_idx, _proposed_sample_idx, _cleanup_samples
using BAT: MCMCChainState, HMCState, HMCProposalState, MCMCChainStateInfo, MCMCChainPoolInit, MCMCMultiCycleBurnin, MCMCProposalTunerState, MCMCTransformTunerState
using BAT: NoMCMCTempering, NoMCMCTransformTuning, HMCTrajectoryTuning, RAMTuning
using BAT: _current_sample_idx, _proposed_sample_idx, _current_sample_z_idx, _proposed_sample_z_idx, _cleanup_samples
using BAT: AbstractTransformTarget, NoAdaptiveTransform
using BAT: RNGPartition, get_rng, set_rng!
using BAT: mcmc_step!!, nsamples, nsteps, samples_available, eff_acceptance_ratio
Expand All @@ -28,11 +29,13 @@ using BAT: CURRENT_SAMPLE, PROPOSED_SAMPLE, INVALID_SAMPLE, ACCEPTED_SAMPLE, REJ
using BAT: HamiltonianMC
using BAT: AHMCSampleID, AHMCSampleIDVector
using BAT: HMCMetric, DiagEuclideanMetric, UnitEuclideanMetric, DenseEuclideanMetric
using BAT: HMCTuning, MassMatrixAdaptor, StepSizeAdaptor, NaiveHMCTuning, StanHMCTuning
using BAT: HMCTuning, MassMatrixAdaptor, HMCTrajectoryTuning, NaiveHMCTuning, StanHMCTuning

using ChangesOfVariables: with_logabsdet_jacobian

using ValueShapes: varshape

using Accessors: @set
using Accessors: @set, @reset


BAT.ext_default(::BAT.PackageExtension{:AdvancedHMC}, ::Val{:DEFAULT_INTEGRATOR}) = AdvancedHMC.Leapfrog(NaN)
Expand Down
2 changes: 1 addition & 1 deletion ext/ahmc_impl/ahmc_config_impl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ function ahmc_adaptor(
end

function ahmc_adaptor(
tuning::StepSizeAdaptor,
tuning::HMCTrajectoryTuning,
metric::AdvancedHMC.AbstractMetric,
integrator::AdvancedHMC.AbstractIntegrator,
θ_init::AbstractVector{<:Real}
Expand Down
102 changes: 58 additions & 44 deletions ext/ahmc_impl/ahmc_sampler_impl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

BAT.bat_default(::Type{TransformedMCMC}, ::Val{:pre_transform}, proposal::HamiltonianMC) = PriorToGaussian()

BAT.bat_default(::Type{TransformedMCMC}, ::Val{:proposal_tuning}, proposal::HamiltonianMC) = StanHMCTuning()
BAT.bat_default(::Type{TransformedMCMC}, ::Val{:proposal_tuning}, proposal::HamiltonianMC) = HMCTrajectoryTuning()

bat_default(::Type{TransformedMCMC}, ::Val{:transform_tuning}, proposal::HamiltonianMC) = NoMCMCTransformTuning()

BAT.bat_default(::Type{TransformedMCMC}, ::Val{:adaptive_transform}, proposal::HamiltonianMC) = NoAdaptiveTransform()

Expand Down Expand Up @@ -40,16 +42,21 @@ function BAT._create_proposal_state(
hamiltonian, init_transition = AdvancedHMC.sample_init(rng, init_hamiltonian, params_vec)
integrator = _ahmc_set_step_size(proposal.integrator, hamiltonian, params_vec)
termination = _ahmc_convert_termination(proposal.termination, params_vec)
kernel = HMCKernel(Trajectory{MultinomialTS}(integrator, termination))
τ = Trajectory{MultinomialTS}(integrator, termination)

# TODO: MD, remove, for debugging
init_rng = deepcopy(rng)

z = AdvancedHMC.phasepoint(hamiltonian, init_transition.z.θ, rand(init_rng, hamiltonian.metric, hamiltonian.kinetic))

# Perform a dummy step to get type-stable transition value:
transition = AdvancedHMC.transition(deepcopy(rng), deepcopy(hamiltonian), deepcopy(kernel), init_transition.z)
transition = AdvancedHMC.transition(init_rng, deepcopy(τ), deepcopy(hamiltonian), z)

HMCProposalState(
integrator,
termination,
hamiltonian,
kernel,
τ,
transition
)
end
Expand All @@ -59,76 +66,84 @@ function BAT._get_sample_id(proposal::HMCProposalState, id::Int32, cycle::Int32,
return AHMCSampleID(id, cycle, stepno, sample_type, 0.0, 0, false, 0.0), AHMCSampleID
end

function BAT.next_cycle!(mc_state::HMCState)
_cleanup_samples(mc_state)
function BAT.next_cycle!(chain_state::HMCState)
_cleanup_samples(chain_state)

mc_state.info = MCMCChainStateInfo(mc_state.info, cycle = mc_state.info.cycle + 1)
mc_state.nsamples = 0
mc_state.stepno = 0
chain_state.info = MCMCChainStateInfo(chain_state.info, cycle = chain_state.info.cycle + 1)
chain_state.nsamples = 0
chain_state.stepno = 0

reset_rng_counters!(mc_state)
reset_rng_counters!(chain_state)

resize!(mc_state.samples, 1)
resize!(chain_state.samples, 1)

i = _proposed_sample_idx(mc_state)
@assert mc_state.samples.info[i].sampletype == CURRENT_SAMPLE
mc_state.samples.weight[i] = 1
i = _proposed_sample_idx(chain_state)
@assert chain_state.samples.info[i].sampletype == CURRENT_SAMPLE
chain_state.samples.weight[i] = 1

t_stat = mc_state.proposal.transition.stat
t_stat = chain_state.proposal.transition.stat

mc_state.samples.info[i] = AHMCSampleID(
mc_state.info.id, mc_state.info.cycle, mc_state.stepno, CURRENT_SAMPLE,
chain_state.samples.info[i] = AHMCSampleID(
chain_state.info.id, chain_state.info.cycle, chain_state.stepno, CURRENT_SAMPLE,
t_stat.hamiltonian_energy, t_stat.tree_depth,
t_stat.numerical_error, t_stat.step_size
)

mc_state
chain_state
end

# TODO: MD, should this be a !! function?
function BAT.mcmc_propose!!(mc_state::HMCState)
# @unpack target, proposal, f_transform, samples, context = mc_state
target = mc_state.target
proposal = mc_state.proposal
f_transform = mc_state.f_transform
samples = mc_state.samples
context = mc_state.context
# TODO: MD, Make Properly !!
function BAT.mcmc_propose!!(chain_state::HMCState)
target = chain_state.target
proposal = chain_state.proposal
f_transform = chain_state.f_transform
samples = chain_state.samples
context = chain_state.context

rng = get_rng(context)

current = _current_sample_idx(mc_state)
proposed = _proposed_sample_idx(mc_state)
current = _current_sample_idx(chain_state)
proposed = _proposed_sample_idx(chain_state)

x_current = samples.v[current]
x_proposed = samples.v[proposed]
current_log_posterior = samples.logd[current]



proposal.transition = AdvancedHMC.transition(rng, proposal.hamiltonian, proposal.kernel, proposal.transition.z)
x_proposed[:] = proposal.transition.z.θ

proposed_log_posterior = logdensityof(target, x_proposed)
τ = deepcopy(proposal.τ)
@reset τ.integrator = AdvancedHMC.jitter(rng, τ.integrator)

samples.logd[proposed] = proposed_log_posterior
hamiltonian = proposal.hamiltonian
z = AdvancedHMC.phasepoint(hamiltonian, Vector(x_current), rand(rng, hamiltonian.metric, hamiltonian.kinetic))

trans = AdvancedHMC.transition(rng, τ, hamiltonian, z)
tstat = AdvancedHMC.stat(trans)
p_accept = tstat.acceptance_rate
x_proposed[:] = trans.z.θ

logd_x_proposed = logdensityof(target, x_proposed)

samples.logd[proposed] = logd_x_proposed

accepted = x_current != x_proposed

# TODO: Setting p_accept to 1 or 0 for now.
# Use AdvancedHMC.stat(transition).acceptance_rate in the future?
p_accept = Float64(accepted)
chain_state_new = @set chain_state.samples.v[proposed] = x_proposed
chain_state_new = @set chain_state.samples.logd[proposed] = logd_x_proposed
chain_state_new.proposal.transition = trans
# chain_state_new = @set chain_state.proposal.transition = trans # For some reason this doesn't save the new transition. Do @reset for each field of trans?

return mc_state, accepted, p_accept
return chain_state_new, accepted, p_accept
end

function BAT._accept_reject!(mc_state::HMCState, accepted::Bool, p_accept::Float64, current::Integer, proposed::Integer)
# @unpack samples, proposal = mc_state
samples = mc_state.samples
proposal = mc_state.proposal
function BAT._accept_reject!(chain_state::HMCState, accepted::Bool, p_accept::Float64, current::Integer, proposed::Integer)
samples = chain_state.samples
proposal = chain_state.proposal

if accepted
samples.info.sampletype[current] = ACCEPTED_SAMPLE
samples.info.sampletype[proposed] = CURRENT_SAMPLE
mc_state.nsamples += 1
chain_state.nsamples += 1

tstat = AdvancedHMC.stat(proposal.transition)
samples.info.hamiltonian_energy[proposed] = tstat.hamiltonian_energy
Expand All @@ -140,11 +155,10 @@ function BAT._accept_reject!(mc_state::HMCState, accepted::Bool, p_accept::Float
samples.info.sampletype[proposed] = REJECTED_SAMPLE
end

delta_w_current, w_proposed = BAT.mcmc_weight_values(mc_state.weighting, p_accept, accepted)
delta_w_current, w_proposed = BAT.mcmc_weight_values(chain_state.weighting, p_accept, accepted)

samples.weight[current] += delta_w_current
samples.weight[proposed] = w_proposed
end


BAT.eff_acceptance_ratio(mc_state::HMCState) = nsamples(mc_state) / nsteps(mc_state)
6 changes: 3 additions & 3 deletions ext/ahmc_impl/ahmc_tuner_impl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ end

function HMCProposalTunerState(tuning::HMCTuning, chain_state::MCMCChainState)
θ = first(chain_state.samples).v
adaptor = ahmc_adaptor(tuning, chain_state.proposal.hamiltonian.metric, chain_state.proposal.kernel.τ.integrator, θ)
adaptor = ahmc_adaptor(tuning, chain_state.proposal.hamiltonian.metric, chain_state.proposal.τ.integrator, θ)
HMCProposalTunerState(tuning, tuning.target_acceptance, adaptor)
end

Expand Down Expand Up @@ -50,7 +50,7 @@ function BAT.mcmc_tuning_finalize!!(tuner::HMCProposalTunerState, chain_state::H
proposal = chain_state.proposal
AdvancedHMC.finalize!(adaptor)
proposal.hamiltonian = AdvancedHMC.update(proposal.hamiltonian, adaptor)
proposal.kernel = AdvancedHMC.update(proposal.kernel, adaptor)
proposal.τ = AdvancedHMC.update(proposal.τ, adaptor)
nothing
end

Expand All @@ -67,7 +67,7 @@ function BAT.mcmc_tune_post_step!!(

AdvancedHMC.adapt!(adaptor, proposal_new.transition.z.θ, tstat.acceptance_rate)
proposal_new.hamiltonian = AdvancedHMC.update(proposal_new.hamiltonian, adaptor)
proposal_new.kernel = AdvancedHMC.update(proposal_new.kernel, adaptor)
proposal_new.τ = AdvancedHMC.update(proposal_new.τ, adaptor)
tstat = merge(tstat, (is_adapt =true,))

chain_state_tmp = @set chain_state.proposal.transition.stat = tstat
Expand Down
8 changes: 4 additions & 4 deletions src/extdefs/ahmc_defs/ahmc_alg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,14 @@ export HamiltonianMC
mutable struct HMCProposalState{
IT,
TC,
HA,#<:AdvancedHMC.Hamiltonian,
KRNL,#<:AdvancedHMC.HMCKernel
TR# <:AdvancedHMC.Transition
HA, #<:AdvancedHMC.Hamiltonian,
TRJ,#<:AdvancedHMC.Trajectory
TR #<:AdvancedHMC.Transition
} <: MCMCProposalState
integrator::IT
termination::TC
hamiltonian::HA
kernel::KRNL
τ::TRJ
transition::TR
end

Expand Down
2 changes: 1 addition & 1 deletion src/extdefs/ahmc_defs/ahmc_config.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ abstract type HMCTuning <: MCMCProposalTuning end
target_acceptance::Float64 = 0.8
end

@with_kw struct StepSizeAdaptor <: HMCTuning
@with_kw struct HMCTrajectoryTuning <: HMCTuning
target_acceptance::Float64 = 0.8
end

Expand Down
2 changes: 0 additions & 2 deletions src/samplers/mcmc/mcmc_algorithm.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# This file is a part of BAT.jl, licensed under the MIT License (MIT).



"""
abstract type MCMCAlgorithm

Expand Down
3 changes: 2 additions & 1 deletion src/samplers/mcmc/mcmc_tuning/mcmc_adaptive_mh_tuner.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ function mcmc_tune_post_cycle!!(tuner::AdaptiveAffineTuningState, chain_state::M
stats_reweight_factor = tuning.r
reweight_relative!(stats, stats_reweight_factor)
append!(stats, samples)
b = chain_state.f_transform.b
# TODO: MD, How to tune the shift?
b = chain_state.f_transform.b # Is left untouched at the moment

α_min = minimum(tuning.α)
α_max = maximum(tuning.α)
Expand Down
1 change: 0 additions & 1 deletion src/samplers/mcmc/mcmc_tuning/mcmc_ram_tuner.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ end

mcmc_tuning_finalize!!(tuner::RAMTrafoTunerState, chain::MCMCChainState) = nothing

# Return mc_state instead of f_transform
function mcmc_tune_post_step!!(
tuner_state::RAMTrafoTunerState,
mc_state::MCMCChainState,
Expand Down
3 changes: 3 additions & 0 deletions src/samplers/mcmc/mh_sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ function mcmc_propose!!(mc_state::MHChainState)
z_current, logd_z_current = sample_z_current.v, sample_z_current.logd

n_dims = size(z_current, 1)

z_proposed = z_current + rand(rng, proposal.proposaldist, n_dims) #TODO: check if proposal is symmetric? otherwise need additional factor?

x_proposed, ladj = with_logabsdet_jacobian(f_transform, z_proposed)
logd_x_proposed = BAT.checked_logdensityof(target, x_proposed)
logd_z_proposed = logd_x_proposed + ladj
Expand All @@ -84,6 +86,7 @@ function mcmc_propose!!(mc_state::MHChainState)
mc_state.sample_z[2] = DensitySample(z_proposed, logd_z_proposed, 0, _get_sample_id(proposal, mc_state.info.id, mc_state.info.cycle, mc_state.stepno, PROPOSED_SAMPLE)[1], nothing)

# TODO: MD, should we check for symmetriy of proposal distribution?
# TODO: MD, Make more robust
p_accept = clamp(exp(logd_z_proposed - logd_z_current), 0, 1)


Expand Down