Skip to content

Commit

Permalink
add more flexibility with delays (#305)
Browse files Browse the repository at this point in the history
  • Loading branch information
sbfnk authored Nov 18, 2022
1 parent d348b9d commit fe7675f
Show file tree
Hide file tree
Showing 54 changed files with 1,025 additions and 429 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,6 @@ Encoding: UTF-8
Language: en-GB
LazyData: true
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.1
RoxygenNote: 7.2.2
SystemRequirements: GNU make
VignetteBuilder: knitr
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ export(create_rt_data)
export(create_shifted_cases)
export(create_stan_args)
export(create_stan_data)
export(delay_dist)
export(delay_opts)
export(dist_fit)
export(dist_skel)
Expand All @@ -44,6 +45,7 @@ export(extract_inits)
export(extract_stan_param)
export(forecast_secondary)
export(gamma_dist_def)
export(generation_time_opts)
export(get_dist)
export(get_generation_time)
export(get_incubation_period)
Expand Down
14 changes: 14 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# EpiNow2 1.3.3.9000

## New features

* Adds support for fixed delays (mean only or fixed lognormal distributed) or truncations (fixed lognormal distributed), and for pre-computing these delays as well as generation times if they are fixed. By @sbfnk and @seabbs.

## Model changes

## Documentation

* Updated examples to make use of fixed distributions to improve run-times where appropriate.

## Deprecated features

# EpiNow2 1.3.3

This release adds a range of new minor features, squashes bugs, enhances documentation, expands unit testing, implements some minor run-time optimisations, and removes some obsolete features.
Expand Down
77 changes: 35 additions & 42 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -395,33 +395,13 @@ create_stan_data <- function(reported_cases, generation_time,
backcalc, shifted_cases,
truncation) {

## make sure we have at least max_gt seeding time
## make sure we have at least gt_max seeding time
delays$seeding_time <- max(delays$seeding_time, generation_time$max)

## complete generation time parameters if not all are given
if (is.null(generation_time)) {
generation_time <- list(mean = 1)
}
for (param in c("mean_sd", "sd", "sd_sd")) {
if (!(param %in% names(generation_time))) generation_time[[param]] <- 0
}
## check if generation time is fixed
if (generation_time$sd == 0 && generation_time$sd_sd == 0) {
if ("max_gt" %in% names(generation_time)) {
if (generation_time$max_gt != generation_time$mean) {
stop("Error in generation time defintion: if max_gt(",
generation_time$max_gt,
") is given it must be equal to the mean (",
generation_time$mean,
")")
}
} else {
generation_time$max_gt <- generation_time$mean
}
if (any(generation_time$mean_sd > 0, generation_time$sd_sd > 0)) {
stop("Error in generation time definition: if sd_mean is 0 and ",
"sd_sd is 0 then mean_sd must be 0, too.")
}
## for backwards compatibility call generation_time_opts internally
if (is.list(generation_time) &&
all(c("mean", "mean_sd", "sd", "sd_sd") %in% names(generation_time))) {
generation_time <- do.call(generation_time_opts, generation_time)
}

cases <- reported_cases[(delays$seeding_time + 1):(.N - horizon)]$confirm
Expand All @@ -431,13 +411,10 @@ create_stan_data <- function(reported_cases, generation_time,
shifted_cases = shifted_cases,
t = length(reported_cases$date),
horizon = horizon,
gt_mean_mean = generation_time$mean,
gt_mean_sd = generation_time$mean_sd,
gt_sd_mean = generation_time$sd,
gt_sd_sd = generation_time$sd_sd,
max_gt = generation_time$max,
burn_in = 0
)
# add gt data
data <- c(data, generation_time)
# add delay data
data <- c(data, delays)
# add truncation data
Expand All @@ -459,6 +436,10 @@ create_stan_data <- function(reported_cases, generation_time,
data$prior_infections <- ifelse(is.na(data$prior_infections) | is.null(data$prior_infections),
0, data$prior_infections
)
if (is.null(data$gt_weight)) {
## default: weigh by number of data points
data$gt_weight <- data$t - data$seeding_time - data$horizon
}
if (data$seeding_time > 1) {
safe_lm <- purrr::safely(stats::lm)
data$prior_growth <- safe_lm(log(confirm) ~ t, data = first_week)[[1]]
Expand Down Expand Up @@ -503,26 +484,35 @@ create_stan_data <- function(reported_cases, generation_time,
create_initial_conditions <- function(data) {
init_fun <- function() {
out <- list()
if (data$delays > 0) {
if (data$n_uncertain_mean_delays > 0) {
out$delay_mean <- array(purrr::map2_dbl(
data$delay_mean_mean, data$delay_mean_sd * 0.1,
data$delay_mean_mean[data$uncertain_mean_delays],
data$delay_mean_sd[data$uncertain_mean_delays] * 0.1,
~ rnorm(1, mean = .x, sd = .y)
))
}
if (data$n_uncertain_sd_delays > 0) {
out$delay_sd <- array(purrr::map2_dbl(
data$delay_sd_mean, data$delay_sd_sd * 0.1,
data$delay_sd_mean[data$uncertain_sd_delays],
data$delay_sd_sd[data$uncertain_sd_delays] * 0.1,
~ rnorm(1, mean = .x, sd = .y)
))
}
if (data$truncation > 0) {
out$truncation_mean <- array(rnorm(1,
mean = data$trunc_mean_mean,
sd = data$trunc_mean_sd * 0.1
))
out$truncation_sd <- array(truncnorm::rtruncnorm(1,
a = 0,
mean = data$trunc_sd_mean,
sd = data$trunc_sd_sd * 0.1
))
if (data$trunc_mean_sd > 0) {
out$truncation_mean <- array(rnorm(1,
mean = data$trunc_mean_mean,
sd = data$trunc_mean_sd * 0.1
))
}
if (data$trunc_sd_sd > 0) {
out$truncation_sd <- array(
truncnorm::rtruncnorm(1,
a = 0,
mean = data$trunc_sd_mean,
sd = data$trunc_sd_sd * 0.1
))
}
}
if (data$fixed == 0) {
out$eta <- array(rnorm(data$M, mean = 0, sd = 0.1))
Expand Down Expand Up @@ -578,6 +568,9 @@ create_initial_conditions <- function(data) {
sd = data$obs_scale_sd * 0.1
))
}
if (data$week_effect > 0) {
out$day_of_week_simplex = array(rep(1 / data$week_effect, data$week_effect))
}
return(out)
}
return(init_fun)
Expand Down
80 changes: 80 additions & 0 deletions R/dist.R
Original file line number Diff line number Diff line change
Expand Up @@ -710,3 +710,83 @@ tune_inv_gamma <- function(lower = 2, upper = 21) {
)
return(out)
}

##' Delay distribution.
##'
##' @description `r lifecycle::badge("stable")`
##' Defines the parameters of a delay distribution
##' @param mean Numeric. If the only non-zero summary parameter
##' then this is the fixed interval of the delay distribution. If the `sd` is
##' non-zero then this is the mean of the distribution given by \code{dist}.
##' If this is not given a vector of empty vectors is returned.
##' @param sd Numeric, defaults to 0. Sets the standard deviation of the delay
##' distribution.
##' @param mean_sd Numeric, defaults to 0. Sets the standard deviation of the
##' uncertainty around the mean of the delay distribution.
##' @param sd_sd Numeric, defaults to 0. Sets the standard deviation of the
##' uncertainty around the sd of the delay distribution.
##' @param dist Character, defaults to "lognormal". The (discretised) distribution
##' to be used. If sd == 0 then the delay is fixed and a delta function will be
##' used whatever the choice here.
##' @param max Numeric, maximum value of the delay distribution
##' @param fixed Logical, defaults to `FALSE`. Should delays be treated
##' as coming from fixed (vs uncertain) distributions. Making this simplification
##' drastically reduces compute requirements.
##' @return A list of delay distribution options to be used downstream
##' @author Sebastian Funk
##' @export
delay_dist <- function(mean, sd = 0, mean_sd = 0, sd_sd = 0,
dist = c("lognormal", "gamma"), max = NULL,
fixed = FALSE) {
dist <- match.arg(dist)

if (missing(mean)) {
ret <- list(
mean_mean = numeric(0),
mean_sd = numeric(0),
sd_mean = numeric(0),
sd_sd = numeric(0),
fixed = integer(0),
dist = integer(0)
)
if (is.null(max)) {
ret$max <- integer(0)
} else {
ret$max <- max
}
} else {
ret <- list(
mean_mean = mean,
mean_sd = mean_sd,
sd_mean = sd,
sd_sd = sd_sd
)
if (fixed) {
ret$mean_sd <- 0
ret$sd_sd <- 0
}
ret$fixed <- as.integer(ret$mean_sd == 0 && ret$mean_sd == 0)

## check if it's a fixed value
if (ret$sd_mean == 0 && ret$sd_sd == 0) {
if (ret$mean_mean %% 1 != 0) {
stop(
"When a delay distribution is set to a constant ",
"(sd == 0 and sd_sd == 0) then the mean parameter ",
"must be an integer."
)
}
ret$max <- ret$mean_mean
if (ret$mean_sd > 0) {
stop(
"When a delay distribution has sd == 0 and ",
"sd_sd == 0 then mean_sd must be 0, too."
)
}
} else {
ret$max <- max
}
ret$dist <- which(eval(formals()[["dist"]]) == dist) - 1
}
return(lapply(ret, array))
}
2 changes: 1 addition & 1 deletion R/epinow.R
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
#'
#' # summary of R estimates
#' summary(out, type = "parameters", params = "R")
#'
#'
#' options(old_opts)
#' }
epinow <- function(reported_cases,
Expand Down
Loading

0 comments on commit fe7675f

Please sign in to comment.