Skip to content

Commit

Permalink
feat: initial code to to compare multiple dx objects
Browse files Browse the repository at this point in the history
  • Loading branch information
overdodactyl committed Jan 4, 2024
1 parent 7332fa6 commit 235ac97
Show file tree
Hide file tree
Showing 20 changed files with 817 additions and 6 deletions.
5 changes: 5 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ export(dx_brier)
export(dx_chi_square)
export(dx_cm)
export(dx_cohens_kappa)
export(dx_compare)
export(dx_delong)
export(dx_edit_cell)
export(dx_f1)
export(dx_f2)
Expand All @@ -28,6 +30,7 @@ export(dx_lrt_neg)
export(dx_lrt_pos)
export(dx_markedness)
export(dx_mcc)
export(dx_mcnemars)
export(dx_npv)
export(dx_odds_ratio)
export(dx_plot_calibration)
Expand All @@ -43,9 +46,11 @@ export(dx_plot_pr)
export(dx_plot_predictive_value)
export(dx_plot_probabilities)
export(dx_plot_roc)
export(dx_plot_rocs)
export(dx_plot_thresholds)
export(dx_plot_youden_j)
export(dx_ppv)
export(dx_prevalence)
export(dx_sensitivity)
export(dx_specificity)
export(dx_z_test)
8 changes: 8 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
* Added Chi-Square Test
* Added Fisher's Exact Test
* Added G-Test (Log-Likelihood Ratio Test)
* Added McNemar's Chi-squared Test for Paired Proportions
* Added Z-test for Comparing Two Proportions
* Added DeLong's Test for Comparing Two ROC Curves


**New output**
Expand All @@ -53,6 +56,7 @@
* Added Cumulative Accuracy Profile (CAP) curve: `dx_plot_cap`
* Added cost curve: `dx_plot_cost`
* Added plot showing metrics across thresholds: `dx_plot_thresholds`
* Add ability to plot multiple ROC curves: `dx_plot_rocs`

**Renamed Functions**

Expand All @@ -72,6 +76,10 @@
* Removed `DescTools` and `e1071` from Suggests
* Added `tibble` to Suggests

**Comparison of dx objects**

* New `dx_compare` function to run pairwise tests on a list of `dx_objects`

**Documentation**

* Added a `NEWS.md` file to track changes to the package.
Expand Down
17 changes: 15 additions & 2 deletions R/dx_constructor.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

