-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
rolling out a pipeline type system for dispatching on analysis details (
#224) * move default constructors into own folder * pipeline types and a new default constructor + tests * More default constructors * reduce to one run_pipeline * Update generate_inference_results.jl * pipeline functions * Update make_truthdata.jl * Constructor functions + tests refactored into folders and added dispatch on pipeline type * rename pipeline functions to `do_...`, refactor in folders * simulation of truth data functions/tests refactored into folders and now dispatch on pipeline type * inference functions/tests refactored into folders and now dispatch on pipeline type * Add a rmprocs at end to clear worker processes * Rewrite toy end-to-end test of inference for generated data * Update runtests.jl * Update AnalysisPipeline.jl * fix style of AnalysisPipeline to be consistent * fix using removed type * move defaults into make_ functions with a dispatch on pipeline type * fix use of default * plot function dispatch on pipeline type * Change generate inference so that name of latent model and model itself are passed as a Pair * refactor unit tests of default functions into unit tests on constructors * Update runtests.jl * Update test_full_inference.jl * remove default methods * abstract the map on inference results to another function * Update plot_functions.jl * catch failure to pass truthdata and reformat
- Loading branch information
1 parent
576e0ee
commit c980e22
Showing
43 changed files
with
480 additions
and
400 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# Local environment script to run the analysis pipeline | ||
using Pkg | ||
Pkg.activate(joinpath(@__DIR__(), "..")) | ||
using Dagger | ||
|
||
@info(""" | ||
Running the analysis pipeline. | ||
-------------------------------------------- | ||
""") | ||
|
||
# Define the backend resources to use for the pipeline | ||
# in this case we are using distributed local workers with loaded modules | ||
using Distributed | ||
pids = addprocs() | ||
|
||
@everywhere include("../src/AnalysisPipeline.jl") | ||
@everywhere using .AnalysisPipeline | ||
|
||
# Create an instance of the pipeline behaviour | ||
pipeline = RtwithoutRenewalPipeline() | ||
|
||
# Run the pipeline | ||
do_pipeline(pipeline) | ||
|
||
# Remove the workers | ||
rmprocs(pids) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,34 +1,43 @@ | ||
""" | ||
This module contains the analysis pipeline for the `Rt-without-renewal` project. | ||
# Pipeline Components | ||
In this module the meaning of a _pipeline component_ is a directed-acylic-graph | ||
(DAG) of tasks defined using `Dagger.jl` via dispatch on an `AbstractEpiAwarePipeline` | ||
sub-type from a function with prefix `do_`. A full pipeline is a sequence of DAGs, | ||
with execution determined by available computational resources. | ||
""" | ||
module AnalysisPipeline | ||
|
||
using Dates: default | ||
using CSV, Dagger, DataFramesMeta, Dates, Distributions, DocStringExtensions, DrWatson, | ||
EpiAware, Plots, Statistics, ADTypes, AbstractMCMC, Plots, JLD2 | ||
|
||
# Exported struct types | ||
export TruthSimulationConfig, InferenceConfig | ||
export AbstractEpiAwarePipeline, EpiAwarePipeline, RtwithoutRenewalPipeline, | ||
TruthSimulationConfig, InferenceConfig | ||
|
||
# Exported functions: constructors | ||
export make_gi_params, make_inf_generating_processes, make_latent_model_priors, | ||
make_epiaware_name_model_pairs, make_Rt, make_truth_data_configs, | ||
make_inference_configs, make_tspan, make_inference_method | ||
|
||
# Exported functions: pipeline components | ||
export do_truthdata, do_inference, do_pipeline | ||
|
||
# Exported functions: simulate functions | ||
export simulate, generate_truthdata | ||
|
||
# Exported functions: infer functions | ||
export infer, generate_inference_results, map_inference_results | ||
|
||
# Exported functions | ||
export simulate, infer, default_gi_params, default_Rt, default_tspan, | ||
default_latent_model_priors, default_epiaware_models, default_inference_method, | ||
default_latent_models_names, make_truth_data_configs, make_inference_configs, | ||
generate_truthdata_from_config, generate_inference_results, plot_truth_data, plot_Rt | ||
# Exported functions: plot functions | ||
export plot_truth_data, plot_Rt | ||
|
||
include("docstrings.jl") | ||
include("default_gi_params.jl") | ||
include("default_Rt.jl") | ||
include("default_tspan.jl") | ||
include("default_latent_model_priors.jl") | ||
include("default_epiaware_models.jl") | ||
include("default_inference_method.jl") | ||
include("make_truth_data_configs.jl") | ||
include("make_inference_configs.jl") | ||
include("default_latent_models_names.jl") | ||
include("TruthSimulationConfig.jl") | ||
include("InferenceConfig.jl") | ||
include("generate_truthdata.jl") | ||
include("generate_inference_results.jl") | ||
include("pipeline/pipeline.jl") | ||
include("constructors/constructors.jl") | ||
include("simulate/simulate.jl") | ||
include("infer/infer.jl") | ||
include("plot_functions.jl") | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
include("make_gi_params.jl") | ||
include("make_inf_generating_processes.jl") | ||
include("make_latent_model_priors.jl") | ||
include("make_epiaware_name_model_pairs.jl") | ||
include("make_inference_method.jl") | ||
include("make_truth_data_configs.jl") | ||
include("make_inference_configs.jl") | ||
include("make_Rt.jl") | ||
include("make_tspan.jl") | ||
include("make_inference_method.jl") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
29 changes: 29 additions & 0 deletions
29
pipeline/src/constructors/make_epiaware_name_model_pairs.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
""" | ||
Constructs a dictionary of name-model pairs for the EpiAware pipeline. This is | ||
the default method. | ||
# Arguments | ||
- `pipeline::AbstractEpiaAwarePipeline`: The EpiAware pipeline object. | ||
# Returns | ||
A dictionary containing the name-model pairs. | ||
""" | ||
function make_epiaware_name_model_pairs(pipeline::AbstractEpiAwarePipeline) | ||
prior_dict = make_latent_model_priors(pipeline) | ||
|
||
ar = AR(damp_priors = [prior_dict["damp_param_prior"]], | ||
std_prior = prior_dict["std_prior"], | ||
init_priors = [prior_dict["transformed_process_init_prior"]]) | ||
|
||
rw = RandomWalk( | ||
std_prior = prior_dict["std_prior"], init_prior = prior_dict["transformed_process_init_prior"]) | ||
|
||
diff_ar = DiffLatentModel(; | ||
model = ar, init_priors = [prior_dict["transformed_process_init_prior"]]) | ||
|
||
wkly_ar, wkly_rw, wkly_diff_ar = [ar, rw, diff_ar] .|> | ||
model -> BroadcastLatentModel(model, 7, RepeatBlock()) | ||
|
||
return ["wkly_ar" => wkly_ar, "wkly_rw" => wkly_rw, "wkly_diff_ar" => wkly_diff_ar] | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
""" | ||
Constructs a dictionary of GI (Generation Interval) parameters. This is the | ||
default method. | ||
# Arguments | ||
- `pipeline`: An instance of the `AbstractEpiAwarePipeline` type. | ||
# Returns | ||
A dictionary containing the GI means and GI standard deviations. | ||
""" | ||
function make_gi_params(pipeline::AbstractEpiAwarePipeline) | ||
gi_means = [2.0, 10.0, 20.0] | ||
gi_stds = [2.0] | ||
return Dict("gi_means" => gi_means, "gi_stds" => gi_stds) | ||
end |
14 changes: 14 additions & 0 deletions
14
pipeline/src/constructors/make_inf_generating_processes.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
""" | ||
Constructs and returns a vector of infection-generating process types for the given | ||
pipeline. This is the default method. | ||
# Arguments | ||
- `pipeline`: An instance of `AbstractEpiAwarePipeline`. | ||
# Returns | ||
An array of infection-generating process types. | ||
""" | ||
function make_inf_generating_processes(pipeline::AbstractEpiAwarePipeline) | ||
return [DirectInfections, ExpGrowthRate, Renewal] | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
""" | ||
Create inference configurations for the given pipeline. This is the default method. | ||
# Arguments | ||
- `pipeline`: An instance of `AbstractEpiAwarePipeline`. | ||
# Returns | ||
- An object representing the inference configurations. | ||
""" | ||
function make_inference_configs(pipeline::AbstractEpiAwarePipeline) | ||
gi_param_dict = make_gi_params(pipeline) | ||
namemodel_vect = make_epiaware_name_model_pairs(pipeline) | ||
igps = make_inf_generating_processes(pipeline) | ||
|
||
inference_configs = Dict("igp" => igps, "latent_namemodels" => namemodel_vect, | ||
"gi_mean" => gi_param_dict["gi_means"], "gi_std" => gi_param_dict["gi_stds"]) |> | ||
dict_list | ||
|
||
return inference_configs | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
""" | ||
Constructs an inference method for the given pipeline. This is a default method. | ||
# Arguments | ||
- `pipeline`: An instance of `AbstractEpiAwarePipeline`. | ||
# Returns | ||
- An inference method. | ||
""" | ||
function make_inference_method(pipeline::AbstractEpiAwarePipeline; ndraws::Integer = 2000, | ||
mcmc_ensemble::AbstractMCMC.AbstractMCMCEnsemble = MCMCSerial(), | ||
nruns_pthf::Integer = 4, maxiters_pthf::Integer = 100, nchains::Integer = 4) | ||
return EpiMethod( | ||
pre_sampler_steps = [ManyPathfinder(nruns = nruns_pthf, maxiters = maxiters_pthf)], | ||
sampler = NUTSampler(adtype = AutoForwardDiff(), ndraws = ndraws, | ||
nchains = nchains, mcmc_parallel = mcmc_ensemble) | ||
) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
""" | ||
Constructs and returns a dictionary of prior distributions for the latent model | ||
parameters. This is the default method. | ||
# Arguments | ||
- `pipeline`: An instance of the `AbstractEpiAwarePipeline` type. | ||
# Returns | ||
A dictionary containing the following prior distributions: | ||
- `"transformed_process_init_prior"`: A normal distribution with mean 0.0 and | ||
standard deviation 0.25. | ||
- `"std_prior"`: A half-normal distribution with standard deviation 0.25. | ||
- `"damp_param_prior"`: A beta distribution with shape parameters 0.5 and 0.5. | ||
""" | ||
function make_latent_model_priors(pipeline::AbstractEpiAwarePipeline) | ||
transformed_process_init_prior = Normal(0.0, 0.25) | ||
std_prior = HalfNormal(0.25) | ||
damp_param_prior = Beta(0.5, 0.5) | ||
|
||
return Dict( | ||
"transformed_process_init_prior" => transformed_process_init_prior, | ||
"std_prior" => std_prior, | ||
"damp_param_prior" => damp_param_prior | ||
) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
""" | ||
Create a dictionary of truth data configurations for `pipeline <: AbstractEpiAwarePipeline`. | ||
This is the default method. | ||
# Returns | ||
A vector of dictionaries containing the mean and standard deviation values for | ||
the generation interval. | ||
""" | ||
function make_truth_data_configs(pipeline::AbstractEpiAwarePipeline) | ||
gi_param_dict = make_gi_params(pipeline) | ||
return Dict( | ||
"gi_mean" => gi_param_dict["gi_means"], "gi_std" => gi_param_dict["gi_stds"]) |> | ||
dict_list | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
""" | ||
Constructs the time span for the given `pipeline` object. | ||
# Arguments | ||
- `pipeline::AbstractEpiAwarePipeline`: The pipeline object for which the time | ||
span is constructed. This is the default method. | ||
# Returns | ||
- `tspan::Tuple{Float64, Float64}`: The time span as a tuple of start and end times. | ||
""" | ||
function make_tspan(pipeline::AbstractEpiAwarePipeline; backhorizon = 21) | ||
N = size(make_Rt(pipeline), 1) | ||
@assert backhorizon<N "Backhorizon must be less than the length of the default Rt." | ||
return (1, N - backhorizon) | ||
end |
Oops, something went wrong.