Skip to content

Commit

Permalink
Fix issue #346: Add option to keep resamples by fold, rather than sum…
Browse files Browse the repository at this point in the history
…ming.
  • Loading branch information
openhands-agent committed Dec 13, 2024
1 parent a1a29a9 commit 5bfe0cf
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 11 deletions.
9 changes: 7 additions & 2 deletions R/caretList.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ caretList <- function(
tuneList = NULL,
metric = NULL,
continue_on_fail = FALSE,
trim = TRUE) {
trim = TRUE,
aggregate_resamples = TRUE) {
# Checks
if (is.null(tuneList) && is.null(methodList)) {
stop("Please either define a methodList or tuneList", call. = FALSE)
Expand Down Expand Up @@ -79,7 +80,11 @@ caretList <- function(
global_args[["metric"]] <- metric

# Loop through the tuneLists and fit caret models with those specs
modelList <- lapply(tuneList, caretTrain, global_args = global_args, continue_on_fail = continue_on_fail, trim = trim)
modelList <- lapply(tuneList, function(x) {
# Add aggregate_resamples to model args
x$aggregate_resamples <- aggregate_resamples
caretTrain(x, global_args = global_args, continue_on_fail = continue_on_fail, trim = trim)
})
names(modelList) <- names(tuneList)
nulls <- vapply(modelList, is.null, logical(1L))
modelList <- modelList[!nulls]
Expand Down
25 changes: 16 additions & 9 deletions R/caretPredict.R
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ caretTrain <- function(local_args, global_args, continue_on_fail = FALSE, trim =

# Only save stacked predictions for the best model
if ("pred" %in% names(model)) {
model[["pred"]] <- extractBestPreds(model)
model[["pred"]] <- extractBestPreds(model, aggregate_resamples = if (!is.null(model_args$aggregate_resamples)) model_args$aggregate_resamples else TRUE)

Check notice on line 103 in R/caretPredict.R

View check run for this annotation

codefactor.io / CodeFactor

R/caretPredict.R#L103

Lines should not be more than 120 characters. This line is 156 characters. (line_length_linter)

Check warning on line 103 in R/caretPredict.R

View workflow job for this annotation

GitHub Actions / lint

file=R/caretPredict.R,line=103,col=121,[line_length_linter] Lines should not be more than 120 characters. This line is 156 characters.
}

if (trim) {
Expand Down Expand Up @@ -147,9 +147,10 @@ aggregate_mean_or_first <- function(x) {
#' @title Extract the best predictions from a train object
#' @description Extract the best predictions from a train object.
#' @param x a train object
#' @param aggregate_resamples logical, should resamples be aggregated (default TRUE)
#' @return a data.table::data.table with predictions
#' @keywords internal
extractBestPreds <- function(x) {
extractBestPreds <- function(x, aggregate_resamples = TRUE) {
stopifnot(methods::is(x, "train"))
if (is.null(x$pred)) {
stop("No predictions saved during training. Please set savePredictions = 'final' in trainControl", call. = FALSE)
Expand All @@ -167,14 +168,20 @@ extractBestPreds <- function(x) {
# Drop rows for other tunes
pred <- pred[best_tune, ]

# If we have multiple resamples per row
# e.g. for repeated CV, we need to average the predictions
keys <- "rowIndex"
data.table::setkeyv(pred, keys)
pred <- pred[, lapply(.SD, aggregate_mean_or_first), by = keys]

# Order results consistently
data.table::setorderv(pred, keys)
data.table::setorderv(pred, "rowIndex")

if (aggregate_resamples) {
# If we have multiple resamples per row
# e.g. for repeated CV, we need to average the predictions
keys <- "rowIndex"
data.table::setkeyv(pred, keys)
pred <- pred[, lapply(.SD, aggregate_mean_or_first), by = keys]
} else {
# Keep all resamples
# Remove columns that are not needed
pred[, c("intercept", "Resample") := NULL]

Check warning on line 183 in R/caretPredict.R

View workflow job for this annotation

GitHub Actions / lint

file=R/caretPredict.R,line=183,col=39,[object_usage_linter] no visible global function definition for ':='
}

# Return
pred
Expand Down
62 changes: 62 additions & 0 deletions tests/testthat/test-aggregate-resamples.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
context("Test aggregate_resamples functionality")

test_that("extractBestPreds respects aggregate_resamples parameter", {
# Create a simple model with repeated CV
data(iris)
tc <- caret::trainControl(
method = "repeatedcv",
number = 2,

Check notice on line 8 in tests/testthat/test-aggregate-resamples.R

View check run for this annotation

codefactor.io / CodeFactor

tests/testthat/test-aggregate-resamples.R#L8

Use 2L or 2.0 to avoid implicit integers. (implicit_integer_linter)

Check warning on line 8 in tests/testthat/test-aggregate-resamples.R

View workflow job for this annotation

GitHub Actions / lint

file=tests/testthat/test-aggregate-resamples.R,line=8,col=15,[implicit_integer_linter] Use 2L or 2.0 to avoid implicit integers.
repeats = 2,

Check notice on line 9 in tests/testthat/test-aggregate-resamples.R

View check run for this annotation

codefactor.io / CodeFactor

tests/testthat/test-aggregate-resamples.R#L9

Use 2L or 2.0 to avoid implicit integers. (implicit_integer_linter)

Check warning on line 9 in tests/testthat/test-aggregate-resamples.R

View workflow job for this annotation

GitHub Actions / lint

file=tests/testthat/test-aggregate-resamples.R,line=9,col=16,[implicit_integer_linter] Use 2L or 2.0 to avoid implicit integers.
savePredictions = "final"
)

Check notice on line 12 in tests/testthat/test-aggregate-resamples.R

View check run for this annotation

codefactor.io / CodeFactor

tests/testthat/test-aggregate-resamples.R#L12

Remove trailing whitespace. (trailing_whitespace_linter)

Check warning on line 12 in tests/testthat/test-aggregate-resamples.R

View workflow job for this annotation

GitHub Actions / lint

file=tests/testthat/test-aggregate-resamples.R,line=12,col=1,[trailing_whitespace_linter] Remove trailing whitespace.
model <- caret::train(
Sepal.Length ~ .,
data = iris,
method = "lm",
trControl = tc
)

Check notice on line 19 in tests/testthat/test-aggregate-resamples.R

View check run for this annotation

codefactor.io / CodeFactor

tests/testthat/test-aggregate-resamples.R#L19

Remove trailing whitespace. (trailing_whitespace_linter)

Check warning on line 19 in tests/testthat/test-aggregate-resamples.R

View workflow job for this annotation

GitHub Actions / lint

file=tests/testthat/test-aggregate-resamples.R,line=19,col=1,[trailing_whitespace_linter] Remove trailing whitespace.
# Test with aggregation (default)
preds_agg <- extractBestPreds(model, aggregate_resamples = TRUE)
expect_equal(nrow(preds_agg), nrow(iris))

Check warning on line 22 in tests/testthat/test-aggregate-resamples.R

View check run for this annotation

codefactor.io / CodeFactor

tests/testthat/test-aggregate-resamples.R#L22

Use expect_identical(x, y) by default; resort to expect_equal() only when needed, e.g. when setting ignore_attr= or tolerance=. (expect_identical_linter)

Check warning on line 22 in tests/testthat/test-aggregate-resamples.R

View workflow job for this annotation

GitHub Actions / lint

file=tests/testthat/test-aggregate-resamples.R,line=22,col=3,[expect_identical_linter] Use expect_identical(x, y) by default; resort to expect_equal() only when needed, e.g. when setting ignore_attr= or tolerance=.

Check notice on line 23 in tests/testthat/test-aggregate-resamples.R

View check run for this annotation

codefactor.io / CodeFactor

tests/testthat/test-aggregate-resamples.R#L23

Remove trailing whitespace. (trailing_whitespace_linter)

Check warning on line 23 in tests/testthat/test-aggregate-resamples.R

View workflow job for this annotation

GitHub Actions / lint

file=tests/testthat/test-aggregate-resamples.R,line=23,col=1,[trailing_whitespace_linter] Remove trailing whitespace.
# Test without aggregation
preds_no_agg <- extractBestPreds(model, aggregate_resamples = FALSE)
expect_equal(nrow(preds_no_agg), nrow(iris) * 4) # 2 folds * 2 repeats

Check warning on line 26 in tests/testthat/test-aggregate-resamples.R

View check run for this annotation

codefactor.io / CodeFactor

tests/testthat/test-aggregate-resamples.R#L26

Use expect_identical(x, y) by default; resort to expect_equal() only when needed, e.g. when setting ignore_attr= or tolerance=. (expect_identical_linter)

Check notice on line 26 in tests/testthat/test-aggregate-resamples.R

View check run for this annotation

codefactor.io / CodeFactor

tests/testthat/test-aggregate-resamples.R#L26

Use 4L or 4.0 to avoid implicit integers. (implicit_integer_linter)

Check warning on line 26 in tests/testthat/test-aggregate-resamples.R

View workflow job for this annotation

GitHub Actions / lint

file=tests/testthat/test-aggregate-resamples.R,line=26,col=3,[expect_identical_linter] Use expect_identical(x, y) by default; resort to expect_equal() only when needed, e.g. when setting ignore_attr= or tolerance=.

Check warning on line 26 in tests/testthat/test-aggregate-resamples.R

View workflow job for this annotation

GitHub Actions / lint

file=tests/testthat/test-aggregate-resamples.R,line=26,col=50,[implicit_integer_linter] Use 4L or 4.0 to avoid implicit integers.
expect_true(nrow(preds_no_agg) > nrow(preds_agg))

Check warning on line 27 in tests/testthat/test-aggregate-resamples.R

View check run for this annotation

codefactor.io / CodeFactor

tests/testthat/test-aggregate-resamples.R#L27

expect_gt(x, y) is better than expect_true(x > y). (expect_comparison_linter)
})

test_that("caretList respects aggregate_resamples parameter", {
data(iris)
tc <- caret::trainControl(
method = "repeatedcv",
number = 2,

Check notice on line 34 in tests/testthat/test-aggregate-resamples.R

View check run for this annotation

codefactor.io / CodeFactor

tests/testthat/test-aggregate-resamples.R#L34

Use 2L or 2.0 to avoid implicit integers. (implicit_integer_linter)
repeats = 2,

Check notice on line 35 in tests/testthat/test-aggregate-resamples.R

View check run for this annotation

codefactor.io / CodeFactor

tests/testthat/test-aggregate-resamples.R#L35

Use 2L or 2.0 to avoid implicit integers. (implicit_integer_linter)
savePredictions = "final"
)

Check notice on line 38 in tests/testthat/test-aggregate-resamples.R

View check run for this annotation

codefactor.io / CodeFactor

tests/testthat/test-aggregate-resamples.R#L38

Remove trailing whitespace. (trailing_whitespace_linter)
# Test with aggregation (default)
model_list_agg <- caretList(
Sepal.Length ~ .,
data = iris,
trControl = tc,
methodList = c("lm", "rf"),
tuneLength = 1

Check notice on line 45 in tests/testthat/test-aggregate-resamples.R

View check run for this annotation

codefactor.io / CodeFactor

tests/testthat/test-aggregate-resamples.R#L45

Use 1L or 1.0 to avoid implicit integers. (implicit_integer_linter)
)

Check notice on line 47 in tests/testthat/test-aggregate-resamples.R

View check run for this annotation

codefactor.io / CodeFactor

tests/testthat/test-aggregate-resamples.R#L47

Remove trailing whitespace. (trailing_whitespace_linter)
# Test without aggregation
model_list_no_agg <- caretList(
Sepal.Length ~ .,
data = iris,
trControl = tc,
methodList = c("lm", "rf"),
tuneLength = 1,

Check notice on line 54 in tests/testthat/test-aggregate-resamples.R

View check run for this annotation

codefactor.io / CodeFactor

tests/testthat/test-aggregate-resamples.R#L54

Use 1L or 1.0 to avoid implicit integers. (implicit_integer_linter)
aggregate_resamples = FALSE
)

Check notice on line 57 in tests/testthat/test-aggregate-resamples.R

View check run for this annotation

codefactor.io / CodeFactor

tests/testthat/test-aggregate-resamples.R#L57

Remove trailing whitespace. (trailing_whitespace_linter)
# Check predictions from both models
expect_equal(nrow(model_list_agg$lm$pred), nrow(iris))

Check warning on line 59 in tests/testthat/test-aggregate-resamples.R

View check run for this annotation

codefactor.io / CodeFactor

tests/testthat/test-aggregate-resamples.R#L59

Use expect_identical(x, y) by default; resort to expect_equal() only when needed, e.g. when setting ignore_attr= or tolerance=. (expect_identical_linter)
expect_equal(nrow(model_list_no_agg$lm$pred), nrow(iris) * 4) # 2 folds * 2 repeats

Check notice on line 60 in tests/testthat/test-aggregate-resamples.R

View check run for this annotation

codefactor.io / CodeFactor

tests/testthat/test-aggregate-resamples.R#L60

Use 4L or 4.0 to avoid implicit integers. (implicit_integer_linter)

Check warning on line 60 in tests/testthat/test-aggregate-resamples.R

View check run for this annotation

codefactor.io / CodeFactor

tests/testthat/test-aggregate-resamples.R#L60

Use expect_identical(x, y) by default; resort to expect_equal() only when needed, e.g. when setting ignore_attr= or tolerance=. (expect_identical_linter)
expect_true(nrow(model_list_no_agg$lm$pred) > nrow(model_list_agg$lm$pred))

Check warning on line 61 in tests/testthat/test-aggregate-resamples.R

View check run for this annotation

codefactor.io / CodeFactor

tests/testthat/test-aggregate-resamples.R#L61

expect_gt(x, y) is better than expect_true(x > y). (expect_comparison_linter)
})

0 comments on commit 5bfe0cf

Please sign in to comment.