-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add get_param_array function * Update get_param_array.jl * Simpler utility using `rowtable`
- Loading branch information
1 parent
2d3b4c8
commit 919b8a8
Showing
4 changed files
with
51 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
""" | ||
Extract a parameter array from a `Chains` object `chn` that matches the shape of number of sample | ||
and chain pairs in `chn`. | ||
# Arguments | ||
- `chn::Chains`: The `Chains` object containing the MCMC samples. | ||
# Returns | ||
- `param_array`: An array of parameter samples, where each element corresponds to a single | ||
MCMC sample as a `NamedTuple`. | ||
# Example | ||
Sampling from a simple model which has both scalar and vector quantity random variables across | ||
4 chains. | ||
```julia | ||
using Turing, MCMCChains, EpiAware | ||
@model function testmodel() | ||
x ~ MvNormal(2, 1.) | ||
y ~ Normal() | ||
end | ||
mdl = testmodel() | ||
chn = sample(mdl, Prior(), MCMCSerial(), 250, 4) | ||
A = get_param_array(chn) | ||
``` | ||
""" | ||
function get_param_array(chn::Chains) | ||
rowtable(chn) |> x -> reshape(x, size(chn, 1), size(chn, 3)) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
@testitem "test unpacking a Chains object" begin | ||
using Turing, Distributions | ||
|
||
@model function testmodel() | ||
x ~ MvNormal(2, 1.0) | ||
y ~ Normal() | ||
end | ||
mdl = testmodel() | ||
chn = sample(mdl, Prior(), MCMCSerial(), 250, 4; progress = false) | ||
A = get_param_array(chn) | ||
|
||
@test size(A) == (size(chn)[1], size(chn)[3]) | ||
@test eltype(A) <: NamedTuple | ||
end |