diff --git a/ext/PigeonsJuliaBUGSExt/interface.jl b/ext/PigeonsJuliaBUGSExt/interface.jl index d44efe5b5..2b77927d3 100644 --- a/ext/PigeonsJuliaBUGSExt/interface.jl +++ b/ext/PigeonsJuliaBUGSExt/interface.jl @@ -80,5 +80,18 @@ function Pigeons.sample_iid!(log_potential::JuliaBUGSLogPotential, replica, shar replica.state = _sample_iid(log_potential.private_model, replica.rng) end +# parameter names Pigeons.sample_names(::Vector, log_potential::JuliaBUGSLogPotential) = [(Symbol(string(vn)) for vn in log_potential.private_model.parameters)...,:log_density] + +# Parallelism invariance +Pigeons.recursive_equal(a::Union{JuliaBUGSPath,JuliaBUGSLogPotential}, b) = + Pigeons._recursive_equal(a,b) +function Pigeons.recursive_equal(a::T, b) where T <: JuliaBUGS.BUGSModel + included = (:transformed, :model_def, :data) + excluded = Tuple(setdiff(fieldnames(T), included)) + Pigeons._recursive_equal(a,b,excluded) +end +# just check the betas match, the model is already checked within path +Pigeons.recursive_equal(a::AbstractVector{<:JuliaBUGSLogPotential}, b) = + all(lp1.beta == lp2.beta for (lp1,lp2) in zip(a,b)) diff --git a/test/test_JuliaBUGS.jl b/test/test_JuliaBUGS.jl index 36263f48f..b0fad3a6e 100644 --- a/test/test_JuliaBUGS.jl +++ b/test/test_JuliaBUGS.jl @@ -63,14 +63,14 @@ end @test res.passed end -@testset "Model with mixed state types using MPI" begin +@testset "Parallelism invariance using MPI" begin target=incomplete_count_data() r = pigeons(; - target, + target=unid_target, n_rounds = 5, n_chains = 4, checkpoint = true, - # checked_round = 4, # NB: doesn't work yet, need a fine-tuned equality check for JuliaBUGS.BUGSModel + checked_round = 4, multithreaded = true, on = ChildProcess( n_local_mpi_processes = set_n_mpis_to_one_on_windows(2),