Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

variable naming / destructuring #2465

Merged
merged 2 commits into from
Jan 14, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 33 additions & 33 deletions src/mcmc/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -429,49 +429,49 @@
and a tuple of initial states for all component samplers.
"""
function gibbs_initialstep_recursive(
rng, model, varnames, samplers, vi, states=(); initial_params=nothing, kwargs...
rng, model, varname_tuples, samplers, vi, states=(); initial_params=nothing, kwargs...
penelopeysm marked this conversation as resolved.
Show resolved Hide resolved
)
# End recursion
if isempty(varnames) && isempty(samplers)
if isempty(varname_tuples) && isempty(samplers)
return vi, states
end

varnames_local = first(varnames)
sampler_local = first(samplers)
varnames, varname_tuples_tail... = varname_tuples
sampler, samplers_tail... = samplers

# Get the initial values for this component sampler.
initial_params_local = if initial_params === nothing
nothing
else
DynamicPPL.subset(vi, varnames_local)[:]
DynamicPPL.subset(vi, varnames)[:]

Check warning on line 446 in src/mcmc/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/gibbs.jl#L446

Added line #L446 was not covered by tests
end

# Construct the conditioned model.
model_local, context_local = make_conditional(model, varnames_local, vi)
conditioned_model, context = make_conditional(model, varnames, vi)

# Take initial step.
_, new_state_local = AbstractMCMC.step(
# Take initial step with the current sampler.
_, new_state = AbstractMCMC.step(
rng,
model_local,
sampler_local;
conditioned_model,
sampler;
# FIXME: This will cause issues if the sampler expects initial params in unconstrained space.
# This is not the case for any samplers in Turing.jl, but will be for external samplers, etc.
initial_params=initial_params_local,
kwargs...,
)
new_vi_local = varinfo(new_state_local)
new_vi_local = varinfo(new_state)
# Merge in any new variables that were introduced during the step, but that
# were not in the domain of the current sampler.
vi = merge(vi, get_global_varinfo(context_local))
vi = merge(vi, get_global_varinfo(context))
# Merge the new values for all the variables sampled by the current sampler.
vi = merge(vi, new_vi_local)

states = (states..., new_state_local)
states = (states..., new_state)
return gibbs_initialstep_recursive(
rng,
model,
varnames[2:end],
samplers[2:end],
varname_tuples_tail,
samplers_tail,
vi,
states;
initial_params=initial_params,
Expand Down Expand Up @@ -624,26 +624,26 @@
function gibbs_step_recursive(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
varnames,
varname_tuples,
samplers,
states,
global_vi,
new_states=();
kwargs...,
)
# End recursion.
if isempty(varnames) && isempty(samplers) && isempty(states)
if isempty(varname_tuples) && isempty(samplers) && isempty(states)
return global_vi, new_states
end

varnames_local = first(varnames)
sampler_local = first(samplers)
state_local = first(states)
varnames, varname_tuples_tail... = varname_tuples
sampler, samplers_tail... = samplers
state, states_tail... = states

# Construct the conditional model and the varinfo that this sampler should use.
model_local, context_local = make_conditional(model, varnames_local, global_vi)
varinfo_local = subset(global_vi, varnames_local)
varinfo_local = match_linking!!(varinfo_local, state_local, model)
conditioned_model, context = make_conditional(model, varnames, global_vi)
vi = subset(global_vi, varnames)
vi = match_linking!!(vi, state, model)

# TODO(mhauru) The below may be overkill. If the varnames for this sampler are not
# sampled by other samplers, we don't need to `setparams`, but could rather simply
Expand All @@ -654,27 +654,27 @@
# going to be a significant expense anyway.
# Set the state of the current sampler, accounting for any changes made by other
# samplers.
state_local = setparams_varinfo!!(
model_local, sampler_local, state_local, varinfo_local
state = setparams_varinfo!!(
conditioned_model, sampler, state, vi
)

# Take a step with the local sampler.
new_state_local = last(
AbstractMCMC.step(rng, model_local, sampler_local, state_local; kwargs...)
new_state = last(
AbstractMCMC.step(rng, conditioned_model, sampler, state; kwargs...)
)

new_vi_local = varinfo(new_state_local)
new_vi_local = varinfo(new_state)
# Merge the latest values for all the variables in the current sampler.
new_global_vi = merge(get_global_varinfo(context_local), new_vi_local)
new_global_vi = merge(get_global_varinfo(context), new_vi_local)
new_global_vi = setlogp!!(new_global_vi, getlogp(new_vi_local))

new_states = (new_states..., new_state_local)
new_states = (new_states..., new_state)
return gibbs_step_recursive(
rng,
model,
varnames[2:end],
samplers[2:end],
states[2:end],
varname_tuples_tail,
samplers_tail,
states_tail,
new_global_vi,
new_states;
kwargs...,
Expand Down
Loading