-
Notifications
You must be signed in to change notification settings - Fork 74
Commit
…ming.
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
context("Test aggregate_resamples functionality") | ||
|
||
test_that("extractBestPreds respects aggregate_resamples parameter", { | ||
# Create a simple model with repeated CV | ||
data(iris) | ||
tc <- caret::trainControl( | ||
method = "repeatedcv", | ||
number = 2, | ||
Check notice on line 8 in tests/testthat/test-aggregate-resamples.R
|
||
repeats = 2, | ||
Check notice on line 9 in tests/testthat/test-aggregate-resamples.R
|
||
savePredictions = "final" | ||
) | ||
|
||
Check notice on line 12 in tests/testthat/test-aggregate-resamples.R
|
||
model <- caret::train( | ||
Sepal.Length ~ ., | ||
data = iris, | ||
method = "lm", | ||
trControl = tc | ||
) | ||
|
||
Check notice on line 19 in tests/testthat/test-aggregate-resamples.R
|
||
# Test with aggregation (default) | ||
preds_agg <- extractBestPreds(model, aggregate_resamples = TRUE) | ||
expect_equal(nrow(preds_agg), nrow(iris)) | ||
Check warning on line 22 in tests/testthat/test-aggregate-resamples.R
|
||
|
||
Check notice on line 23 in tests/testthat/test-aggregate-resamples.R
|
||
# Test without aggregation | ||
preds_no_agg <- extractBestPreds(model, aggregate_resamples = FALSE) | ||
expect_equal(nrow(preds_no_agg), nrow(iris) * 4) # 2 folds * 2 repeats | ||
Check warning on line 26 in tests/testthat/test-aggregate-resamples.R
|
||
expect_true(nrow(preds_no_agg) > nrow(preds_agg)) | ||
}) | ||
|
||
test_that("caretList respects aggregate_resamples parameter", { | ||
data(iris) | ||
tc <- caret::trainControl( | ||
method = "repeatedcv", | ||
number = 2, | ||
repeats = 2, | ||
savePredictions = "final" | ||
) | ||
|
||
# Test with aggregation (default) | ||
model_list_agg <- caretList( | ||
Sepal.Length ~ ., | ||
data = iris, | ||
trControl = tc, | ||
methodList = c("lm", "rf"), | ||
tuneLength = 1 | ||
) | ||
|
||
# Test without aggregation | ||
model_list_no_agg <- caretList( | ||
Sepal.Length ~ ., | ||
data = iris, | ||
trControl = tc, | ||
methodList = c("lm", "rf"), | ||
tuneLength = 1, | ||
aggregate_resamples = FALSE | ||
) | ||
|
||
# Check predictions from both models | ||
expect_equal(nrow(model_list_agg$lm$pred), nrow(iris)) | ||
Check warning on line 59 in tests/testthat/test-aggregate-resamples.R
|
||
expect_equal(nrow(model_list_no_agg$lm$pred), nrow(iris) * 4) # 2 folds * 2 repeats | ||
Check notice on line 60 in tests/testthat/test-aggregate-resamples.R
|
||
expect_true(nrow(model_list_no_agg$lm$pred) > nrow(model_list_agg$lm$pred)) | ||
}) |