Skip to content

Commit

Permalink
fixes following more thorough testing
Browse files Browse the repository at this point in the history
  • Loading branch information
sbfnk committed Nov 1, 2022
1 parent 3f01a21 commit 7c6bfa3
Show file tree
Hide file tree
Showing 9 changed files with 66 additions and 33 deletions.
5 changes: 4 additions & 1 deletion R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,10 @@ create_stan_data <- function(reported_cases, generation_time,
delays$seeding_time <- max(delays$seeding_time, generation_time$max)

## for backwards compatibility call generation_time_opts internally
generation_time <- do.call(generation_time_opts, generation_time)
if (is.list(generation_time) &&
all(c("mean", "mean_sd", "sd", "sd_sd") %in% names(generation_time))) {
generation_time <- do.call(generation_time_opts, generation_time)
}

cases <- reported_cases[(delays$seeding_time + 1):(.N - horizon)]$confirm

Expand Down
10 changes: 9 additions & 1 deletion R/estimate_truncation.R
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,13 @@
#'
#' options(old_opts)
estimate_truncation <- function(obs, trunc_max = 10,
trunc_dist = c("lognormal", "gamma"),
model = NULL,
CrIs = c(0.2, 0.5, 0.9),
verbose = TRUE,
...) {
trunc_dist <- match.arg(trunc_dist)

# combine into ordered matrix
dirty_obs <- purrr::map(obs, data.table::as.data.table)
nrow_obs <- order(purrr::map_dbl(dirty_obs, nrow))
Expand All @@ -129,9 +132,14 @@ estimate_truncation <- function(obs, trunc_max = 10,
obs_dist = obs_dist,
t = nrow(obs_data),
obs_sets = ncol(obs_data),
trunc_max = trunc_max
trunc_max = trunc_max,
trunc_dist = trunc_dist
)

## convert to integer
data$trunc_dist <-
which(eval(formals()[["trunc_dist"]]) == trunc_dist) - 1

# initial conditions
init_fn <- function() {
data <- list(
Expand Down
17 changes: 10 additions & 7 deletions R/opts.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#' 1 will be assumed, if the \code{max} parameter not set then the \code{max} will
#' be set to 15 to ensure backwards compatibility, and if no \code{dist} parameter
#' is given then a gamma distribution will be used for backwards compatibility.
#' @param fixed Logical, defaults to `FALSE`. Should the generation time be
#' treated as coming from fixed (vs uncertain) distributions.
#' @inheritParams get_generation_time
#' @seealso convert_to_logmean convert_to_logsd bootstrapped_dist_fit delay_dist
#' @return A list summarising the input delay distributions.
Expand All @@ -24,7 +26,7 @@
#'
#' # An uncertain gamma distributed generation time
#' generation_time_opts(mean = 3, sd = 2, mean_sd = 1, sd_sd = 0.5)
generation_time_opts <- function(..., disease, source) {
generation_time_opts <- function(..., disease, source, max = 15, fixed = FALSE) {
dot_options <- list(...) ## options for delay_dist
## check consistent options are given
type_options <- (length(dot_options) > 0) + ## distributional parameters
Expand All @@ -38,20 +40,21 @@ generation_time_opts <- function(..., disease, source) {
dist <- get_generation_time(
disease = disease, source = source, max_value = max
)
dist$fixed <- fixed
gt <- do.call(delay_dist, dist)
} else { ## generation time provided as distributional parameters or not at all
## make gamma default for backwards compatibility
if (!("dist" %in% names(dot_options))) {
dot_options$dist <- "gamma"
}
## set default of max=15 for backwards compatibility
if (!("max" %in% names(dot_options))) {
dot_options$max <- 15
}
## set max
dot_options$max <- max
## set default of mean=1 for backwards compatibility
if (!("mean" %in% names(dot_options))) {
dot_options$mean <- 1
}
gt <- do.call(delay_dist, dot_options)
dot_options$fixed <- fixed
gt <- do.call(delay_dist, dot_options)
}
names(gt) <- paste0("gt_", names(gt))

Expand Down Expand Up @@ -104,7 +107,7 @@ delay_opts <- function(..., fixed = FALSE) {

if (fixed) { ## set all to fixed
data <- lapply(data, function(x) {
x$fixed <- 1
x$fixed <- 1L
x$mean_sd <- 0
x$sd_sd <- 0
return(x)
Expand Down
12 changes: 7 additions & 5 deletions inst/stan/estimate_secondary.stan
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,20 @@ data {
}

transformed data {
int delay_max_fixed =
sum(delay_max[fixed_delays]) - num_elements(fixed_delays) + 1;
int delay_max_fixed = (n_fixed_delays == 0 ? 0 :
sum(delay_max[fixed_delays]) - num_elements(fixed_delays) + 1);
int delay_max_total = (delays == 0 ? 0 :
sum(delay_max) - num_elements(delay_max) + 1);
vector[truncation && trunc_fixed[1] ? trunc_max[1] : 0] trunc_fixed_pmf;
vector[delay_max_fixed] fixed_delays_pmf;
vector[delay_max_fixed] delays_fixed_pmf;

if (truncation && trunc_fixed[1]) {
trunc_fixed_pmf = discretised_pmf(
trunc_mean_mean[1], trunc_sd_mean[1], trunc_max[1], trunc_dist[1], 0
);
}
if (n_fixed_delays) {
fixed_delays_pmf = combine_pmfs(
delays_fixed_pmf = combine_pmfs(
to_vector([ 1 ]),
delay_mean_mean[fixed_delays],
delay_sd_mean[fixed_delays],
Expand All @@ -55,15 +55,17 @@ parameters{
transformed parameters {
vector<lower=0>[t] secondary;
// calculate secondary reports from primary

{
vector[delay_max_total] delay_pmf;
delay_pmf = combine_pmfs(
fixed_delays_pmf, delay_mean, delay_sd, delay_max, delay_dist, delay_max_total, 0
delays_fixed_pmf, delay_mean, delay_sd, delay_max, delay_dist, delay_max_total, 0
);
secondary = calculate_secondary(primary, obs, frac_obs, delay_pmf, cumulative,
historic, primary_hist_additive,
current, primary_current_additive, t);
}

// weekly reporting effect
if (week_effect > 1) {
secondary = day_of_week_effect(secondary, day_of_week, day_of_week_simplex);
Expand Down
8 changes: 5 additions & 3 deletions tests/testthat/_snaps/calc_CrI.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# calc_CrI works as expected with default arguments

value CrI
<num> <char>
1: 1.45 lower_90
2: 9.55 upper_90

# calc_CrI works as expected when grouping

type value CrI
1: car 1.45 lower_90
2: car 9.55 upper_90
type value CrI
<char> <num> <char>
1: car 1.45 lower_90
2: car 9.55 upper_90

18 changes: 12 additions & 6 deletions tests/testthat/_snaps/calc_CrIs.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
# calc_CrI works as expected with default arguments

. lower_90 lower_50 lower_20 upper_20 upper_50 upper_90
1: . 1.45 3.25 4.6 6.4 7.75 9.55
Key: <.>
. lower_90 lower_50 lower_20 upper_20 upper_50 upper_90
<char> <num> <num> <num> <num> <num> <num>
1: . 1.45 3.25 4.6 6.4 7.75 9.55

# calc_CrI works as expected when grouping

type lower_90 lower_50 lower_20 upper_20 upper_50 upper_90
1: car 1.45 3.25 4.6 6.4 7.75 9.55
Key: <type>
type lower_90 lower_50 lower_20 upper_20 upper_50 upper_90
<char> <num> <num> <num> <num> <num> <num>
1: car 1.45 3.25 4.6 6.4 7.75 9.55

# calc_CrI works as expected when given a custom CrI list

. lower_95 lower_40 lower_10 upper_10 upper_40 upper_95
1: . 1.225 3.7 5.05 5.95 7.3 9.775
Key: <.>
. lower_95 lower_40 lower_10 upper_10 upper_40 upper_95
<char> <num> <num> <num> <num> <num> <num>
1: . 1.225 3.7 5.05 5.95 7.3 9.775

18 changes: 12 additions & 6 deletions tests/testthat/_snaps/calc_summary_measures.md
Original file line number Diff line number Diff line change
@@ -1,21 +1,27 @@
# calc_summary_measures works as expected with default arguments

type median mean sd lower_90 lower_50 lower_20 upper_20 upper_50
1: car 5.5 5.5 3.02765 1.45 3.25 4.6 6.4 7.75
type median mean sd lower_90 lower_50 lower_20 upper_20 upper_50
<char> <num> <num> <num> <num> <num> <num> <num> <num>
1: car 5.5 5.5 3.02765 1.45 3.25 4.6 6.4 7.75
upper_90
<num>
1: 9.55

# calc_CrI works as expected when grouping

type median mean sd lower_90 lower_50 lower_20 upper_20 upper_50
1: car 5.5 5.5 3.02765 1.45 3.25 4.6 6.4 7.75
type median mean sd lower_90 lower_50 lower_20 upper_20 upper_50
<char> <num> <num> <num> <num> <num> <num> <num> <num>
1: car 5.5 5.5 3.02765 1.45 3.25 4.6 6.4 7.75
upper_90
<num>
1: 9.55

# calc_CrI works as expected when given a custom CrI list

type median mean sd lower_95 lower_40 lower_10 upper_10 upper_40
1: car 5.5 5.5 3.02765 1.225 3.7 5.05 5.95 7.3
type median mean sd lower_95 lower_40 lower_10 upper_10 upper_40
<char> <num> <num> <num> <num> <num> <num> <num> <num>
1: car 5.5 5.5 3.02765 1.225 3.7 5.05 5.95 7.3
upper_95
<num>
1: 9.775

10 changes: 6 additions & 4 deletions tests/testthat/_snaps/calc_summary_stats.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# calc_summary_stats works as expected with default arguments

median mean sd
1: 5.5 5.5 3.02765
median mean sd
<num> <num> <num>
1: 5.5 5.5 3.02765

# calc_summary_stats works as expected when grouping

type median mean sd
1: car 5.5 5.5 3.02765
type median mean sd
<char> <num> <num> <num>
1: car 5.5 5.5 3.02765

1 change: 1 addition & 0 deletions tests/testthat/test-estimate_secondary.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
skip_on_cran()
library(data.table)
library(purrr)

# make some example secondary incidence data
cases <- example_confirmed
Expand Down

0 comments on commit 7c6bfa3

Please sign in to comment.