Skip to content

Commit

Permalink
Update sampleposterior for multiple data sets; add simple method for …
Browse files Browse the repository at this point in the history
…assess(est::PosteriorEstimator)
  • Loading branch information
msainsburydale committed Feb 24, 2025
1 parent abd6248 commit 00b6e79
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 13 deletions.
1 change: 1 addition & 0 deletions ext/NeuralEstimatorsPlotExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ function plot(assessment::Assessment; grid::Bool = false)

linkyaxes=:none
if "estimate" names(df) #TODO only want this for point estimates
#TODO fix the vertical axis to have the same limits as the horizontal axis
if num_estimators > 1
colors = [unique(df.estimator)[i] => ColorSchemes.Set1_4.colors[i] for i 1:num_estimators]
if grid
Expand Down
4 changes: 2 additions & 2 deletions src/ApproximateDistributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ end

function sampleposterior(flow::NormalisingFlow, TZ::AbstractMatrix, N::Integer; use_gpu::Bool = true)

@assert size(TZ, 2) == 1
@assert size(TZ, 2) == 1

# Sample from the base distribution (standard Gaussian) and repeat TZ to match the desired sample size
U = randn(Float32, flow.d, N)
Expand All @@ -373,4 +373,4 @@ function sampleposterior(flow::NormalisingFlow, TZ::AbstractMatrix, N::Integer;

return cpu(θ)
end
sampleposterior(flow::NormalisingFlow, TZ::AbstractVector, N::Integer; kwargs...) = sampleposterior(flow, reshape(TZ, :, 1), N; kwargs...)
sampleposterior(flow::NormalisingFlow, TZ::AbstractVector, N::Integer; kwargs...) = sampleposterior(flow, reshape(TZ, :, 1), N; kwargs...)
3 changes: 1 addition & 2 deletions src/Estimators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ estimator = PosteriorEstimator(q, network)
estimator = train(estimator, sample, simulate, m = m)
# Inference with observed data
θ = [0.8f0; 0.1f0]
θ = [0.8f0 0.1f0]'
Z = simulate(θ, m)
sampleposterior(estimator, Z) # posterior draws
posteriormean(estimator, Z) # point estimate
Expand All @@ -466,7 +466,6 @@ end
numdistributionalparams(estimator::PosteriorEstimator) = numdistributionalparams(estimator.q)
logdensity(estimator::PosteriorEstimator, θ, Z) = logdensity(estimator.q, θ, estimator.network(Z))
(estimator::PosteriorEstimator)(Zθ::Tuple) = logdensity(estimator, Zθ[2], Zθ[1]) # internal method only used during training # TODO not ideal that we assume an ordering here
sampleposterior(estimator::PosteriorEstimator, Z, N::Integer = 1000) = sampleposterior(estimator.q, estimator.network(Z), N)

## Alternatively, to use a Gaussian approximate distribution:
# q = GaussianDistribution(d)
Expand Down
64 changes: 64 additions & 0 deletions src/assess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,70 @@ function assess(
return Assessment(df, runtime)
end

function assess(
estimator::PosteriorEstimator,
θ::P, Z;
parameter_names::Vector{String} = ["θ$i" for i 1:size(θ, 1)],
estimator_name::Union{Nothing, String} = nothing,
estimator_names::Union{Nothing, String} = nothing,
N::Integer = 1000,
kwargs...
) where {P <: Union{AbstractMatrix, ParameterConfigurations}}

# Extract the matrix of parameters
θ = _extractθ(θ)
p, K = size(θ)

# Check the size of the test data conforms with θ
m = numberreplicates(Z)
if !(typeof(m) <: Vector{Int}) # a vector of vectors has been given, attempt to convert Z to the correct format
Z = reduce(vcat, Z)
m = numberreplicates(Z)
end
KJ = length(m) # NB this can be different to length(Z) when we have set-level information, in which case length(Z) = 2
@assert KJ % K == 0 "The number of data sets in `Z` must be a multiple of the number of parameter vectors in `θ`"
J = KJ ÷ K
if J > 1
θ = repeat(θ, outer = (1, J))
end
if θ isa NamedMatrix
parameter_names = names(θ, 1)
end
@assert length(parameter_names) == p

estimate_names = parameter_names

runtime = @elapsed θ̂ = posteriormedian(estimator, Z, N; kwargs...)

# Convert to DataFrame and add information
runtime = DataFrame(runtime = runtime)
θ̂ = DataFrame(θ̂', estimate_names)
θ̂[!, "m"] = m
θ̂[!, "k"] = repeat(1:K, J)
θ̂[!, "j"] = repeat(1:J, inner = K)

# Add estimator name if it was provided
if !isnothing(estimator_names) estimator_name = estimator_names end # deprecation coercion
if !isnothing(estimator_name)
θ̂[!, "estimator"] .= estimator_name
runtime[!, "estimator"] .= estimator_name
end

# Dataframe containing the true parameters
θ = convert(Matrix, θ)
θ = DataFrame', parameter_names)
# Replicate θ to match the number of rows in θ̂. Note that the parameter
# configuration, k, is the fastest running variable in θ̂, so we repeat θ
# in an outer fashion.
θ = repeat(θ, outer = nrow(θ̂) ÷ nrow(θ))
θ = stack(θ, variable_name = :parameter, value_name = :truth) # transform to long form

# Merge true parameters and estimates
df = _merge(θ, θ̂)

return Assessment(df, runtime)
end


function assess(
estimator::Union{IntervalEstimator, Ensemble{<:IntervalEstimator}},
Expand Down
35 changes: 27 additions & 8 deletions src/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ Computes the posterior mean $_doc_string
See also [`posteriormedian()`](@ref), [`posteriormode()`](@ref).
"""
posteriormean::AbstractMatrix) = mean(θ; dims = 2)
posteriormean::AbstractVector{<:AbstractMatrix}) = reduce(hcat, posteriormean.(θ))
posteriormean(estimator::Union{PosteriorEstimator, RatioEstimator}, Z, N::Integer = 1000; kwargs...) = posteriormean(sampleposterior(estimator, Z, N; kwargs...))

"""
Expand All @@ -72,6 +73,7 @@ Computes the vector of marginal posterior medians $_doc_string
See also [`posteriormean()`](@ref), [`posteriorquantile()`](@ref).
"""
posteriormedian::AbstractMatrix) = median(θ; dims = 2)
posteriormedian::AbstractVector{<:AbstractMatrix}) = reduce(hcat, posteriormedian.(θ))
posteriormedian(estimator::Union{PosteriorEstimator, RatioEstimator}, Z, N::Integer = 1000; kwargs...) = posteriormedian(sampleposterior(estimator, Z, N; kwargs...))

"""
Expand All @@ -91,29 +93,29 @@ posteriorquantile(estimator::Union{PosteriorEstimator, RatioEstimator}, Z, probs

# ---- Posterior sampling ----

#TODO Parallel computations in outer broadcasting functions
#TODO Basic MCMC sampler (initialised with θ₀)
@doc raw"""
sampleposterior(estimator::PosteriorEstimator, Z, N::Integer = 1000)
sampleposterior(estimator::RatioEstimator, Z, N::Integer = 1000; θ_grid, prior::Function = θ -> 1f0)
Samples from the approximate posterior distribution implied by `estimator`.
The positional argument `N` controls the size of the posterior sample.
Returns $d$ × `N` matrix of posterior samples, where $d$ is the dimension of the parameter vector. If `Z` is a vector containing multiple data sets, a vector of matrices will be returned
When sampling based on a `RatioEstimator`, the sampling algorithm is based on a fine-gridding of the
parameter space, specified through the keyword argument `θ_grid` (or `theta_grid`).
The approximate posterior density is evaluated over this grid, which is then
used to draw samples. This is very effective when making inference with a
used to draw samples. This is effective when making inference with a
small number of parameters. For models with a large number of parameters,
other sampling algorithms may be needed (please feel free to contact the
package maintainer for discussion). The prior distribution $p(\boldsymbol{\theta})$ is controlled through the keyword argument `prior` (by default, a uniform prior is used).
other sampling algorithms may be needed (please contact the package maintainer).
The prior distribution $p(\boldsymbol{\theta})$ is controlled through the keyword argument `prior` (by default, a uniform prior is used).
"""
function sampleposterior(est::RatioEstimator,
Z,
N::Integer = 1000;
prior::Function = θ -> 1f0,
θ_grid = nothing, theta_grid = nothing,
# θ₀ = nothing, theta0 = nothing,
# θ₀ = nothing, theta0 = nothing, #TODO Basic MCMC sampler (initialised with θ₀)
kwargs...)

# Check duplicated arguments that are needed so that the R interface uses ASCII characters only
Expand All @@ -135,10 +137,27 @@ function sampleposterior(est::RatioEstimator,
reduce(hcat, θ)
end
end
function sampleposterior(est::RatioEstimator, Z::AbstractVector, args...; kwargs...)
sampleposterior.(Ref(est), Z, args...; kwargs...)
function sampleposterior(est::RatioEstimator, Z::AbstractVector, N::Integer = 1000; kwargs...)
# Simple broadcasting to handle multiple data sets (NB this could be done in parallel)
if length(Z) == 1
sampleposterior(est, Z[1], N; kwargs...)
else
sampleposterior.(Ref(est), Z, N; kwargs...)
end
end

sampleposterior(estimator::PosteriorEstimator, Z, N::Integer = 1000; kwargs...) = sampleposterior(estimator.q, estimator.network(Z), N; kwargs...)
function sampleposterior(est::PosteriorEstimator, Z::AbstractVector, N::Integer = 1000; kwargs...)
# Simple broadcasting to handle multiple data sets (NB this could be done in parallel)
if length(Z) == 1
sampleposterior(est, Z[1], N; kwargs...)
else
sampleposterior.(Ref(est), Z, N; kwargs...)
end
end



# ---- Optimisation-based point estimates ----

# TODO density evaluation with PosteriorEstimator would allow us to use these grid and gradient-based methods for computing the posterior mode
Expand Down
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1230,11 +1230,12 @@ end
estimator = PosteriorEstimator(q, network)
estimator = train(estimator, sample, simulate, m = m, epochs = 1, verbose = false)
@test numdistributionalparams(estimator) == numdistributionalparams(q)
θ = [0.8f0; 0.1f0]
θ = [0.8f0 0.1f0]'
Z = simulate(θ, m)
sampleposterior(estimator, Z) # posterior draws
posteriormean(estimator, Z) # point estimate
posteriorquantile(estimator, Z, [0.1, 0.5]) # quantiles
assessment = assess(estimator, θ, Z)
end
end

Expand Down

0 comments on commit 00b6e79

Please sign in to comment.