Skip to content

Commit

Permalink
Merge pull request #78 from epiforecasts/update-scoringutils
Browse files Browse the repository at this point in the history
  • Loading branch information
sbfnk authored Feb 22, 2023
2 parents 35d0c96 + 0c33590 commit f093bb0
Show file tree
Hide file tree
Showing 8 changed files with 12 additions and 96 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ License: MIT + file LICENSE
Encoding: UTF-8
LazyData: true
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.1.0
RoxygenNote: 7.2.3
Depends:
R (>= 3.3.0)
Imports:
Expand Down
8 changes: 1 addition & 7 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,7 @@ importFrom(purrr,map_dfr)
importFrom(purrr,safely)
importFrom(purrr,transpose)
importFrom(rlang,has_name)
importFrom(scoringRules,crps_sample)
importFrom(scoringRules,dss_sample)
importFrom(scoringRules,logs_sample)
importFrom(scoringutils,bias)
importFrom(scoringutils,interval_score)
importFrom(scoringutils,pit)
importFrom(scoringutils,sharpness)
importFrom(scoringutils,score)
importFrom(stats,median)
importFrom(stats,quantile)
importFrom(stats,rpois)
Expand Down
82 changes: 8 additions & 74 deletions R/score.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
#' @importFrom dplyr filter select select_if
#' @importFrom tidyr spread
#' @importFrom tibble tibble
#' @importFrom scoringRules dss_sample crps_sample logs_sample
#' @importFrom scoringutils bias sharpness pit interval_score
#' @importFrom scoringutils score
#' @inheritParams summarise_forecast
#' @examples
#' \dontrun{
Expand All @@ -38,88 +37,23 @@ score_forecast <- function(fit_samples, observations, scores = "all") {
date >= min(fit_samples$date),
date <= max(fit_samples$date)
)
observations <-
dplyr::rename(observations, true_value = rt)

fit_samples <-
dplyr::filter(fit_samples,
date >= min(observations$date),
date <= max(observations$date)
)
fit_samples <-
dplyr::rename(fit_samples, prediction = rt)

combined <-
dplyr::inner_join(observations, fit_samples, by = "date", multiple = "all")

obs <- observations$rt

samples_matrix <-
tidyr::spread(fit_samples, key = "sample", value = "rt") %>%
dplyr::select(-horizon, -date) %>%
as.matrix

data_length <- length(observations$date)

##Define interval_score
interval_score <- function(lower, upper, range) {
suppressMessages(
suppressWarnings(scoringutils::interval_score(true_values = obs,
lower = apply(samples_matrix, 1,
quantile, probs = lower),
upper = apply(samples_matrix, 1,
quantile, probs = upper),
interval_range = range))
)
}

scores <- tibble::tibble(
date = observations$date,
horizon = 1:data_length,
dss = if(any(c("all", "dss") %in% scores)) {
scoringRules::dss_sample(y = obs, dat = samples_matrix)
}else{
NA
},
crps = if(any(c("all", "crps") %in% scores)) {
scoringRules::crps_sample(y = obs, dat = samples_matrix)
}else{
NA
},
logs = if(any(c("all", "logs") %in% scores)) {
scoringRules::logs_sample(y = obs, dat = samples_matrix)
}else{
NA
},
bias = if(any(c("all", "bias") %in% scores)) {
suppressWarnings(scoringutils::bias(obs, samples_matrix))
}else{
NA
},
sharpness = if(any(c("all", "sharpness") %in% scores)) {
suppressWarnings(scoringutils::sharpness(samples_matrix))
}else{
NA
},
calibration = if(any(c("all", "calibration") %in% scores)) {
suppressWarnings(scoringutils::pit(obs, samples_matrix)$p_value)
}else{
NA
},
median = if(any(c("all", "median") %in% scores)) {
interval_score(0.5, 0.5, 0)
}else{
NA
},
iqr = if(any(c("all", "iqr") %in% scores)) {
interval_score(0.25, 0.75, 50)
}else{
NA
},
ci = if(any(c("all", "ci") %in% scores)) {
interval_score(0.025, 0.975, 95)
}else{
NA
}
)

scores <- scoringutils::score(combined)
scores <- dplyr::select_if(scores, ~ any(!is.na(.)))


return(scores)
}

Expand Down
4 changes: 2 additions & 2 deletions R/summarise.R
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ summarise_scores <- function(scores, variables = NULL, sel_scores = NULL) {
}

summarised_scores <- scores %>%
tidyr::gather(key = "score", value = "value", dss, crps, logs, bias,
sharpness, calibration, median, iqr, ci)
tidyr::gather(key = "score", value = "value", mad, bias, dss, crps,
ae_median, se_mean)


if (!is.null(sel_scores)) {
Expand Down
3 changes: 0 additions & 3 deletions man/evaluate_model.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 0 additions & 3 deletions man/iterative_direct_case_forecast.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 0 additions & 3 deletions man/iterative_rt_forecast.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 0 additions & 3 deletions man/plot_scores.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit f093bb0

Please sign in to comment.