Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add get_param_array function #235

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions EpiAware/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

[compat]
Expand All @@ -37,5 +38,6 @@ Random = "1.9"
Reexport = "1.2"
SparseArrays = "1.10"
Statistics = "1.10"
Tables = "1.11"
Turing = "0.30"
julia = "1.10"
4 changes: 3 additions & 1 deletion EpiAware/src/EpiAwareUtils/EpiAwareUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@ using DataFramesMeta: DataFrame, @rename!
using DynamicPPL: Model, fix, condition
using MCMCChains: Chains
using Random: AbstractRNG
using Tables: rowtable

using Distributions, DocStringExtensions, QuadGK, Statistics, Turing

#Export Structures
export HalfNormal, DirectSample

#Export functions
export scan, spread_draws, censored_pmf
export scan, spread_draws, censored_pmf, get_param_array

include("docstrings.jl")
include("censored_pmf.jl")
Expand All @@ -25,5 +26,6 @@ include("scan.jl")
include("turing-methods.jl")
include("DirectSample.jl")
include("post-inference.jl")
include("get_param_array.jl")

end
32 changes: 32 additions & 0 deletions EpiAware/src/EpiAwareUtils/get_param_array.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""
Extract a parameter array from a `Chains` object `chn` that matches the shape of number of sample
and chain pairs in `chn`.

# Arguments
- `chn::Chains`: The `Chains` object containing the MCMC samples.

# Returns
- `param_array`: An array of parameter samples, where each element corresponds to a single
MCMC sample as a `NamedTuple`.

# Example

Sampling from a simple model which has both scalar and vector quantity random variables across
4 chains.

```julia
using Turing, MCMCChains, EpiAware

@model function testmodel()
x ~ MvNormal(2, 1.)
y ~ Normal()
end
mdl = testmodel()
chn = sample(mdl, Prior(), MCMCSerial(), 250, 4)

A = get_param_array(chn)
```
"""
function get_param_array(chn::Chains)
rowtable(chn) |> x -> reshape(x, size(chn, 1), size(chn, 3))
end
14 changes: 14 additions & 0 deletions EpiAware/test/EpiAwareUtils/get_param_array.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
@testitem "test unpacking a Chains object" begin
using Turing, Distributions

@model function testmodel()
x ~ MvNormal(2, 1.0)
y ~ Normal()
end
mdl = testmodel()
chn = sample(mdl, Prior(), MCMCSerial(), 250, 4; progress = false)
A = get_param_array(chn)

@test size(A) == (size(chn)[1], size(chn)[3])
@test eltype(A) <: NamedTuple
end
Loading