Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[324] Fix rev dep #325

Merged
merged 15 commits into from
Aug 14, 2024
9 changes: 6 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,23 @@ Suggests:
caTools,
covr,
devtools,
earth,
gbm,
glmnet,
klaR,
knitr,
lintr,
mgcv,
mlbench,
nnet,
pkgdown,
randomForest,
rmarkdown,
rhub,
rpart,
spelling,
testthat,
usethis,
pkgdown,
rhub
usethis
Imports:
caret,
data.table,
Expand Down
8 changes: 7 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ help:
@echo " readme Build readme"
@echo " check-win Run R CMD on the winbuilder service from CRAN"
@echo " check-rhub Run R CMD on the rhub service"
@echo " check-many-preds Check that caretList can predict on ~200 caret models"
@echo " release Release to CRAN"
@echo " preview-site Preview pkgdown site"
@echo " dev-guide Open the R package development guide"
Expand Down Expand Up @@ -135,6 +136,10 @@ preview-site:
Rscript -e "pkgdown::build_site()"
open docs/index.html

.PHONY: check-many-preds
check-many-preds:
Rscript inst/data-raw/test-all_models.R

.PHONY: check-win
check-win:
rm -rf lib/
Expand All @@ -146,7 +151,7 @@ check-rhub:
Rscript -e "rhub::rhub_check(platform='linux')"

.PHONY: release
release: check-rhub check-win
release: check-many-preds check-rhub check-win
R --no-save --quiet --interactive
devtools::release()

