From 0c33590400ac9831ebaed32b62baf0c2232d2e7c Mon Sep 17 00:00:00 2001 From: Sebastian Funk Date: Wed, 22 Feb 2023 18:48:09 +0000 Subject: [PATCH] scoringutils updates --- DESCRIPTION | 2 +- NAMESPACE | 8 +-- R/score.R | 82 +++------------------------ R/summarise.R | 4 +- man/evaluate_model.Rd | 3 - man/iterative_direct_case_forecast.Rd | 3 - man/iterative_rt_forecast.Rd | 3 - man/plot_scores.Rd | 3 - 8 files changed, 12 insertions(+), 96 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 85244e9..083a613 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -50,7 +50,7 @@ License: MIT + file LICENSE Encoding: UTF-8 LazyData: true Roxygen: list(markdown = TRUE) -RoxygenNote: 7.1.0 +RoxygenNote: 7.2.3 Depends: R (>= 3.3.0) Imports: diff --git a/NAMESPACE b/NAMESPACE index fad6be9..c0f7f41 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -84,13 +84,7 @@ importFrom(purrr,map_dfr) importFrom(purrr,safely) importFrom(purrr,transpose) importFrom(rlang,has_name) -importFrom(scoringRules,crps_sample) -importFrom(scoringRules,dss_sample) -importFrom(scoringRules,logs_sample) -importFrom(scoringutils,bias) -importFrom(scoringutils,interval_score) -importFrom(scoringutils,pit) -importFrom(scoringutils,sharpness) +importFrom(scoringutils,score) importFrom(stats,median) importFrom(stats,quantile) importFrom(stats,rpois) diff --git a/R/score.R b/R/score.R index 5fb4a76..50e9a08 100644 --- a/R/score.R +++ b/R/score.R @@ -11,8 +11,7 @@ #' @importFrom dplyr filter select select_if #' @importFrom tidyr spread #' @importFrom tibble tibble -#' @importFrom scoringRules dss_sample crps_sample logs_sample -#' @importFrom scoringutils bias sharpness pit interval_score +#' @importFrom scoringutils score #' @inheritParams summarise_forecast #' @examples #' \dontrun{ @@ -38,88 +37,23 @@ score_forecast <- function(fit_samples, observations, scores = "all") { date >= min(fit_samples$date), date <= max(fit_samples$date) ) + observations <- + dplyr::rename(observations, true_value = rt) fit_samples <- dplyr::filter(fit_samples, date >= min(observations$date), date <= max(observations$date) ) + fit_samples <- + dplyr::rename(fit_samples, prediction = rt) + combined <- + dplyr::inner_join(observations, fit_samples, by = "date", multiple = "all") - obs <- observations$rt - - samples_matrix <- - tidyr::spread(fit_samples, key = "sample", value = "rt") %>% - dplyr::select(-horizon, -date) %>% - as.matrix - - data_length <- length(observations$date) - - ##Define interval_score - interval_score <- function(lower, upper, range) { - suppressMessages( - suppressWarnings(scoringutils::interval_score(true_values = obs, - lower = apply(samples_matrix, 1, - quantile, probs = lower), - upper = apply(samples_matrix, 1, - quantile, probs = upper), - interval_range = range)) - ) - } - - scores <- tibble::tibble( - date = observations$date, - horizon = 1:data_length, - dss = if(any(c("all", "dss") %in% scores)) { - scoringRules::dss_sample(y = obs, dat = samples_matrix) - }else{ - NA - }, - crps = if(any(c("all", "crps") %in% scores)) { - scoringRules::crps_sample(y = obs, dat = samples_matrix) - }else{ - NA - }, - logs = if(any(c("all", "logs") %in% scores)) { - scoringRules::logs_sample(y = obs, dat = samples_matrix) - }else{ - NA - }, - bias = if(any(c("all", "bias") %in% scores)) { - suppressWarnings(scoringutils::bias(obs, samples_matrix)) - }else{ - NA - }, - sharpness = if(any(c("all", "sharpness") %in% scores)) { - suppressWarnings(scoringutils::sharpness(samples_matrix)) - }else{ - NA - }, - calibration = if(any(c("all", "calibration") %in% scores)) { - suppressWarnings(scoringutils::pit(obs, samples_matrix)$p_value) - }else{ - NA - }, - median = if(any(c("all", "median") %in% scores)) { - interval_score(0.5, 0.5, 0) - }else{ - NA - }, - iqr = if(any(c("all", "iqr") %in% scores)) { - interval_score(0.25, 0.75, 50) - }else{ - NA - }, - ci = if(any(c("all", "ci") %in% scores)) { - interval_score(0.025, 0.975, 95) - }else{ - NA - } - ) - + scores <- scoringutils::score(combined) scores <- dplyr::select_if(scores, ~ any(!is.na(.))) - return(scores) } diff --git a/R/summarise.R b/R/summarise.R index 0d49d09..4f3bbda 100644 --- a/R/summarise.R +++ b/R/summarise.R @@ -142,8 +142,8 @@ summarise_scores <- function(scores, variables = NULL, sel_scores = NULL) { } summarised_scores <- scores %>% - tidyr::gather(key = "score", value = "value", dss, crps, logs, bias, - sharpness, calibration, median, iqr, ci) + tidyr::gather(key = "score", value = "value", mad, bias, dss, crps, + ae_median, se_mean) if (!is.null(sel_scores)) { diff --git a/man/evaluate_model.Rd b/man/evaluate_model.Rd index 1d52a92..86ea10d 100644 --- a/man/evaluate_model.Rd +++ b/man/evaluate_model.Rd @@ -50,9 +50,6 @@ arguments with the first specfying the number of samples and the second the mean to \code{rpois} if not supplied} \item{return_raw}{Logical, should raw cases and rt forecasts be returned. Defaults to \code{FALSE}.} -} -\value{ - } \description{ Evaluate a Model for Forecasting Rts diff --git a/man/iterative_direct_case_forecast.Rd b/man/iterative_direct_case_forecast.Rd index a76f229..4edd118 100644 --- a/man/iterative_direct_case_forecast.Rd +++ b/man/iterative_direct_case_forecast.Rd @@ -31,9 +31,6 @@ equal to 0.} \item{min_points}{Numeric, defaults to 3. The minimum number of time points at which to begin iteratively evaluating the forecast.} -} -\value{ - } \description{ Iteratively forecast directly on cases diff --git a/man/iterative_rt_forecast.Rd b/man/iterative_rt_forecast.Rd index 5ccffc5..df0165d 100644 --- a/man/iterative_rt_forecast.Rd +++ b/man/iterative_rt_forecast.Rd @@ -32,9 +32,6 @@ equal to 0.} \item{min_points}{Numeric, defaults to 3. The minimum number of time points at which to begin iteratively evaluating the forecast.} -} -\value{ - } \description{ Iteratively Forecast diff --git a/man/plot_scores.Rd b/man/plot_scores.Rd index 841672f..078f15e 100644 --- a/man/plot_scores.Rd +++ b/man/plot_scores.Rd @@ -12,6 +12,3 @@ A dataframe of summarised scores in a tidy format. \description{ Plot forecast scores } -\examples{ - -}