diff --git a/R/caretList.R b/R/caretList.R index dc1ca5d2..7b55942e 100644 --- a/R/caretList.R +++ b/R/caretList.R @@ -37,7 +37,8 @@ caretList <- function( tuneList = NULL, metric = NULL, continue_on_fail = FALSE, - trim = TRUE) { + trim = TRUE, + aggregate_resamples = TRUE) { # Checks if (is.null(tuneList) && is.null(methodList)) { stop("Please either define a methodList or tuneList", call. = FALSE) @@ -79,7 +80,11 @@ caretList <- function( global_args[["metric"]] <- metric # Loop through the tuneLists and fit caret models with those specs - modelList <- lapply(tuneList, caretTrain, global_args = global_args, continue_on_fail = continue_on_fail, trim = trim) + modelList <- lapply(tuneList, function(x) { + # Add aggregate_resamples to model args + x$aggregate_resamples <- aggregate_resamples + caretTrain(x, global_args = global_args, continue_on_fail = continue_on_fail, trim = trim) + }) names(modelList) <- names(tuneList) nulls <- vapply(modelList, is.null, logical(1L)) modelList <- modelList[!nulls] diff --git a/R/caretPredict.R b/R/caretPredict.R index 853bba15..d98d0a92 100644 --- a/R/caretPredict.R +++ b/R/caretPredict.R @@ -100,7 +100,7 @@ caretTrain <- function(local_args, global_args, continue_on_fail = FALSE, trim = # Only save stacked predictions for the best model if ("pred" %in% names(model)) { - model[["pred"]] <- extractBestPreds(model) + model[["pred"]] <- extractBestPreds(model, aggregate_resamples = if (!is.null(model_args$aggregate_resamples)) model_args$aggregate_resamples else TRUE) } if (trim) { @@ -147,9 +147,10 @@ aggregate_mean_or_first <- function(x) { #' @title Extract the best predictions from a train object #' @description Extract the best predictions from a train object. #' @param x a train object +#' @param aggregate_resamples logical, should resamples be aggregated (default TRUE) #' @return a data.table::data.table with predictions #' @keywords internal -extractBestPreds <- function(x) { +extractBestPreds <- function(x, aggregate_resamples = TRUE) { stopifnot(methods::is(x, "train")) if (is.null(x$pred)) { stop("No predictions saved during training. Please set savePredictions = 'final' in trainControl", call. = FALSE) @@ -167,14 +168,20 @@ extractBestPreds <- function(x) { # Drop rows for other tunes pred <- pred[best_tune, ] - # If we have multiple resamples per row - # e.g. for repeated CV, we need to average the predictions - keys <- "rowIndex" - data.table::setkeyv(pred, keys) - pred <- pred[, lapply(.SD, aggregate_mean_or_first), by = keys] - # Order results consistently - data.table::setorderv(pred, keys) + data.table::setorderv(pred, "rowIndex") + + if (aggregate_resamples) { + # If we have multiple resamples per row + # e.g. for repeated CV, we need to average the predictions + keys <- "rowIndex" + data.table::setkeyv(pred, keys) + pred <- pred[, lapply(.SD, aggregate_mean_or_first), by = keys] + } else { + # Keep all resamples + # Remove columns that are not needed + pred[, c("intercept", "Resample") := NULL] + } # Return pred diff --git a/tests/testthat/test-aggregate-resamples.R b/tests/testthat/test-aggregate-resamples.R new file mode 100644 index 00000000..0833a886 --- /dev/null +++ b/tests/testthat/test-aggregate-resamples.R @@ -0,0 +1,62 @@ +context("Test aggregate_resamples functionality") + +test_that("extractBestPreds respects aggregate_resamples parameter", { + # Create a simple model with repeated CV + data(iris) + tc <- caret::trainControl( + method = "repeatedcv", + number = 2, + repeats = 2, + savePredictions = "final" + ) + + model <- caret::train( + Sepal.Length ~ ., + data = iris, + method = "lm", + trControl = tc + ) + + # Test with aggregation (default) + preds_agg <- extractBestPreds(model, aggregate_resamples = TRUE) + expect_equal(nrow(preds_agg), nrow(iris)) + + # Test without aggregation + preds_no_agg <- extractBestPreds(model, aggregate_resamples = FALSE) + expect_equal(nrow(preds_no_agg), nrow(iris) * 4) # 2 folds * 2 repeats + expect_true(nrow(preds_no_agg) > nrow(preds_agg)) +}) + +test_that("caretList respects aggregate_resamples parameter", { + data(iris) + tc <- caret::trainControl( + method = "repeatedcv", + number = 2, + repeats = 2, + savePredictions = "final" + ) + + # Test with aggregation (default) + model_list_agg <- caretList( + Sepal.Length ~ ., + data = iris, + trControl = tc, + methodList = c("lm", "rf"), + tuneLength = 1 + ) + + # Test without aggregation + model_list_no_agg <- caretList( + Sepal.Length ~ ., + data = iris, + trControl = tc, + methodList = c("lm", "rf"), + tuneLength = 1, + aggregate_resamples = FALSE + ) + + # Check predictions from both models + expect_equal(nrow(model_list_agg$lm$pred), nrow(iris)) + expect_equal(nrow(model_list_no_agg$lm$pred), nrow(iris) * 4) # 2 folds * 2 repeats + expect_true(nrow(model_list_no_agg$lm$pred) > nrow(model_list_agg$lm$pred)) +})