Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Saving log_density; add MCMCChains ext; improve control over what gets saved to traces #178

Merged
merged 12 commits into from
Nov 22, 2023
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,12 @@ ZipFile = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
[weakdeps]
BridgeStan = "c88b6f0a-829e-4b0b-94b7-f06ab5908f5a"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"

[extensions]
PigeonsBridgeStanExt = "BridgeStan"
PigeonsDynamicPPLExt = "DynamicPPL"
PigeonsMCMCChainsExt = "MCMCChains"

[compat]
BridgeStan = "2"
Expand Down Expand Up @@ -81,3 +83,4 @@ julia = "1.8"
[extras]
BridgeStan = "c88b6f0a-829e-4b0b-94b7-f06ab5908f5a"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ makedocs(;
"Outputs" => [
"Outputs overview" => "output-overview.md",
"Quick reports" => "output-reports.md",
"Traces" => "output-traces.md",
"Plots" => "output-plotting.md",
"log(Z)" => "output-normalization.md",
"Numerical" => "output-numerical.md",
Expand Down
2 changes: 1 addition & 1 deletion docs/src/input-julia.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ pt = pigeons(
reference = MyLogPotential(0, 0),
explorer = AutoMALA(default_autodiff_backend = :ForwardDiff),
record = [traces])
samples = Chains(sample_array(pt), variable_names(pt))
samples = Chains(pt)
my_plot = StatsPlots.plot(samples)
StatsPlots.savefig(my_plot, "julia_posterior_densities_and_traces.html");
Expand Down
8 changes: 7 additions & 1 deletion docs/src/input-nonjulian.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,10 @@ nothing # hide
```

As shown above, create a [`StreamTarget`](@ref) amounts to specifying which command will
be used to create a child process.
be used to create a child process.

To terminate the child processes associated with a stream target, use:

```@example blang
Pigeons.kill_child_processes(pt)
```
2 changes: 1 addition & 1 deletion docs/src/input-stan.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ pt = pigeons(
target = stan_unid(100, 50),
reference = stan_unid(0, 0),
record = [traces])
samples = Chains(sample_array(pt), variable_names(pt))
samples = Chains(pt)
my_plot = StatsPlots.plot(samples)
StatsPlots.savefig(my_plot, "stan_posterior_densities_and_traces.html");
Expand Down
2 changes: 1 addition & 1 deletion docs/src/input-turing.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ plotlyjs()
pt = pigeons(
target = TuringLogPotential(my_turing_model(100, 50)),
record = [traces])
samples = Chains(sample_array(pt), variable_names(pt))
samples = Chains(pt)
my_plot = StatsPlots.plot(samples)
StatsPlots.savefig(my_plot, "turing_posterior_densities_and_traces.html");

Expand Down
2 changes: 1 addition & 1 deletion docs/src/output-extended.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ pt = pigeons(target = an_unidentifiable_model,

# collect the statistics and convert to MCMCChains' Chains
# to have axes labels matching variable names in Turing and Stan
samples = Chains(sample_array(pt), variable_names(pt))
samples = Chains(pt)

# create the trace plots
my_plot = StatsPlots.plot(samples)
Expand Down
2 changes: 1 addition & 1 deletion docs/src/output-mpi-postprocessing.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ pt = load(pt_result)
# collect the statistics and convert to MCMCChains' Chains
# to have axes labels matching variable names in Turing and Stan
samples = Chains(sample_array(pt), variable_names(pt))
samples = Chains(pt)
# create the trace plots
my_plot = StatsPlots.plot(samples)
Expand Down
2 changes: 1 addition & 1 deletion docs/src/output-numerical.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ pt = pigeons(
# collect the statistics and convert to MCMCChains' Chains
# to have axes labels matching variable names in Turing and Stan
samples = Chains(sample_array(pt), variable_names(pt))
samples = Chains(pt)
samples
```
Expand Down
11 changes: 10 additions & 1 deletion docs/src/output-online.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,17 @@ using Statistics
mean(pt)
```

To be more precise, the online statistics are computed on the
result of calling [`extract_sample()`](@ref).
Use [`sample_names()`](@ref) to obtain the description of each
coordinate:

## Adding other online statistics
```@example online
sample_names(pt)
```


## Including other online statistics

The computation of online statistics makes use of
[OnlineStats.jl](https://joshday.github.io/OnlineStats.jl/latest/).
Expand Down
1 change: 1 addition & 0 deletions docs/src/output-overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ methods using either [the disk](@ref output-off-memory) or
[constant-memory statistics](@ref output-online).

- [Interpreting pigeons' standard output](@ref output-reports)
- [Working with traces](@ref output-traces)
- [Creating plots.](@ref output-plotting)
- [Approximation of the normalization constant.](@ref output-normalization)
- [Numerical summaries and diagnostics.](@ref output-numerical)
Expand Down
41 changes: 38 additions & 3 deletions docs/src/output-plotting.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ CurrentModule = Pigeons

Use [`sample_array()`](@ref) to convert target chain
samples into a format that can then be consumed by
third party packages such as
third party plotting packages such as
[MCMCChains.jl](https://github.com/TuringLang/MCMCChains.jl)
and [PairPlots.jl](https://sefffal.github.io/PairPlots.jl/).

Expand Down Expand Up @@ -40,7 +40,11 @@ pt = pigeons(target = an_unidentifiable_model,

# collect the statistics and convert to MCMCChains' Chains
# to have axes labels matching variable names in Turing and Stan
samples = Chains(sample_array(pt), variable_names(pt))
samples = Chains(sample_array(pt), sample_names(pt))

# since the above line is frequently needed, Pigeons includes
# an MCMCChains extension allowinging you to use the shorter form:
samples = Chains(pt)

# create the trace plots
my_plot = StatsPlots.plot(samples)
Expand All @@ -52,6 +56,33 @@ nothing # hide
<iframe src="../posterior_densities_and_traces.html" style="height:500px;width:100%;"></iframe>
```

## Monitoring the log density

The value of the log density is appended to each sample. Continuing the
above example, this can be seen
from the variable names indexing the flattened vector created by
[`sample_array()`](@ref):

```@example traces
sample_names(pt)
```

When using the `Chains(pt)` constructor as shown above, the
un-normalized log density is stored inside MCMCChains' "internal"
storage so will not appear in plots by default. To show it, use the following:

```@example traces
params, internals = MCMCChains.get_sections(samples)

my_plot = StatsPlots.plot(internals)
StatsPlots.savefig(my_plot, "logdensity.html");
nothing # hide
```

```@raw html
<iframe src="../logdensity.html" style="height:500px;width:100%;"></iframe>
```

## Posterior pair plots

!!! note
Expand Down Expand Up @@ -80,11 +111,15 @@ pt = pigeons(target = an_unidentifiable_model,
# make sure to record the trace:
record = [traces; round_trip; record_default()])

samples = Chains(sample_array(pt), variable_names(pt))
samples = Chains(pt)
# Warning: the line below only works for Julia 1.9
# see https://sefffal.github.io/PairPlots.jl/dev/chains/ for a workaround
my_plot = PairPlots.pairplot(samples)

CairoMakie.save("pair_plot.svg", my_plot)
nothing # hide
```

```@raw html
<iframe src="../pair_plot.svg" style="height:500px;width:100%;"></iframe>
```
101 changes: 101 additions & 0 deletions docs/src/output-traces.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
```@meta
CurrentModule = Pigeons
```

# [Saving traces](@id output-traces)

The `traces` refer to the list of samples ``X_1, X_2, \dots, X_n``
from which we can approximate expectations of the form
``E[f(X)]``, where ``X \sim \pi`` via
a Monte Carlo average of the form ``\sum_i f(X_i) / n``.

To indicate that the traces should be saved, use

```@example record-traces
using Pigeons

target = Pigeons.toy_turing_unid_target(100, 50)

pt = pigeons(; target,
n_rounds = 3,
# make sure to record the trace:
record = [traces; round_trip; record_default()])
```

Note that there are more memory efficient alternatives
to saving the full traces: see
[online (constant-memory) statistics](@ref output-online) and
[off-memory processing.](@ref output-off-memory)


## Accessing traces

Use [`get_sample`](@ref) to access the list of samples:

```@example record-traces
get_sample(pt)
```

In the special case where each state is a vector, use
[`sample_names`](@ref) to obtain description of the
vector components:

```@example record-traces
sample_names(pt)
```

Still in the special case where each state is a vector,
it is often convenient to organize the result into a single
array, this is done using [`sample_array`](@ref):

```@example record-traces
sample_array(pt)
```


## Customizing what is saved in the traces

You may want to save only some statistics of interest, or a subset of the dimensions to
take up less memory.

We show here an example saving only the
value of the first coordinate:

```@example record-traces
struct OnlyFirstExtractor end

Pigeons.extract_sample(state, log_potential, extractor::OnlyFirstExtractor) =
Pigeons.extract_sample(state, log_potential)[1:1]


pt = pigeons(; target,
n_rounds = 3,
# custom method to extract samples:
extractor = OnlyFirstExtractor(),
# make sure to record the trace:
record = [traces; round_trip; record_default()])

sample_array(pt)
```

Optionally, it is a good idea to also adjust the behaviour
of [`sample_names`](@ref) accordingly. For example, `variables_names` gets called
when creating MCMCChains object so that e.g. plots are labelled correctly.

```@example record-traces
Pigeons.sample_names(state, log_potential, extractor::OnlyFirstExtractor) =
Pigeons.sample_names(state, log_potential)[1:1]
```

To keep only the value of the log potential, you can use the following built-in [`LogPotentialExtractor`](@ref):

```@example record-traces
pt = pigeons(; target,
n_rounds = 3,
# custom method to extract samples:
extractor = Pigeons.LogPotentialExtractor(),
# make sure to record the trace:
record = [traces; round_trip; record_default()])

sample_array(pt)
```
4 changes: 2 additions & 2 deletions docs/src/unidentifiable-example.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ pt = pigeons(
record = [traces])
# collect the statistics and convert to MCMCChains' Chains
samples = Chains(sample_array(pt), variable_names(pt))
samples = Chains(pt)
# create the trace plots
my_plot = StatsPlots.plot(samples)
StatsPlots.savefig(my_plot, "no_pt_posterior_densities_and_traces.html");
Expand Down Expand Up @@ -74,7 +74,7 @@ pt = pigeons(
record = [traces, round_trip])
# collect the statistics and convert to MCMCChains' Chains
samples = Chains(sample_array(pt), variable_names(pt))
samples = Chains(pt)
# create the trace plots
my_plot = StatsPlots.plot(samples)
StatsPlots.savefig(my_plot, "with_pt_posterior_densities_and_traces.html");
Expand Down
5 changes: 0 additions & 5 deletions examples/turing-galaxy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,3 @@ end
pt = pigeons(
target = TuringLogPotential(GalaxyTuring())
)

# using StatsPlots
# samples = sample_array(pt);
# plot(Chains(samples, ["par_$i" for i in 1:size(samples)[2]])) # TODO: variable_names should detect when variables are vectors
# nothing
11 changes: 9 additions & 2 deletions ext/PigeonsBridgeStanExt/state.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ Pigeons.continuous_variables(state::Pigeons.StanState) = Pigeons.SINGLETON_VAR #
Pigeons.discrete_variables(state::Pigeons.StanState) = []

Pigeons.extract_sample(state::Pigeons.StanState, log_potential) =
BridgeStan.param_constrain(Pigeons.stan_model(log_potential), state.unconstrained_parameters; include_tp = true, include_gq = true, rng = state.rng)
[
BridgeStan.param_constrain(Pigeons.stan_model(log_potential), state.unconstrained_parameters; include_tp = true, include_gq = true, rng = state.rng);
log_potential(state)
]


function Pigeons.update_state!(state::Pigeons.StanState, name::Symbol, index, value)
Expand All @@ -22,7 +25,11 @@ end
Pigeons.step!(explorer::Pigeons.HamiltonianSampler, replica, shared, state::Pigeons.StanState) =
Pigeons.step!(explorer, replica, shared, state.unconstrained_parameters)

Pigeons.variable_names(::Pigeons.StanState, log_potential) = BridgeStan.param_names(Pigeons.stan_model(log_potential); include_tp = true, include_gq = true)
Pigeons.sample_names(::Pigeons.StanState, log_potential) =
[
BridgeStan.param_names(Pigeons.stan_model(log_potential); include_tp = true, include_gq = true);
:log_density
]


function Pigeons.slice_sample!(h::SliceSampler, state::Pigeons.StanState, log_potential, cached_lp, replica)
Expand Down
6 changes: 4 additions & 2 deletions ext/PigeonsDynamicPPLExt/state.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@ function Pigeons.extract_sample(state::DynamicPPL.TypedVarInfo, log_potential)
DynamicPPL.invlink!!(state, Pigeons.turing_model(log_potential))
result = DynamicPPL.getall(state)
DynamicPPL.link!!(state, DynamicPPL.SampleFromPrior(), Pigeons.turing_model(log_potential))
push!(result, log_potential(state))
return result
end

function Pigeons.variable_names(state::DynamicPPL.TypedVarInfo, _)
function Pigeons.sample_names(state::DynamicPPL.TypedVarInfo, _)
result = Symbol[]
all_names = fieldnames(typeof(state.metadata))
for var_name in all_names
Expand All @@ -45,6 +46,7 @@ function Pigeons.variable_names(state::DynamicPPL.TypedVarInfo, _)
error("don't know how to handle var `$var_name` of type $(typeof(var))")
end
end
push!(result, :log_density)
return result
end

Expand All @@ -70,7 +72,7 @@ end
Pigeons.recursive_equal(a::DynamicPPL.TypedVarInfo, b::DynamicPPL.TypedVarInfo) =
# as of Nov 2023, DynamicPPL does not supply == for TypedVarInfo
length(a.metadata) == length(b.metadata) &&
variable_names(a,1) == variable_names(b,1) && # second argument is not used
sample_names(a,1) == sample_names(b,1) && # second argument is not used
DynamicPPL.getall(a) == DynamicPPL.getall(b)


Expand Down
14 changes: 14 additions & 0 deletions ext/PigeonsMCMCChainsExt/PigeonsMCMCChainsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
module PigeonsMCMCChainsExt

using Pigeons
if isdefined(Base, :get_extension)
using DocStringExtensions
using MCMCChains
else
using ..DocStringExtensions
using ..MCMCChains
end

MCMCChains.Chains(pt::PT) = Chains(sample_array(pt), sample_names(pt), Dict(:internals => [:log_density]))

end
Loading