Skip to content

Commit

Permalink
rolling out a pipeline type system for dispatching on analysis details (
Browse files Browse the repository at this point in the history
#224)

* move default constructors into own folder

* pipeline types and a new default constructor + tests

* More default constructors

* reduce to one run_pipeline

* Update generate_inference_results.jl

* pipeline functions

* Update make_truthdata.jl

* Constructor functions + tests refactored into folders and added dispatch on pipeline type

* rename pipeline functions to `do_...`, refactor in folders

* simulation of truth data functions/tests refactored into folders and now dispatch on pipeline type

* inference functions/tests refactored into folders and now dispatch on pipeline type

* Add a rmprocs at end to clear worker processes

* Rewrite toy end-to-end test of inference for generated data

* Update runtests.jl

* Update AnalysisPipeline.jl

* fix style of AnalysisPipeline to be consistent

* fix using removed type

* move defaults into make_ functions with a dispatch on pipeline type

* fix use of default

* plot function dispatch on pipeline type

* Change generate inference so that name of latent model and model itself are passed as a Pair

* refactor unit tests of default functions into unit tests on constructors

* Update runtests.jl

* Update test_full_inference.jl

* remove default methods

* abstract the map on inference results to another function

* Update plot_functions.jl

* catch failure to pass truthdata and reformat
  • Loading branch information
SamuelBrand1 authored May 17, 2024
1 parent 576e0ee commit c980e22
Show file tree
Hide file tree
Showing 43 changed files with 480 additions and 400 deletions.
84 changes: 0 additions & 84 deletions pipeline/scripts/analysis_pipeline.jl

This file was deleted.

26 changes: 26 additions & 0 deletions pipeline/scripts/run_pipeline.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Local environment script to run the analysis pipeline
using Pkg
Pkg.activate(joinpath(@__DIR__(), ".."))
using Dagger

@info("""
Running the analysis pipeline.
--------------------------------------------
""")

# Define the backend resources to use for the pipeline
# in this case we are using distributed local workers with loaded modules
using Distributed
pids = addprocs()

@everywhere include("../src/AnalysisPipeline.jl")
@everywhere using .AnalysisPipeline

# Create an instance of the pipeline behaviour
pipeline = RtwithoutRenewalPipeline()

# Run the pipeline
do_pipeline(pipeline)

# Remove the workers
rmprocs(pids)
49 changes: 29 additions & 20 deletions pipeline/src/AnalysisPipeline.jl
Original file line number Diff line number Diff line change
@@ -1,34 +1,43 @@
"""
This module contains the analysis pipeline for the `Rt-without-renewal` project.
# Pipeline Components
In this module the meaning of a _pipeline component_ is a directed-acylic-graph
(DAG) of tasks defined using `Dagger.jl` via dispatch on an `AbstractEpiAwarePipeline`
sub-type from a function with prefix `do_`. A full pipeline is a sequence of DAGs,
with execution determined by available computational resources.
"""
module AnalysisPipeline

using Dates: default
using CSV, Dagger, DataFramesMeta, Dates, Distributions, DocStringExtensions, DrWatson,
EpiAware, Plots, Statistics, ADTypes, AbstractMCMC, Plots, JLD2

# Exported struct types
export TruthSimulationConfig, InferenceConfig
export AbstractEpiAwarePipeline, EpiAwarePipeline, RtwithoutRenewalPipeline,
TruthSimulationConfig, InferenceConfig

# Exported functions: constructors
export make_gi_params, make_inf_generating_processes, make_latent_model_priors,
make_epiaware_name_model_pairs, make_Rt, make_truth_data_configs,
make_inference_configs, make_tspan, make_inference_method

# Exported functions: pipeline components
export do_truthdata, do_inference, do_pipeline

# Exported functions: simulate functions
export simulate, generate_truthdata

# Exported functions: infer functions
export infer, generate_inference_results, map_inference_results

# Exported functions
export simulate, infer, default_gi_params, default_Rt, default_tspan,
default_latent_model_priors, default_epiaware_models, default_inference_method,
default_latent_models_names, make_truth_data_configs, make_inference_configs,
generate_truthdata_from_config, generate_inference_results, plot_truth_data, plot_Rt
# Exported functions: plot functions
export plot_truth_data, plot_Rt

include("docstrings.jl")
include("default_gi_params.jl")
include("default_Rt.jl")
include("default_tspan.jl")
include("default_latent_model_priors.jl")
include("default_epiaware_models.jl")
include("default_inference_method.jl")
include("make_truth_data_configs.jl")
include("make_inference_configs.jl")
include("default_latent_models_names.jl")
include("TruthSimulationConfig.jl")
include("InferenceConfig.jl")
include("generate_truthdata.jl")
include("generate_inference_results.jl")
include("pipeline/pipeline.jl")
include("constructors/constructors.jl")
include("simulate/simulate.jl")
include("infer/infer.jl")
include("plot_functions.jl")
end
10 changes: 10 additions & 0 deletions pipeline/src/constructors/constructors.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
include("make_gi_params.jl")
include("make_inf_generating_processes.jl")
include("make_latent_model_priors.jl")
include("make_epiaware_name_model_pairs.jl")
include("make_inference_method.jl")
include("make_truth_data_configs.jl")
include("make_inference_configs.jl")
include("make_Rt.jl")
include("make_tspan.jl")
include("make_inference_method.jl")
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""
Compute the default Rt values over time.
Compute the default Rt values over time for generating truth data. This is the
default method.
## keyword Arguments
- `A`: Amplitude of the sinusoidal variation in Rt. Default is 0.3.
Expand All @@ -9,7 +10,7 @@ Compute the default Rt values over time.
- `true_Rt`: Array of default Rt values over time.
"""
function default_Rt(; A = 0.3, P = 30.0)
function make_Rt(pipeline::AbstractEpiAwarePipeline; A = 0.3, P = 30.0)
ϕ = asin(-0.1 / 0.3) * P / (2 * π)
N = 160
true_Rt = vcat(fill(1.1, 2 * 7), fill(2.0, 2 * 7), fill(0.5, 2 * 7),
Expand Down
29 changes: 29 additions & 0 deletions pipeline/src/constructors/make_epiaware_name_model_pairs.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""
Constructs a dictionary of name-model pairs for the EpiAware pipeline. This is
the default method.
# Arguments
- `pipeline::AbstractEpiaAwarePipeline`: The EpiAware pipeline object.
# Returns
A dictionary containing the name-model pairs.
"""
function make_epiaware_name_model_pairs(pipeline::AbstractEpiAwarePipeline)
prior_dict = make_latent_model_priors(pipeline)

ar = AR(damp_priors = [prior_dict["damp_param_prior"]],
std_prior = prior_dict["std_prior"],
init_priors = [prior_dict["transformed_process_init_prior"]])

rw = RandomWalk(
std_prior = prior_dict["std_prior"], init_prior = prior_dict["transformed_process_init_prior"])

diff_ar = DiffLatentModel(;
model = ar, init_priors = [prior_dict["transformed_process_init_prior"]])

wkly_ar, wkly_rw, wkly_diff_ar = [ar, rw, diff_ar] .|>
model -> BroadcastLatentModel(model, 7, RepeatBlock())

return ["wkly_ar" => wkly_ar, "wkly_rw" => wkly_rw, "wkly_diff_ar" => wkly_diff_ar]
end
16 changes: 16 additions & 0 deletions pipeline/src/constructors/make_gi_params.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""
Constructs a dictionary of GI (Generation Interval) parameters. This is the
default method.
# Arguments
- `pipeline`: An instance of the `AbstractEpiAwarePipeline` type.
# Returns
A dictionary containing the GI means and GI standard deviations.
"""
function make_gi_params(pipeline::AbstractEpiAwarePipeline)
gi_means = [2.0, 10.0, 20.0]
gi_stds = [2.0]
return Dict("gi_means" => gi_means, "gi_stds" => gi_stds)
end
14 changes: 14 additions & 0 deletions pipeline/src/constructors/make_inf_generating_processes.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""
Constructs and returns a vector of infection-generating process types for the given
pipeline. This is the default method.
# Arguments
- `pipeline`: An instance of `AbstractEpiAwarePipeline`.
# Returns
An array of infection-generating process types.
"""
function make_inf_generating_processes(pipeline::AbstractEpiAwarePipeline)
return [DirectInfections, ExpGrowthRate, Renewal]
end
21 changes: 21 additions & 0 deletions pipeline/src/constructors/make_inference_configs.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""
Create inference configurations for the given pipeline. This is the default method.
# Arguments
- `pipeline`: An instance of `AbstractEpiAwarePipeline`.
# Returns
- An object representing the inference configurations.
"""
function make_inference_configs(pipeline::AbstractEpiAwarePipeline)
gi_param_dict = make_gi_params(pipeline)
namemodel_vect = make_epiaware_name_model_pairs(pipeline)
igps = make_inf_generating_processes(pipeline)

inference_configs = Dict("igp" => igps, "latent_namemodels" => namemodel_vect,
"gi_mean" => gi_param_dict["gi_means"], "gi_std" => gi_param_dict["gi_stds"]) |>
dict_list

return inference_configs
end
19 changes: 19 additions & 0 deletions pipeline/src/constructors/make_inference_method.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""
Constructs an inference method for the given pipeline. This is a default method.
# Arguments
- `pipeline`: An instance of `AbstractEpiAwarePipeline`.
# Returns
- An inference method.
"""
function make_inference_method(pipeline::AbstractEpiAwarePipeline; ndraws::Integer = 2000,
mcmc_ensemble::AbstractMCMC.AbstractMCMCEnsemble = MCMCSerial(),
nruns_pthf::Integer = 4, maxiters_pthf::Integer = 100, nchains::Integer = 4)
return EpiMethod(
pre_sampler_steps = [ManyPathfinder(nruns = nruns_pthf, maxiters = maxiters_pthf)],
sampler = NUTSampler(adtype = AutoForwardDiff(), ndraws = ndraws,
nchains = nchains, mcmc_parallel = mcmc_ensemble)
)
end
26 changes: 26 additions & 0 deletions pipeline/src/constructors/make_latent_model_priors.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""
Constructs and returns a dictionary of prior distributions for the latent model
parameters. This is the default method.
# Arguments
- `pipeline`: An instance of the `AbstractEpiAwarePipeline` type.
# Returns
A dictionary containing the following prior distributions:
- `"transformed_process_init_prior"`: A normal distribution with mean 0.0 and
standard deviation 0.25.
- `"std_prior"`: A half-normal distribution with standard deviation 0.25.
- `"damp_param_prior"`: A beta distribution with shape parameters 0.5 and 0.5.
"""
function make_latent_model_priors(pipeline::AbstractEpiAwarePipeline)
transformed_process_init_prior = Normal(0.0, 0.25)
std_prior = HalfNormal(0.25)
damp_param_prior = Beta(0.5, 0.5)

return Dict(
"transformed_process_init_prior" => transformed_process_init_prior,
"std_prior" => std_prior,
"damp_param_prior" => damp_param_prior
)
end
15 changes: 15 additions & 0 deletions pipeline/src/constructors/make_truth_data_configs.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""
Create a dictionary of truth data configurations for `pipeline <: AbstractEpiAwarePipeline`.
This is the default method.
# Returns
A vector of dictionaries containing the mean and standard deviation values for
the generation interval.
"""
function make_truth_data_configs(pipeline::AbstractEpiAwarePipeline)
gi_param_dict = make_gi_params(pipeline)
return Dict(
"gi_mean" => gi_param_dict["gi_means"], "gi_std" => gi_param_dict["gi_stds"]) |>
dict_list
end
16 changes: 16 additions & 0 deletions pipeline/src/constructors/make_tspan.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""
Constructs the time span for the given `pipeline` object.
# Arguments
- `pipeline::AbstractEpiAwarePipeline`: The pipeline object for which the time
span is constructed. This is the default method.
# Returns
- `tspan::Tuple{Float64, Float64}`: The time span as a tuple of start and end times.
"""
function make_tspan(pipeline::AbstractEpiAwarePipeline; backhorizon = 21)
N = size(make_Rt(pipeline), 1)
@assert backhorizon<N "Backhorizon must be less than the length of the default Rt."
return (1, N - backhorizon)
end
Loading

0 comments on commit c980e22

Please sign in to comment.