dx <- function(data,
classlabels = c("Negative", "Positive"),
threshold_range = NA, outcome_label, pred_varname, true_varname,
threshold_range = NA, outcome_label = NA, pred_varname, true_varname,
setthreshold = .5, poslabel = 1, grouping_variables = NA,
citype = "exact", bootreps = 2000, bootseed = 20191015,
doboot = FALSE, direction = "auto", ...) {
Expand Down Expand Up @@ -139,6 +139,17 @@ dx <- function(data,

roc <- get_roc(true_varname, pred_varname, data, direction)

# main confusion matrix
predprob <- data[[options$pred_varname]]
truth <- data[[options$true_varname]]

cm <- dx_cm(
predprob = predprob,
truth = truth,
threshold = setthreshold,
poslabel = options$poslabel
)


structure(list(
data = data,
Expand All @@ -147,6 +158,8 @@ dx <- function(data,
thresholds = threshold_analysis,
prevalence = prevalence_analysis,
rank = rank_analysis,
n_levels = n_levels, roc = roc
cm = cm,
n_levels = n_levels,
roc = roc
), class = "dx")
}
3 changes: 2 additions & 1 deletion R/dx_heart_failure.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
#' \item{AgeGroup}{Age group}
#' \item{Sex}{sex}
#' \item{truth}{Heart failure (outcome)}
#' \item{predicted}{Predicted outcome from model}
#' \item{predicted}{Predicted outcome from a GLM model}
#' \item{predicted_rf}{Predicted outcome from a Random Forest model}
#' \item{AgeSex}{Age and sex group}
#' }
#' @source \url{https://www.kaggle.com/imnikhilanand/heart-attack-prediction}
Expand Down
73 changes: 73 additions & 0 deletions R/dx_measure.R
Original file line number Diff line number Diff line change
Expand Up @@ -265,3 +265,76 @@ dx_group_measure <- function(data, options, group_varname) {
rbind(res, bd)

}


#' Compare Multiple Classification Models
#'
#' Compares multiple classification models pairwise using various statistical tests
#' to assess differences in performance metrics. It supports both paired and unpaired
#' comparisons.
#'
#' @param dx_list A list of `dx` objects representing the models to be compared.
#' Each `dx` object should be the result of a call to `dx()`.
#' @param paired Logical, indicating whether the comparisons should be treated as paired.
#' Paired comparisons are appropriate when models are evaluated on the
#' same set of instances (e.g., cross-validation or repeated measures).
#'
#' @return A `dx_compare` object containing a list of `dx` objects and a data frame of
#' pairwise comparison results for each test conducted.
#'
#' @details This function is a utility to perform a comprehensive comparison between
#' multiple classification models. Based on the value of `paired`, it will
#' perform appropriate tests. The resulting object can be used it further
#' functions like `dx_plot_rocs.`
#'
#' @examples
#' dx_glm <- dx(data = dx_heart_failure, true_varname = "truth", pred_varname = "predicted")
#' dx_rf <- dx(data = dx_heart_failure, true_varname = "truth", pred_varname = "predicted_rf")
#' dx_list <- list(dx_glm, dx_rf)
#' dx_comp <- dx_compare(dx_list, paired = TRUE)
#' print(dx_comp$tests)
#' @seealso \code{\link{dx_delong}}, \code{\link{dx_z_test}}, \code{\link{dx_mcnemars}}
#' for more details on the tests used for comparisons.
#'
#' @export
dx_compare <- function(dx_list, paired = TRUE) {

dx_list <- validate_dx_list(dx_list)


combinations <- utils::combn(names(dx_list), 2)

res <- NULL

for (i in seq_along(ncol(combinations))) {
n1 <- combinations[1,i]
n2 <- combinations[2,i]
dx1 <- dx_list[[n1]]
dx2 <- dx_list[[n2]]
delong <- dx_delong(dx1, dx2, paired = paired)
res <- rbind(res, delong)

if (!paired) {
metrics <- c("accuracy", "ppv", "npv", "fnr", "fpr", "fdr", "sensitivity", "specificity")
for (metric in metrics) {
ztest <- dx_z_test(dx1, dx2, metric = metric)
res <- rbind(res, ztest)
}
} else {
mcnemars <- dx_mcnemars(dx1, dx2)
res <- rbind(res, mcnemars)
}

res$models <- two_model_name(n1, n2)
}

structure(
list(
dx_list = dx_list,
tests = res
),
class = "dx_compare"
)


}
29 changes: 29 additions & 0 deletions R/dx_metrics.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#' @param bootreps The number of bootstrap replications for calculating confidence intervals.
#' @param predprob Numeric vector of predicted probabilities associated with the positive class.
#' @param truth Numeric vector of true binary outcomes, typically 0 or 1, with the same length as `predprob`.
#' @param dx1 A `dx` object
#' @param dx2 A `dx` object
#' @return Depending on the `detail` parameter, returns a numeric value
#' representing the calculated metric or a data frame/tibble with
#' detailed diagnostics including confidence intervals and possibly other
Expand Down Expand Up @@ -577,6 +579,33 @@ measure_df <- function(measure = "", estimate = "", fraction = "",

}

compare_df <- function(models = "",
test = "",
summary = "",
p_value = "",
estimate = "",
conf_low = NA,
conf_high = NA,
statistic = "",
notes = "") {

metric <- data.frame(
models = models,
test = test,
summary = summary,
p_value = p_value,
estimate = estimate,
conf_low = conf_low,
conf_high = conf_high,
statistic = statistic,
notes = notes,
stringsAsFactors = FALSE
)

return_df(metric)

}

#' Calculate a Binomial Metric
#'
#' This internal function calculates a binomial metric and its confidence interval
Expand Down
128 changes: 128 additions & 0 deletions R/dx_plots.R
Original file line number Diff line number Diff line change
Expand Up @@ -1050,3 +1050,131 @@ dx_plot_thresholds <- function(dx_obj) {
# ggplot2::facet_wrap(~metric, scales = "free_y")
}

pluck_auc_roc_data <- function(dx_obj) {
measures <- as.data.frame(dx_obj, variable = "Overall", thresh = dx_obj$options$setthreshold)

sensdf <- measures[measures$measure == "Sensitivity", ]$estimate

specdf <- measures[measures$measure == "Specificity", ]$estimate

auc_df <- data.frame(
threshold = dx_obj$thresholds$threshold,
specificity = dx_obj$thresholds$specificity,
sensitivity = dx_obj$thresholds$sensitivity
)

auc_df$lead_specificity <- c(auc_df$specificity[-1], NA)

list(
sensitivity = sensdf,
specificity = specdf,
auc_df = auc_df
)

}


#' Plot ROC Curves for Multiple Models
#'
#' Generates Receiver Operating Characteristic (ROC) curves for multiple models
#' and overlays them for comparison. Optionally, it adds text annotations for DeLong's
#' test results to indicate statistical differences between the models' Area Under
#' the ROC Curve (AUC).
#'
#' @param dx_comp A `dx_compare` object containing the results of pairwise model
#' comparisons and the list of `dx` objects with ROC data.
#' @param add_text Logical, whether to add DeLong's test results as text annotations
#' on the plot. Defaults to TRUE.
#' @param axis_color Color of the axes lines, specified as a color name or hex code.
#' Defaults to "#333333".
#' @param text_color Color of the text annotations, specified as a color name or hex code.
#' Defaults to "black".
#'
#' @return A ggplot object representing the ROC curves for the models included in the
#' `dx_comp` object. Each model's ROC curve is color-coded, and the plot
#' includes annotations for DeLong's test p-values if `add_text` is TRUE.
#'
#' @details This function is a visualization tool that plots ROC curves for multiple
#' models to facilitate comparison. It uses DeLong's test to statistically
#' compare AUC values and, if desired, annotates the plot with the results.
#' The function expects a `dx_compare` object as input, which should contain
#' the necessary ROC and test comparison data. Ensure that the ROC data and
#' DeLong's test results are appropriately generated and stored in the
#' `dx_compare` object before using this function.
#'
#' @examples
#' dx_glm <- dx(data = dx_heart_failure, true_varname = "truth", pred_varname = "predicted")
#' dx_rf <- dx(data = dx_heart_failure, true_varname = "truth", pred_varname = "predicted_rf")
#' dx_list <- list(dx_glm, dx_rf)
#' dx_comp <- dx_compare(dx_list, paired = TRUE)
#' dx_plot_rocs(dx_comp)
#' @seealso \code{\link{dx_compare}} to generate the required input object.
#' \code{\link{dx_delong}} for details on DeLong's test used in comparisons.
#' @export
dx_plot_rocs <- function(dx_comp, add_text = TRUE, axis_color = "#333333", text_color = "black") {


roc_data <- lapply(dx_comp$dx_list, pluck_auc_roc_data)
delong <- dx_comp$tests
delong <- delong[delong$test == "DeLong's test for ROC curves", ]

for (i in seq_along(roc_data)) {
roc_data[[i]]$auc_df$model <- names(roc_data)[i]
}

df <- do.call(rbind, lapply(roc_data, function(model) model$auc_df))

df <- stats::na.omit(df)

p <- ggplot2::ggplot(df) +
ggplot2::geom_line(
ggplot2::aes(.data$specificity, .data$sensitivity, color = .data$model),
linewidth = 1
) +
ggplot2::scale_x_reverse() +
ggplot2::geom_hline(
yintercept = 0,
linewidth = 1,
colour = axis_color
) +
ggplot2::geom_vline(
xintercept = 1.05,
linewidth = 1,
colour = axis_color
) +
ggplot2::coord_fixed() +
dx_roc_ggtheme() +
ggplot2::labs(
x = "\nSpecificity",
y = "Sensitivity\n\n",
color = "Model"
) +
ggplot2::theme(legend.position = "bottom")


if (add_text) {
numsummary <- nrow(delong)
ystart <- 0.05
yend <- ystart + (numsummary - 1) * 0.05
location_vector <- seq(yend, ystart, -0.05)



text_df <- data.frame(
y = location_vector,
label = paste0(delong$models, ": ", sapply(delong$p_value, format_pvalue))
)

p <- p +
ggplot2::geom_text(
data = text_df,
mapping = ggplot2::aes(x = .05, y = .data$y, label = .data$label, hjust = 1),
color = text_color
)
}

return(p)

}


Loading

0 comments on commit 235ac97

Please sign in to comment.