diff --git a/Project.toml b/Project.toml index a8570d226..a9bffe6cf 100644 --- a/Project.toml +++ b/Project.toml @@ -37,11 +37,14 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea" [weakdeps] +AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" +Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" BridgeStan = "c88b6f0a-829e-4b0b-94b7-f06ab5908f5a" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5" +JuliaBUGS = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" @@ -51,10 +54,13 @@ PigeonsDynamicPPLExt = "DynamicPPL" PigeonsEnzymeExt = "Enzyme" PigeonsForwardDiffExt = "ForwardDiff" PigeonsHypothesisTestsExt = "HypothesisTests" +PigeonsJuliaBUGSExt = ["JuliaBUGS", "AbstractPPL", "Bijectors"] PigeonsMCMCChainsExt = "MCMCChains" PigeonsReverseDiffExt = "ReverseDiff" [compat] +AbstractPPL = "0.8.4, 0.9" +Bijectors = "0.13, 0.14" BridgeStan = "2" DataFrames = "1" Distributions = "0.25" @@ -68,6 +74,7 @@ Graphs = "1" HypothesisTests = "0.11" Interpolations = "0.14, 0.15" JSON = "0.21" +JuliaBUGS = "0.8" LogDensityProblems = "2" LogDensityProblemsAD = "1" LogExpFunctions = "0.3" @@ -90,10 +97,13 @@ ZipFile = "0.10" julia = "1.8" [extras] +AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" +Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" BridgeStan = "c88b6f0a-829e-4b0b-94b7-f06ab5908f5a" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5" +JuliaBUGS = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" diff --git a/examples/JuliaBUGS.jl b/examples/JuliaBUGS.jl new file mode 100644 index 000000000..65d29a7f1 --- /dev/null +++ b/examples/JuliaBUGS.jl @@ -0,0 +1,24 @@ +function incomplete_count_data_model(;tau::Real=4) + model_def = @bugs("model{ + for (i in 1:n) { + r[i] ~ dbern(pr[i]) + pr[i] <- ilogit(y[i] * alpha1 + alpha0) + y[i] ~ dpois(mu) + } + mu ~ dgamma(1,1) + alpha0 ~ dnorm(0, 0.1) + alpha1 ~ dnorm(0, tau) + }",false,false + ) + data = ( + y = [ + 6,missing,missing,missing,missing,missing,missing,5,1,missing,1,missing, + missing,missing,2,missing,missing,0,missing,1,2,1,7,4,6,missing,missing, + missing,5,missing + ], + r = [1,0,0,0,0,0,0,1,1,0,1,0,0,0,1,0,0,1,0,1,1,1,1,1,1,0,0,0,1,0], + n = 30, tau = tau + ) + return compile(model_def, data) +end +incomplete_count_data(;kwargs...) = JuliaBUGSPath(incomplete_count_data_model(;kwargs...)) diff --git a/ext/PigeonsDynamicPPLExt/invariance_test.jl b/ext/PigeonsDynamicPPLExt/invariance_test.jl index 6bb7d9a48..0168bf5a6 100644 --- a/ext/PigeonsDynamicPPLExt/invariance_test.jl +++ b/ext/PigeonsDynamicPPLExt/invariance_test.jl @@ -10,9 +10,8 @@ initial and final states. """ function Pigeons.forward_sample_condition_and_explore( model::DynamicPPL.Model, - explorer, rng::SplittableRandom; - run_explorer::Bool = true, + explorer = nothing, condition_on::NTuple{N,Symbol} ) where {N} # forward simulation @@ -37,7 +36,7 @@ function Pigeons.forward_sample_condition_and_explore( DynamicPPL.link!!(state, DynamicPPL.SampleFromPrior(), conditioned_model) # maybe take a step with explorer - if run_explorer + if !isnothing(explorer) state = Pigeons.explorer_step(rng, TuringLogPotential(conditioned_model), explorer, state) end diff --git a/ext/PigeonsHypothesisTestsExt/PigeonsHypothesisTestsExt.jl b/ext/PigeonsHypothesisTestsExt/PigeonsHypothesisTestsExt.jl index 18bf303e2..6444a04bf 100644 --- a/ext/PigeonsHypothesisTestsExt/PigeonsHypothesisTestsExt.jl +++ b/ext/PigeonsHypothesisTestsExt/PigeonsHypothesisTestsExt.jl @@ -34,9 +34,9 @@ function Pigeons.invariance_test( # iterate iid samples for n in eachindex(initial_values) initial_values[n] = Pigeons.forward_sample_condition_and_explore( - target, explorer, rng; run_explorer=false, simulator_kwargs...) + target, rng; simulator_kwargs...) final_values[n] = Pigeons.forward_sample_condition_and_explore( - target, explorer, rng; simulator_kwargs...) + target, rng; explorer, simulator_kwargs...) end # transform vector of vectors to matrices so that iterating dimensions == iterating columns => faster diff --git a/ext/PigeonsJuliaBUGSExt/PigeonsJuliaBUGSExt.jl b/ext/PigeonsJuliaBUGSExt/PigeonsJuliaBUGSExt.jl new file mode 100644 index 000000000..d797e36e5 --- /dev/null +++ b/ext/PigeonsJuliaBUGSExt/PigeonsJuliaBUGSExt.jl @@ -0,0 +1,24 @@ +module PigeonsJuliaBUGSExt + +using Pigeons +if isdefined(Base, :get_extension) + import JuliaBUGS + using AbstractPPL # only need because we rewrite JuliaBUGS.getparams + using Bijectors # only need because we rewrite JuliaBUGS.getparams + using DocStringExtensions + using SplittableRandoms: SplittableRandom, split + using Random +else + import ..JuliaBUGS + using ..AbstractPPL # only need because we rewrite JuliaBUGS.getparams + using ..Bijectors # only need because we rewrite JuliaBUGS.getparams + using ..DocStringExtensions + using ..SplittableRandoms: SplittableRandom, split + using ..Random +end + +include(joinpath(@__DIR__, "utils.jl")) +include(joinpath(@__DIR__, "interface.jl")) +include(joinpath(@__DIR__, "invariance_test.jl")) + +end diff --git a/ext/PigeonsJuliaBUGSExt/interface.jl b/ext/PigeonsJuliaBUGSExt/interface.jl new file mode 100644 index 000000000..2b77927d3 --- /dev/null +++ b/ext/PigeonsJuliaBUGSExt/interface.jl @@ -0,0 +1,97 @@ +####################################### +# Path interface +####################################### + +# Initialization and iid sampling +function evaluate_and_initialize(model::JuliaBUGS.BUGSModel, rng::AbstractRNG) + new_env = first(JuliaBUGS.evaluate!!(rng, model)) # sample a new evaluation environment + return JuliaBUGS.initialize!(model, new_env) # set the private_model's environment to the newly created one +end + +# used for both initializing and iid sampling +# Note: state is a flattened vector of the parameters +# Also, the vector is **concretely typed**. This means that if the evaluation +# environment contains floats and integers, the latter will be cast to float. +_sample_iid(model::JuliaBUGS.BUGSModel, rng::AbstractRNG) = + getparams(evaluate_and_initialize(model, rng)) # flatten the unobserved parameters in the model's eval environment and return + +# Note: JuliaBUGS.getparams creates a new vector on each call, so it is safe +# to call _sample_iid during initialization (**sequentially**, as done as of time +# of writing) for different Replicas (i.e., they won't share the same state). +Pigeons.initialization(target::JuliaBUGSPath, rng::AbstractRNG, _::Int64) = + _sample_iid(target.model, rng) + +# target is already a Path +Pigeons.create_path(target::JuliaBUGSPath, ::Inputs) = target + +####################################### +# Log-potential interface +####################################### + +""" +$SIGNATURES + +A log-potential built from a [`JuliaBUGSPath`](@ref) for a specific inverse +temperature parameter. + +$FIELDS +""" +struct JuliaBUGSLogPotential{TMod<:JuliaBUGS.BUGSModel, TF<:AbstractFloat} + """ + A deep-enough copy of the original model that allows evaluation while + avoiding race conditions between different Replicas. + """ + private_model::TMod + + """ + Tempering parameter. + """ + beta::TF +end + +# make a log-potential by creating a new model with independent graph and +# evaluation environment. Both of these could be modified during density +# evaluations and/or during Gibbs sampling +function Pigeons.interpolate(path::JuliaBUGSPath, beta) + model = path.model + private_model = make_private_model_copy(model) + JuliaBUGSLogPotential(private_model, beta) +end + +# log_potential evaluation +(log_potential::JuliaBUGSLogPotential)(flattened_values) = + try + log_prior, _, tempered_log_joint = last( + JuliaBUGS._tempered_evaluate!!( + log_potential.private_model, + flattened_values; + temperature=log_potential.beta + ) + ) + # avoid potential 0*Inf (= NaN) + return iszero(log_potential.beta) ? log_prior : tempered_log_joint + catch e + (isa(e, DomainError) || isa(e, BoundsError)) && return -Inf + rethrow(e) + end + +# iid sampling +function Pigeons.sample_iid!(log_potential::JuliaBUGSLogPotential, replica, shared) + replica.state = _sample_iid(log_potential.private_model, replica.rng) +end + +# parameter names +Pigeons.sample_names(::Vector, log_potential::JuliaBUGSLogPotential) = + [(Symbol(string(vn)) for vn in log_potential.private_model.parameters)...,:log_density] + +# Parallelism invariance +Pigeons.recursive_equal(a::Union{JuliaBUGSPath,JuliaBUGSLogPotential}, b) = + Pigeons._recursive_equal(a,b) +function Pigeons.recursive_equal(a::T, b) where T <: JuliaBUGS.BUGSModel + included = (:transformed, :model_def, :data) + excluded = Tuple(setdiff(fieldnames(T), included)) + Pigeons._recursive_equal(a,b,excluded) +end +# just check the betas match, the model is already checked within path +Pigeons.recursive_equal(a::AbstractVector{<:JuliaBUGSLogPotential}, b) = + all(lp1.beta == lp2.beta for (lp1,lp2) in zip(a,b)) diff --git a/ext/PigeonsJuliaBUGSExt/invariance_test.jl b/ext/PigeonsJuliaBUGSExt/invariance_test.jl new file mode 100644 index 000000000..fa44a0731 --- /dev/null +++ b/ext/PigeonsJuliaBUGSExt/invariance_test.jl @@ -0,0 +1,35 @@ +""" +$SIGNATURES + +Implements `Pigeons.forward_sample_condition_and_explore` for running invariance +tests using a [`JuliaBUGSPath`](@ref) as target. +""" +function Pigeons.forward_sample_condition_and_explore( + model::JuliaBUGS.BUGSModel, + rng::SplittableRandom; + explorer = nothing, + condition_on = () + ) + # forward simulation (new values stored in model.evaluation_env) + model = evaluate_and_initialize(model, rng) + + # maybe condition the model using the sampled observations + conditioned_model = if length(condition_on) > 0 + var_group = [JuliaBUGS.VarName{sym}() for sym in condition_on] # transform Symbols into VarNames + JuliaBUGS.condition(model, var_group) + else + model + end + + # maybe take a step with explorer + state = getparams(conditioned_model) + return if !isnothing(explorer) + Pigeons.explorer_step(rng, JuliaBUGSPath(conditioned_model), explorer, state) + else + state + end +end + +Pigeons.forward_sample_condition_and_explore(target::JuliaBUGSPath, args...; kwargs...) = + Pigeons.forward_sample_condition_and_explore(target.model, args...; kwargs...) + \ No newline at end of file diff --git a/ext/PigeonsJuliaBUGSExt/utils.jl b/ext/PigeonsJuliaBUGSExt/utils.jl new file mode 100644 index 000000000..0a1ebedc1 --- /dev/null +++ b/ext/PigeonsJuliaBUGSExt/utils.jl @@ -0,0 +1,63 @@ +#= +Tweak of JuliaBUGS.getparams to allow for flattened vectors of mixed type +=# +type_join_eval_env(env) = typejoin(Set(eltype(v) for v in env)...) +function getparams(model::JuliaBUGS.BUGSModel) + param_length = if model.transformed + model.transformed_param_length + else + model.untransformed_param_length + end + + # search for an umbrella type for all parameters in the model to avoid + # promotion of e.g. ints to floats. For models with a unique parameter + # type T, it holds that TMix=T. + TMix = type_join_eval_env(model.evaluation_env) + param_vals = Vector{TMix}(undef, param_length) + pos = 1 + for v in model.parameters + if !model.transformed + val = AbstractPPL.get(model.evaluation_env, v) + len = model.untransformed_var_lengths[v] + if val isa AbstractArray + param_vals[pos:(pos + len - 1)] .= vec(val) + else + param_vals[pos] = val + end + else + (; node_function, loop_vars) = model.g[v] + dist = node_function(model.evaluation_env, loop_vars) + transformed_value = Bijectors.transform( + Bijectors.bijector(dist), AbstractPPL.get(model.evaluation_env, v) + ) + len = model.transformed_var_lengths[v] + if transformed_value isa AbstractArray + param_vals[pos:(pos + len - 1)] .= vec(transformed_value) + else + param_vals[pos] = transformed_value + end + end + pos += len + end + return param_vals +end + +function make_private_model_copy(model::JuliaBUGS.BUGSModel) + g = deepcopy(model.g) + parameters = model.parameters + sorted_nodes = model.flattened_graph_node_data.sorted_nodes + return JuliaBUGS.BUGSModel( + model.transformed, + sum(model.untransformed_var_lengths[v] for v in parameters), + sum(model.transformed_var_lengths[v] for v in parameters), + model.untransformed_var_lengths, + model.transformed_var_lengths, + deepcopy(model.evaluation_env), + parameters, + JuliaBUGS.FlattenedGraphNodeData(g, sorted_nodes), + g, + nothing, + model.model_def, + model.data + ) +end diff --git a/src/Pigeons.jl b/src/Pigeons.jl index fb748d912..8cf17b753 100644 --- a/src/Pigeons.jl +++ b/src/Pigeons.jl @@ -52,10 +52,8 @@ include("includes.jl") export pigeons, Inputs, PT, # for running jobs: ChildProcess, MPIProcesses, - # references: - DistributionLogPotential, # targets: - TuringLogPotential, StanLogPotential, + TuringLogPotential, StanLogPotential, DistributionLogPotential, JuliaBUGSPath, # some examples toy_mvn_target, toy_stan_target, # post-processing helpers @@ -89,6 +87,7 @@ end @require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" include(joinpath(@__DIR__, "../ext/PigeonsEnzymeExt/PigeonsEnzymeExt.jl")) @require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include(joinpath(@__DIR__, "../ext/PigeonsForwardDiffExt/PigeonsForwardDiffExt.jl")) @require HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5" include(joinpath(@__DIR__, "../ext/PigeonsHypothesisTestsExt/PigeonsHypothesisTestsExt.jl")) + @require JuliaBUGS = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf" include(joinpath(@__DIR__, "../ext/PigeonsJuliaBUGSExt/PigeonsJuliaBUGSExt.jl")) @require MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" include(joinpath(@__DIR__, "../ext/PigeonsMCMCChainsExt/PigeonsMCMCChainsExt.jl")) @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include(joinpath(@__DIR__, "../ext/PigeonsReverseDiffExt/PigeonsReverseDiffExt.jl")) end diff --git a/src/explorers/invariance_test.jl b/src/explorers/invariance_test.jl index 1465bc010..5b75a2061 100644 --- a/src/explorers/invariance_test.jl +++ b/src/explorers/invariance_test.jl @@ -65,12 +65,11 @@ allows direct iid sampling from the target, conditioning is not necessary. """ function forward_sample_condition_and_explore( target::ScaledPrecisionNormalPath, - explorer, rng::SplittableRandom; - run_explorer::Bool = true + explorer = nothing ) state = initialization(target, rng, 1) # forward simulation - if run_explorer + if !isnothing(explorer) state = explorer_step(rng, target, explorer, state) end return state diff --git a/src/includes.jl b/src/includes.jl index 225a98af9..5600eb669 100644 --- a/src/includes.jl +++ b/src/includes.jl @@ -87,3 +87,4 @@ include("explorers/AAPS.jl") include("explorers/GradientBasedSampler.jl") include("evidence/stepping_stone.jl") include("api.jl") +include("paths/JuliaBUGSPath.jl") diff --git a/src/log_potentials/log_potentials.jl b/src/log_potentials/log_potentials.jl index 28226df05..e5a049010 100644 --- a/src/log_potentials/log_potentials.jl +++ b/src/log_potentials/log_potentials.jl @@ -40,5 +40,12 @@ Assumes the input `log_potentials` is a vector where each element is a [`log_pot This default implementation is sufficient in most cases, but in less standard scenarios, e.g. where the state space is infinite dimensional, this can be overridden. """ -log_unnormalized_ratio(log_potentials::AbstractVector, numerator::Int, denominator::Int, state) = - log_potentials[numerator](state) - log_potentials[denominator](state) +function log_unnormalized_ratio(log_potentials::AbstractVector, numerator::Int, denominator::Int, state) + lp_num = log_potentials[numerator](state) + lp_den = log_potentials[denominator](state) + ans = lp_num-lp_den + if isnan(ans) + error("Got NaN log-unnormalized ratio; Dumping information:\n\tlp_num=$lp_num\n\tlp_den=$lp_den\n\tState=$state") + end + return ans +end diff --git a/src/paths/JuliaBUGSPath.jl b/src/paths/JuliaBUGSPath.jl new file mode 100644 index 000000000..1c67ae68d --- /dev/null +++ b/src/paths/JuliaBUGSPath.jl @@ -0,0 +1,14 @@ +""" +$SIGNATURES + +A thin wrapper around a `JuliaBUGS.BUGSModel` to provide a prior-posterior path. +To work with Pigeons, `JuliaBUGS` needs to be imported into the current session. + +$FIELDS +""" +@auto struct JuliaBUGSPath + """ + A `JuliaBUGS.BUGSModel`. + """ + model +end diff --git a/test/Project.toml b/test/Project.toml index c38b35c9c..a09c32e67 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -15,6 +15,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" +JuliaBUGS = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearRegression = "92481ed7-9fb7-40fd-80f2-46fd0f076581" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" diff --git a/test/test_JuliaBUGS.jl b/test/test_JuliaBUGS.jl new file mode 100644 index 000000000..b0fad3a6e --- /dev/null +++ b/test/test_JuliaBUGS.jl @@ -0,0 +1,91 @@ +using JuliaBUGS + +include("supporting/analytic_solutions.jl") +include("supporting/mpi_test_utils.jl") +include("../examples/JuliaBUGS.jl") + +# good ol' toy unidentifiable model for testing purposes +unid_model_def = @bugs begin + for i in 1:2 + p[i] ~ dunif(0,1) + end + p_prod = p[1]*p[2] + n_heads ~ dbin(p_prod, n_flips) +end +unid_target_model = compile(unid_model_def, (; n_heads=50000, n_flips=100000)) +unid_target = JuliaBUGSPath(unid_target_model) +unid_target_constrained = JuliaBUGSPath(JuliaBUGS.settrans(unid_target_model)) +struct IdentityExplorer end +function Pigeons.step!(::IdentityExplorer, replica, shared) end + +@testset "Basic sampling via independent MH from the prior" begin + pt = pigeons( + target = unid_target_constrained, + n_chains = 2, + explorer = IdentityExplorer(), + record = [traces], + extended_traces = true + ) + # check sample_iid! + @test isapprox(0.5, mean(v[1] for (k,v) in pt.reduced_recorders.traces if first(k)==1), rtol=0.05) + @test isapprox(0.5, mean(v[2] for (k,v) in pt.reduced_recorders.traces if first(k)==1), rtol=0.05) + + # check log_potential evaluation with constrained version (easier, no Jacobian) + @test all(v for (k,v) in pt.reduced_recorders.traces if first(k)==2) do v + last(v) == logpdf( + Binomial(unid_target_model.evaluation_env.n_flips, v[1]*v[2]), + unid_target_model.evaluation_env.n_heads + ) + end +end + +@testset "SliceSampler on constrained and unconstrained versions" begin + exact_logZ = unid_target_exact_logZ( + unid_target_model.evaluation_env.n_flips, + unid_target_model.evaluation_env.n_heads + ) + for target in (unid_target, unid_target_constrained) + @show target.model + pt = pigeons(; + target, + explorer = SliceSampler(), + n_chains=7, + n_rounds=5 + ) + @test isapprox(Pigeons.stepping_stone(pt), exact_logZ, rtol=0.1) + end +end + +@testset "Invariance test" begin + uncond_target = JuliaBUGSPath(compile(unid_model_def, (;n_flips=100000))) + res = Pigeons.invariance_test(uncond_target, SliceSampler(); condition_on=(:n_heads,)) + @show res.pvalues + @test res.passed +end + +@testset "Parallelism invariance using MPI" begin + target=incomplete_count_data() + r = pigeons(; + target=unid_target, + n_rounds = 5, + n_chains = 4, + checkpoint = true, + checked_round = 4, + multithreaded = true, + on = ChildProcess( + n_local_mpi_processes = set_n_mpis_to_one_on_windows(2), + n_threads = 2, + mpiexec_args = extra_mpi_args(), + dependencies = [JuliaBUGS] + ) + ) + pt = Pigeons.load(r) + @test true +end + +@testset "Check no NaN log potentials" begin # https://github.com/Julia-Tempering/Pigeons.jl/pull/303#issuecomment-2547306248 + target=incomplete_count_data(tau=0.01) + pt = pigeons(target = target, n_rounds = 4, n_chains = 4, record=[traces]) + chns = Chains(pt) + @test first(names(chns)) != Symbol("param_1") # check we're not using the default array-state name builder +end