Skip to content

Commit

Permalink
Merge pull request #178 from Julia-Tempering/mcmcchains-improvements
Browse files Browse the repository at this point in the history
Saving log_density; add MCMCChains ext; improve control over what gets saved to traces
  • Loading branch information
alexandrebouchard authored Nov 22, 2023
2 parents 225ffbc + 71fe190 commit 1df9a57
Show file tree
Hide file tree
Showing 40 changed files with 340 additions and 107 deletions.
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,12 @@ ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
[weakdeps]
BridgeStan = "c88b6f0a-829e-4b0b-94b7-f06ab5908f5a"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"

[extensions]
PigeonsBridgeStanExt = "BridgeStan"
PigeonsDynamicPPLExt = "DynamicPPL"
PigeonsMCMCChainsExt = "MCMCChains"

[compat]
BridgeStan = "2"
Expand Down Expand Up @@ -81,3 +83,4 @@ julia = "1.8"
[extras]
BridgeStan = "c88b6f0a-829e-4b0b-94b7-f06ab5908f5a"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ makedocs(;
"Outputs" => [
"Outputs overview" => "output-overview.md",
"Quick reports" => "output-reports.md",
"Traces" => "output-traces.md",
"Plots" => "output-plotting.md",
"log(Z)" => "output-normalization.md",
"Numerical" => "output-numerical.md",
Expand Down
2 changes: 1 addition & 1 deletion docs/src/input-julia.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ pt = pigeons(
reference = MyLogPotential(0, 0),
explorer = AutoMALA(default_autodiff_backend = :ForwardDiff),
record = [traces])
samples = Chains(sample_array(pt), variable_names(pt))
samples = Chains(pt)
my_plot = StatsPlots.plot(samples)
StatsPlots.savefig(my_plot, "julia_posterior_densities_and_traces.html");
Expand Down
8 changes: 7 additions & 1 deletion docs/src/input-nonjulian.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,10 @@ nothing # hide
```

As shown above, create a [`StreamTarget`](@ref) amounts to specifying which command will
be used to create a child process.
be used to create a child process.

To terminate the child processes associated with a stream target, use:

```@example blang
Pigeons.kill_child_processes(pt)
```
2 changes: 1 addition & 1 deletion docs/src/input-stan.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ pt = pigeons(
target = stan_unid(100, 50),
reference = stan_unid(0, 0),
record = [traces])
samples = Chains(sample_array(pt), variable_names(pt))
samples = Chains(pt)
my_plot = StatsPlots.plot(samples)
StatsPlots.savefig(my_plot, "stan_posterior_densities_and_traces.html");
Expand Down
2 changes: 1 addition & 1 deletion docs/src/input-turing.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ plotlyjs()
pt = pigeons(
target = TuringLogPotential(my_turing_model(100, 50)),
record = [traces])
samples = Chains(sample_array(pt), variable_names(pt))
samples = Chains(pt)
my_plot = StatsPlots.plot(samples)
StatsPlots.savefig(my_plot, "turing_posterior_densities_and_traces.html");
Expand Down
2 changes: 1 addition & 1 deletion docs/src/output-extended.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ pt = pigeons(target = an_unidentifiable_model,
# collect the statistics and convert to MCMCChains' Chains
# to have axes labels matching variable names in Turing and Stan
samples = Chains(sample_array(pt), variable_names(pt))
samples = Chains(pt)
# create the trace plots
my_plot = StatsPlots.plot(samples)
Expand Down
2 changes: 1 addition & 1 deletion docs/src/output-mpi-postprocessing.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ pt = load(pt_result)
# collect the statistics and convert to MCMCChains' Chains
# to have axes labels matching variable names in Turing and Stan
samples = Chains(sample_array(pt), variable_names(pt))
samples = Chains(pt)
# create the trace plots
my_plot = StatsPlots.plot(samples)
Expand Down
2 changes: 1 addition & 1 deletion docs/src/output-numerical.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ pt = pigeons(
# collect the statistics and convert to MCMCChains' Chains
# to have axes labels matching variable names in Turing and Stan
samples = Chains(sample_array(pt), variable_names(pt))
samples = Chains(pt)
samples
```
Expand Down
11 changes: 10 additions & 1 deletion docs/src/output-online.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,17 @@ using Statistics
mean(pt)
```

To be more precise, the online statistics are computed on the
result of calling [`extract_sample()`](@ref).
Use [`sample_names()`](@ref) to obtain the description of each
coordinate:

## Adding other online statistics
```@example online
sample_names(pt)
```


## Including other online statistics

The computation of online statistics makes use of
[OnlineStats.jl](https://joshday.github.io/OnlineStats.jl/latest/).
Expand Down
1 change: 1 addition & 0 deletions docs/src/output-overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ methods using either [the disk](@ref output-off-memory) or
[constant-memory statistics](@ref output-online).

- [Interpreting pigeons' standard output](@ref output-reports)
- [Working with traces](@ref output-traces)
- [Creating plots.](@ref output-plotting)
- [Approximation of the normalization constant.](@ref output-normalization)
- [Numerical summaries and diagnostics.](@ref output-numerical)
Expand Down
41 changes: 38 additions & 3 deletions docs/src/output-plotting.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ CurrentModule = Pigeons

Use [`sample_array()`](@ref) to convert target chain
samples into a format that can then be consumed by
third party packages such as
third party plotting packages such as
[MCMCChains.jl](https://github.com/TuringLang/MCMCChains.jl)
and [PairPlots.jl](https://sefffal.github.io/PairPlots.jl/).

Expand Down Expand Up @@ -40,7 +40,11 @@ pt = pigeons(target = an_unidentifiable_model,
# collect the statistics and convert to MCMCChains' Chains
# to have axes labels matching variable names in Turing and Stan
samples = Chains(sample_array(pt), variable_names(pt))
samples = Chains(sample_array(pt), sample_names(pt))
# since the above line is frequently needed, Pigeons includes
# an MCMCChains extension allowinging you to use the shorter form:
samples = Chains(pt)
# create the trace plots
my_plot = StatsPlots.plot(samples)
Expand All @@ -52,6 +56,33 @@ nothing # hide
<iframe src="../posterior_densities_and_traces.html" style="height:500px;width:100%;"></iframe>
```

## Monitoring the log density

The value of the log density is appended to each sample. Continuing the
above example, this can be seen
from the variable names indexing the flattened vector created by
[`sample_array()`](@ref):

```@example traces
sample_names(pt)
```

When using the `Chains(pt)` constructor as shown above, the
un-normalized log density is stored inside MCMCChains' "internal"
storage so will not appear in plots by default. To show it, use the following:

```@example traces
params, internals = MCMCChains.get_sections(samples)
my_plot = StatsPlots.plot(internals)
StatsPlots.savefig(my_plot, "logdensity.html");
nothing # hide
```

```@raw html
<iframe src="../logdensity.html" style="height:500px;width:100%;"></iframe>
```

## Posterior pair plots

!!! note
Expand Down Expand Up @@ -80,11 +111,15 @@ pt = pigeons(target = an_unidentifiable_model,
# make sure to record the trace:
record = [traces; round_trip; record_default()])
samples = Chains(sample_array(pt), variable_names(pt))
samples = Chains(pt)
# Warning: the line below only works for Julia 1.9
# see https://sefffal.github.io/PairPlots.jl/dev/chains/ for a workaround
my_plot = PairPlots.pairplot(samples)
CairoMakie.save("pair_plot.svg", my_plot)
nothing # hide
```

```@raw html
<iframe src="../pair_plot.svg" style="height:500px;width:100%;"></iframe>
```
101 changes: 101 additions & 0 deletions docs/src/output-traces.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
```@meta
CurrentModule = Pigeons
```

# [Saving traces](@id output-traces)

The `traces` refer to the list of samples ``X_1, X_2, \dots, X_n``
from which we can approximate expectations of the form
``E[f(X)]``, where ``X \sim \pi`` via
a Monte Carlo average of the form ``\sum_i f(X_i) / n``.

To indicate that the traces should be saved, use

```@example record-traces
using Pigeons
target = Pigeons.toy_turing_unid_target(100, 50)
pt = pigeons(; target,
n_rounds = 3,
# make sure to record the trace:
record = [traces; round_trip; record_default()])
```

Note that there are more memory efficient alternatives
to saving the full traces: see
[online (constant-memory) statistics](@ref output-online) and
[off-memory processing.](@ref output-off-memory)


## Accessing traces

Use [`get_sample`](@ref) to access the list of samples:

```@example record-traces
get_sample(pt)
```

In the special case where each state is a vector, use
[`sample_names`](@ref) to obtain description of the
vector components:

```@example record-traces
sample_names(pt)
```

Still in the special case where each state is a vector,
it is often convenient to organize the result into a single
array, this is done using [`sample_array`](@ref):

```@example record-traces
sample_array(pt)
```


## Customizing what is saved in the traces

You may want to save only some statistics of interest, or a subset of the dimensions to
take up less memory.

We show here an example saving only the
value of the first coordinate:

```@example record-traces
struct OnlyFirstExtractor end
Pigeons.extract_sample(state, log_potential, extractor::OnlyFirstExtractor) =
Pigeons.extract_sample(state, log_potential)[1:1]
pt = pigeons(; target,
n_rounds = 3,
# custom method to extract samples:
extractor = OnlyFirstExtractor(),
# make sure to record the trace:
record = [traces; round_trip; record_default()])
sample_array(pt)
```

Optionally, it is a good idea to also adjust the behaviour
of [`sample_names`](@ref) accordingly. For example, `variables_names` gets called
when creating MCMCChains object so that e.g. plots are labelled correctly.

```@example record-traces
Pigeons.sample_names(state, log_potential, extractor::OnlyFirstExtractor) =
Pigeons.sample_names(state, log_potential)[1:1]
```

To keep only the value of the log potential, you can use the following built-in [`LogPotentialExtractor`](@ref):

```@example record-traces
pt = pigeons(; target,
n_rounds = 3,
# custom method to extract samples:
extractor = Pigeons.LogPotentialExtractor(),
# make sure to record the trace:
record = [traces; round_trip; record_default()])
sample_array(pt)
```
4 changes: 2 additions & 2 deletions docs/src/unidentifiable-example.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ pt = pigeons(
record = [traces])
# collect the statistics and convert to MCMCChains' Chains
samples = Chains(sample_array(pt), variable_names(pt))
samples = Chains(pt)
# create the trace plots
my_plot = StatsPlots.plot(samples)
StatsPlots.savefig(my_plot, "no_pt_posterior_densities_and_traces.html");
Expand Down Expand Up @@ -74,7 +74,7 @@ pt = pigeons(
record = [traces, round_trip])
# collect the statistics and convert to MCMCChains' Chains
samples = Chains(sample_array(pt), variable_names(pt))
samples = Chains(pt)
# create the trace plots
my_plot = StatsPlots.plot(samples)
StatsPlots.savefig(my_plot, "with_pt_posterior_densities_and_traces.html");
Expand Down
5 changes: 0 additions & 5 deletions examples/turing-galaxy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,3 @@ end
pt = pigeons(
target = TuringLogPotential(GalaxyTuring())
)

# using StatsPlots
# samples = sample_array(pt);
# plot(Chains(samples, ["par_$i" for i in 1:size(samples)[2]])) # TODO: variable_names should detect when variables are vectors
# nothing
11 changes: 9 additions & 2 deletions ext/PigeonsBridgeStanExt/state.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ Pigeons.continuous_variables(state::Pigeons.StanState) = Pigeons.SINGLETON_VAR #
Pigeons.discrete_variables(state::Pigeons.StanState) = []

Pigeons.extract_sample(state::Pigeons.StanState, log_potential) =
BridgeStan.param_constrain(Pigeons.stan_model(log_potential), state.unconstrained_parameters; include_tp = true, include_gq = true, rng = state.rng)
[
BridgeStan.param_constrain(Pigeons.stan_model(log_potential), state.unconstrained_parameters; include_tp = true, include_gq = true, rng = state.rng);
log_potential(state)
]


function Pigeons.update_state!(state::Pigeons.StanState, name::Symbol, index, value)
Expand All @@ -22,7 +25,11 @@ end
Pigeons.step!(explorer::Pigeons.HamiltonianSampler, replica, shared, state::Pigeons.StanState) =
Pigeons.step!(explorer, replica, shared, state.unconstrained_parameters)

Pigeons.variable_names(::Pigeons.StanState, log_potential) = BridgeStan.param_names(Pigeons.stan_model(log_potential); include_tp = true, include_gq = true)
Pigeons.sample_names(::Pigeons.StanState, log_potential) =
[
BridgeStan.param_names(Pigeons.stan_model(log_potential); include_tp = true, include_gq = true);
:log_density
]


function Pigeons.slice_sample!(h::SliceSampler, state::Pigeons.StanState, log_potential, cached_lp, replica)
Expand Down
6 changes: 4 additions & 2 deletions ext/PigeonsDynamicPPLExt/state.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@ function Pigeons.extract_sample(state::DynamicPPL.TypedVarInfo, log_potential)
DynamicPPL.invlink!!(state, Pigeons.turing_model(log_potential))
result = DynamicPPL.getall(state)
DynamicPPL.link!!(state, DynamicPPL.SampleFromPrior(), Pigeons.turing_model(log_potential))
push!(result, log_potential(state))
return result
end

function Pigeons.variable_names(state::DynamicPPL.TypedVarInfo, _)
function Pigeons.sample_names(state::DynamicPPL.TypedVarInfo, _)
result = Symbol[]
all_names = fieldnames(typeof(state.metadata))
for var_name in all_names
Expand All @@ -45,6 +46,7 @@ function Pigeons.variable_names(state::DynamicPPL.TypedVarInfo, _)
error("don't know how to handle var `$var_name` of type $(typeof(var))")
end
end
push!(result, :log_density)
return result
end

Expand All @@ -70,7 +72,7 @@ end
Pigeons.recursive_equal(a::DynamicPPL.TypedVarInfo, b::DynamicPPL.TypedVarInfo) =
# as of Nov 2023, DynamicPPL does not supply == for TypedVarInfo
length(a.metadata) == length(b.metadata) &&
variable_names(a,1) == variable_names(b,1) && # second argument is not used
sample_names(a,1) == sample_names(b,1) && # second argument is not used
DynamicPPL.getall(a) == DynamicPPL.getall(b)


Expand Down
14 changes: 14 additions & 0 deletions ext/PigeonsMCMCChainsExt/PigeonsMCMCChainsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
module PigeonsMCMCChainsExt

using Pigeons
if isdefined(Base, :get_extension)
using DocStringExtensions
using MCMCChains
else
using ..DocStringExtensions
using ..MCMCChains
end

MCMCChains.Chains(pt::PT) = Chains(sample_array(pt), sample_names(pt), Dict(:internals => [:log_density]))

end
Loading

1 comment on commit 1df9a57

@nikola-sur
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yay :)

Please sign in to comment.