Skip to content

Commit

Permalink
add get_param_array function (#235)
Browse files Browse the repository at this point in the history
* add get_param_array function

* Update get_param_array.jl

* Simpler utility using `rowtable`
  • Loading branch information
SamuelBrand1 authored May 22, 2024
1 parent 2d3b4c8 commit 919b8a8
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 1 deletion.
2 changes: 2 additions & 0 deletions EpiAware/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

[compat]
Expand All @@ -37,5 +38,6 @@ Random = "1.9"
Reexport = "1.2"
SparseArrays = "1.10"
Statistics = "1.10"
Tables = "1.11"
Turing = "0.30"
julia = "1.10"
4 changes: 3 additions & 1 deletion EpiAware/src/EpiAwareUtils/EpiAwareUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@ using DataFramesMeta: DataFrame, @rename!
using DynamicPPL: Model, fix, condition
using MCMCChains: Chains
using Random: AbstractRNG
using Tables: rowtable

using Distributions, DocStringExtensions, QuadGK, Statistics, Turing

#Export Structures
export HalfNormal, DirectSample

#Export functions
export scan, spread_draws, censored_pmf
export scan, spread_draws, censored_pmf, get_param_array

include("docstrings.jl")
include("censored_pmf.jl")
Expand All @@ -25,5 +26,6 @@ include("scan.jl")
include("turing-methods.jl")
include("DirectSample.jl")
include("post-inference.jl")
include("get_param_array.jl")

end
32 changes: 32 additions & 0 deletions EpiAware/src/EpiAwareUtils/get_param_array.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""
Extract a parameter array from a `Chains` object `chn` that matches the shape of number of sample
and chain pairs in `chn`.
# Arguments
- `chn::Chains`: The `Chains` object containing the MCMC samples.
# Returns
- `param_array`: An array of parameter samples, where each element corresponds to a single
MCMC sample as a `NamedTuple`.
# Example
Sampling from a simple model which has both scalar and vector quantity random variables across
4 chains.
```julia
using Turing, MCMCChains, EpiAware
@model function testmodel()
x ~ MvNormal(2, 1.)
y ~ Normal()
end
mdl = testmodel()
chn = sample(mdl, Prior(), MCMCSerial(), 250, 4)
A = get_param_array(chn)
```
"""
function get_param_array(chn::Chains)
rowtable(chn) |> x -> reshape(x, size(chn, 1), size(chn, 3))
end
14 changes: 14 additions & 0 deletions EpiAware/test/EpiAwareUtils/get_param_array.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
@testitem "test unpacking a Chains object" begin
using Turing, Distributions

@model function testmodel()
x ~ MvNormal(2, 1.0)
y ~ Normal()
end
mdl = testmodel()
chn = sample(mdl, Prior(), MCMCSerial(), 250, 4; progress = false)
A = get_param_array(chn)

@test size(A) == (size(chn)[1], size(chn)[3])
@test eltype(A) <: NamedTuple
end

0 comments on commit 919b8a8

Please sign in to comment.