Skip to content

Commit

Permalink
fixes to small issues introduced by PR #305 (#336)
Browse files Browse the repository at this point in the history
  • Loading branch information
sbfnk authored Nov 24, 2022
1 parent 4568af6 commit 9d13484
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 23 deletions.
2 changes: 1 addition & 1 deletion R/estimate_truncation.R
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ estimate_truncation <- function(obs, max_truncation, trunc_max = 10,
out$last_obs <- last_obs
# summarise estimated cmf of the truncation distribution
out$cmf <- extract_stan_param(fit, "cmf", CrIs = CrIs)
out$cmf <- data.table::as.data.table(out$cmf)[, index := .N:1]
out$cmf <- data.table::as.data.table(out$cmf)[, index := seq_len(.N)]
data.table::setcolorder(out$cmf, "index")
out$data <- data
out$fit <- fit
Expand Down
22 changes: 9 additions & 13 deletions R/opts.R
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ delay_opts <- function(..., fixed = FALSE) {

names(data) <- paste0("delay_", names(data))
# Estimate the mean delay -----------------------------------------------
data$seeding_time <- sum(
purrr::map2_dbl(data$mean_mean, data$sd_mean, ~ exp(.x + .y^2 / 2))
data$seeding_time <- sum(purrr::map2_dbl(
data$delay_mean_mean, data$delay_sd_mean, ~ exp(.x + .y^2 / 2))
)
if (data$seeding_time < 1) {
data$seeding_time <- 1
Expand Down Expand Up @@ -162,10 +162,9 @@ delay_opts <- function(..., fixed = FALSE) {
#' Returns a truncation distribution formatted for usage by
#' downstream functions. See `estimate_truncation()` for an approach to
#' estimate this distribution.
#' @param ... Any parameters to be passed to [delay_dist]. If the `max` parameter
#' is not set but other distributional parameters given then the `max` will be
#' set to 15 to ensure backwards compatibility. Also if no `dist` parameter is given
#' then a gamma distribution will be used for backwards compatibility.
#' @param dist Parameters of a discretised (upper-)truncated lognormal
#' truncation distribution as a list with parameters that are all passed to
#' [delay_dist].
#' @seealso convert_to_logmean convert_to_logsd bootstrapped_dist_fit delay_dist
#' @return A list summarising the input truncation distribution.
#' @export
Expand All @@ -174,14 +173,11 @@ delay_opts <- function(..., fixed = FALSE) {
#' trunc_opts()
#'
#' # truncation dist
#' trunc_opts(mean = 3, sd = 2)
trunc_opts <- function(...) {
dot_options <- list(...) ## options for delay_dist
present <- (length(dot_options) > 0)
data <- do.call(delay_dist, dot_options)
#' trunc_opts(dist = list(mean = 3, sd = 2))
trunc_opts <- function(dist = list()) {
data <- do.call(delay_dist, dist)
names(data) <- paste0("trunc_", names(data))
data$truncation <- as.integer(present)

data$truncation <- as.integer(length(data$trunc_max) > 0)
return(data)
}

Expand Down
11 changes: 5 additions & 6 deletions man/trunc_opts.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions tests/testthat/test-delays.R
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,10 @@ test_that("truncation parameters can be specified in different ways", {
c("trunc_mean_mean", "trunc_mean_sd", "trunc_sd_mean", "trunc_sd_sd",
"trunc_max")
expect_equal(
test_stan_data(truncation = trunc_opts(mean = 3, sd = 1, max = 5),
params = trunc_params),
test_stan_data(
truncation = trunc_opts(dist = list(mean = 3, sd = 1, max = 5)),
params = trunc_params
),
c(3, 0, 1, 0, 5)
)
})
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-simulate_infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ test_that("simulate infections fails as expected", {
})

test_that("simulate_infections works to simulate a passed in estimate_infections object with an adjusted Rt in data frame", {
R <- c(rep(NA_real_, 31), rep(0.5, 17))
R <- c(rep(NA_real_, 32), rep(0.5, 17))
R_dt <- data.frame(date = summary(out, type = "parameters", param = "R")$date, value = R)
sims_dt <- simulate_infections(out, R_dt)
expect_equal(names(sims_dt), c("samples", "summarised", "observations"))
Expand Down

0 comments on commit 9d13484

Please sign in to comment.