Expand All @@ -158,6 +163,7 @@ dev-guide:
clean:
rm -rf *.Rcheck
rm -f *.tar.gz
rm -f *.Rout
rm -rf man/
rm -f README.md
rm -f coverage.rds
Expand Down
12 changes: 10 additions & 2 deletions R/caretPredict.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,18 @@ caretPredict <- function(object, newdata = NULL, excluded_class_id = 1L, ...) {

# Otherwise, predict on newdata
} else {
if (any(object$modelInfo$library %in% c("neuralnet", "klaR"))) {
newdata <- as.matrix(newdata) # I hate some of these packages
}
if (is_class) {
pred <- caret::predict.train(object, type = "prob", newdata = newdata, ...)
pred <- stats::predict(object, type = "prob", newdata = newdata, ...)
stopifnot(is.data.frame(pred))
} else {
pred <- caret::predict.train(object, type = "raw", newdata = newdata, ...)
pred <- stats::predict(object, type = "raw", newdata = newdata, ...)
stopifnot(is.numeric(pred))
if (!is.vector(pred)) {
pred <- as.vector(pred) # Backwards compatability with older earth and caret::train models
}
stopifnot(
is.vector(pred),
is.numeric(pred),
Expand Down
21 changes: 11 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ stack them with another caret model.
First, use caretList to fit many models to the same data:

``` r
set.seed(42L)
data(diamonds, package = "ggplot2")
dat <- data.table::data.table(diamonds)
dat <- dat[sample.int(nrow(diamonds), 500L), ]
Expand All @@ -48,8 +49,8 @@ print(summary(models))
#> Model accuracy:
#> model_name metric value sd
#> <char> <char> <num> <num>
#> 1: rf RMSE 1110.199 114.5286
#> 2: glmnet RMSE 1256.668 100.4436
#> 1: rf RMSE 1076.492 215.4737
#> 2: glmnet RMSE 1142.082 105.6022
```

Then, use caretEnsemble to make a greedy ensemble of these models
Expand All @@ -68,17 +69,17 @@ print(greedy_stack)
#> Resampling results:
#>
#> RMSE Rsquared MAE
#> 1015.885 0.9364999 572.8984
#> 969.2517 0.9406218 557.1987
#>
#> Tuning parameter 'max_iter' was held constant at a value of 100
#>
#> Final model:
#> Greedy MSE
#> RMSE: 1031.297
#> RMSE: 989.2085
#> Weights:
#> [,1]
#> rf 0.63
#> glmnet 0.37
#> rf 0.55
#> glmnet 0.45
```

You can also use caretStack to make a non-linear ensemble
Expand All @@ -97,8 +98,8 @@ print(rf_stack)
#> Summary of sample sizes: 400, 400, 400, 400, 400
#> Resampling results:
#>
#> RMSE Rsquared MAE
#> 1020.387 0.9363342 527.8525
#> RMSE Rsquared MAE
#> 1081.425 0.930012 540.3294
#>
#> Tuning parameter 'mtry' was held constant at a value of 2
#>
Expand All @@ -110,8 +111,8 @@ print(rf_stack)
#> Number of trees: 500
#> No. of variables tried at each split: 2
#>
#> Mean of squared residuals: 1065781
#> % Var explained: 93.33
#> Mean of squared residuals: 925377
#> % Var explained: 93.95
```

Use autoplot from ggplot2 to plot ensemble diagnostics:
Expand Down
1 change: 1 addition & 0 deletions README.rmd
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ Use `caretList` to fit multiple models, and then use `caretStack` to stack them

First, use caretList to fit many models to the same data:
```{r}
set.seed(42L)
data(diamonds, package = "ggplot2")
dat <- data.table::data.table(diamonds)
dat <- dat[sample.int(nrow(diamonds), 500L), ]
Expand Down
4 changes: 2 additions & 2 deletions inst/WORDLIST
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ rpart
savePredictions
scalability
scikit
setosa
SDs
setosa
trainControl
travis
tuneGrid
Expand All @@ -71,4 +71,4 @@ varImp
vecstack
verions
xvars
yhat
yhat
115 changes: 115 additions & 0 deletions inst/data-raw/build_backwards_compatability_data.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# This script is a little big sorry.
# We're using a 3rd party dataset from a package
# It depends on caretEnsemble, and I broke it with the 4.0.0 pre-release
# This script isolates the bad, saved model in that package
# and then removes all the parts of that model that aren't needed to make predictions
# this gives us a minimal test case for the backwards compatability issue
# This script shouldn't ever need to get run again, just use the old saved
# caretlist_with_bad_earth_model.rds file in data in testthat in the tests folder.
# This script is for posterity.

# Note this is not in our depends or suggests. Also note new version may not have the bug.
devtools::install_version("LDLcalc", version = "2.1", repos = "http://cran.us.r-project.org")
devtools::load_all()

# Load the data and fit the model
data(SampleData, package = "LDLcalc")
ldl_model <- LDLcalc:::LDL_ML_train_StackingAlgorithm(SampleData) # nolint undesirable_operator_linter
testthat::expect_s3_class(ldl_model$stackModel, "caretStack")
testthat::expect_s3_class(ldl_model$stackModel$models, "caretList")

# Make a caretList with just the bad model
caretlist_with_old_earth_model <- ldl_model$stackModel$models["earth"]

# Function to test the error and warnings after removing a specific part
test_error <- function(obj, path, SampleData) {
modified_obj <- obj
eval(parse(text = paste0("modified_obj", path, " <- NULL")))

wrns <- NULL

# Capture both errors and warnings
result <- tryCatch(
{
withCallingHandlers(
{
predict(modified_obj, SampleData)
},
warning = function(w) {
wrns <<- c(wrns, conditionMessage(w)) # nolint undesirable_operator_linter
invokeRestart("muffleWarning")
}
)
list(error = NULL, wrns = wrns)
},
error = function(e) {
list(error = e$message, wrns = wrns)
}
)

result
}

# Function to iteratively prune the object
prune_list_iterative <- function(obj, SampleData) { # nolint cyclocomp_linter
the_stack <- list(list(obj = obj, path = ""))
pruned_obj <- obj

while (length(the_stack) > 0L) {
# Pop the last element from the the_stack
current <- the_stack[[length(the_stack)]]
the_stack <- the_stack[-length(the_stack)]

if (is.list(current$obj)) {
keys <- names(current$obj)
for (key in keys) {
current_path <- paste0(current$path, "$", key)

# Test by removing the current element
result <- test_error(pruned_obj, current_path, SampleData)

# Determine if we should keep or remove the element
if ((!is.null(result$error) && result$error != "is.vector(pred) is not TRUE") || !is.null(result$wrns)) {
# If error changes, goes away, or a warning appears, keep the element
the_stack <- c(the_stack, list(list(obj = current$obj[[key]], path = current_path)))
} else {
# If error remains the same and no wrns, remove the element
eval(parse(text = paste0("pruned_obj", current_path, " <- NULL")))
}
}
}
}
pruned_obj
}

# Start the pruning process
pruned_caretlist <- prune_list_iterative(caretlist_with_old_earth_model, SampleData)

# Prune attributes
attr(pruned_caretlist$earth$terms, ".Environment") <- NULL
attr(pruned_caretlist$earth$terms, "dimnames") <- NULL
attr(pruned_caretlist$earth$terms, "term.labels") <- NULL
attr(pruned_caretlist$earth$terms, "order") <- NULL
attr(pruned_caretlist$earth$terms, "intercept") <- NULL
attr(pruned_caretlist$earth$terms, "response") <- NULL
attr(pruned_caretlist$earth$terms, "predvars") <- NULL
attr(pruned_caretlist$earth$terms, "dataClasses") <- NULL

# Test the final pruned object
# Note that once the bug is fixed, this will no longer fail
# this requires version 4.0.0 of caretEnsemble, prior to the PR fixing the prediciton issue
# https://github.com/zachmayer/caretEnsemble/issues/324
testthat::expect_error(
predict(pruned_caretlist, SampleData),
"is.vector(pred) is not TRUE",
fixed = TRUE
)

# Save
saveRDS(
pruned_caretlist,
file.path("tests", "testthat", "data", "caretlist_with_bad_earth_model.rds"),
ascii = FALSE,
version = 3L,
compress = "xz"
)
106 changes: 106 additions & 0 deletions inst/data-raw/test-all_models.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# This test takes a few minutes and needs to install and load a lot of packages
# I don't want to make it a dependency for the package or even for PR tests
# But I do want to run this every release to make sure that the models
# we can run predict correctly.

devtools::load_all()

very_quiet <- function(expr) {
testthat::expect_output(suppressWarnings(suppressMessages(expr)))
}

#################################################################
# Setup data
#################################################################
set.seed(42L)
nrows <- 10L
ncols <- 2L

X <- matrix(stats::rnorm(nrows * ncols), ncol = ncols)
colnames(X) <- paste0("X", 1L:ncols)

y <- X[, 1L] + X[, 2L] + stats::rnorm(nrows) / 10.0
y_bin <- factor(ifelse(y > median(y), "yes", "no"))

all_models <- data.table::data.table(caret::modelLookup())
all_models <- unique(all_models[, c("model", "forReg", "probModel")])

java_models <- c(
"gbm_h2o",
"glmnet_h2o",
"bartMachine",
"M5",
"M5Rules",
"J48",
"JRip",
"LMT",
"PART",
"OneR",
"evtree"
)

#################################################################
# Reg
#################################################################

# From https://github.com/zachmayer/caretEnsemble/issues/324
# Problem models:
# bam - array
# blackboost - matrix, array
# dnn - matrix, array
# earth - matrix, array
# gam - array
# gamboost - matrix, array
# glmboost - matrix, array
# pcaNNet - matrix, array
# rvmLinear - matrix, array
# rvmRadial - matrix, array
# spls - matrix, array
# xyf - matrix, array
reg_models <- sort(unique(all_models[which(forReg), ][["model"]]))
reg_models <- setdiff(reg_models, c( # Can't install or too slow
"elm", "extraTrees", "foba", "logicBag", "mlpSGD", "mxnet",
"mxnetAdam", "nodeHarvest", "relaxo",
java_models
))

#################################################################
# Class
#################################################################

# Problem models: None!
bin_models <- sort(unique(all_models[which(probModel), ][["model"]]))
bin_models <- setdiff(bin_models, c( # Can't install or too slow
"gaussprLinear", "adaboost", "amdai", "chaid", "extraTrees",
"gpls", "logicBag", "mlpSGD", "mxnet", "mxnetAdam", "nodeHarvest",
"ORFlog", "ORFpls", "ORFridge", "ORFsvm", "rrlda", "vbmpRadial",
java_models
))

#################################################################
# Tests
#################################################################

testthat::test_that("Most caret models can predict", {
# Fit the big caret lists
models_reg <- very_quiet(caretList(X, y, methodList = reg_models, tuneLength = 1L, continue_on_fail = TRUE))
models_bin <- very_quiet(caretList(X, y_bin, methodList = bin_models, tuneLength = 1L, continue_on_fail = TRUE))
all_models <- c(models_reg, models_bin)
testthat::expect_gt(length(all_models), 200L) # About 100 each of class/reg

# Make sure we can predict
pred <- very_quiet(predict(all_models, head(X, 5L)))
testthat::expect_identical(nrow(pred), 5L)
testthat::expect_identical(ncol(pred), length(all_models))
testthat::expect_true(all(unlist(lapply(pred, is.finite))))

# Make sure we can stacked predict
# Some of these stupid models predict Infs lol, so whatever.
# I guess beware of what models you ensemble.
# The bagEarth models are bad, as is rvmPoly and some others.
# These are stacked preds btw, so probably it indicates a fit failure
# on one fold. Many ensemble models can handle Nans, but we'll see.
pred_stack <- suppressWarnings(suppressMessages(predict(all_models)))
testthat::expect_identical(nrow(pred_stack), nrow(X))
testthat::expect_identical(ncol(pred_stack), length(all_models))
})
Binary file modified man/figures/README-greedy-stack-6-plot-1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified man/figures/README-unnamed-chunk-5-1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading