Skip to content

Commit

Permalink
rebuild
Browse files Browse the repository at this point in the history
  • Loading branch information
zachmayer committed Aug 13, 2024
1 parent 01a1e6d commit cbac717
Show file tree
Hide file tree
Showing 9 changed files with 59 additions and 49 deletions.
5 changes: 4 additions & 1 deletion R/caretList.R
Original file line number Diff line number Diff line change
Expand Up @@ -175,14 +175,17 @@ predict.caretList <- function(object, newdata = NULL, verbose = FALSE, excluded_
#' models. We also construct explicit fold indexes and return the stacked predictions,
#' which are needed for stacking. For classification models we return class probabilities.
#' @param target the target variable.
#' @param method the method to use for trainControl.
#' @param number the number of folds to use.
#' @param savePredictions the type of predictions to save.
#' @param index the fold indexes to use.
#' @param is_class logical, is this a classification or regression problem.
#' @param is_binary logical, is this binary classification.
#' @param ... other arguments to pass to \code{\link[caret]{trainControl}}
#' @export
defaultControl <- function(
target,
method = 'cv',
method = "cv",
number = 5L,
savePredictions = "final",
index = caret::createFolds(target, k = number, list = TRUE, returnTrain = TRUE),
Expand Down
2 changes: 1 addition & 1 deletion R/permutationImportance.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ shuffled_mae <- function(model, original_data, target, pred_type, shuffle_idx) {
new_preds <- as.matrix(stats::predict(model, original_data, type = pred_type))
data.table::set(original_data, j = var, value = old_var)

if(anyNA(new_preds)) { # This shoudn't happen, but it does with rpart.
if (anyNA(new_preds)) { # This shoudn't happen, but it does with rpart.
new_preds[is.na(new_preds)] <- 0.0
}

Expand Down
35 changes: 16 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,21 +56,21 @@ print(greedy_stack)
#>
#> No pre-processing
#> Resampling: Cross-Validated (5 fold)
#> Summary of sample sizes: 400, 400, 400, 400, 400
#> Summary of sample sizes: 399, 401, 400, 400, 400
#> Resampling results:
#>
#> RMSE Rsquared MAE
#> 1019.069 0.9401024 530.9788
#> 925.6101 0.9454321 513.4172
#>
#> Tuning parameter 'max_iter' was held constant at a value of 100
#>
#> Final model:
#> Greedy MSE
#> RMSE: 1013.725
#> RMSE: 927.4321
#> Weights:
#> [,1]
#> rf 0.9
#> glmnet 0.1
#> rf 0.76
#> glmnet 0.24
ggplot2::autoplot(greedy_stack, training_data = dat, xvars = c("carat", "table"))
```

Expand All @@ -89,11 +89,11 @@ print(rf_stack)
#>
#> No pre-processing
#> Resampling: Cross-Validated (5 fold)
#> Summary of sample sizes: 400, 400, 400, 400, 400
#> Summary of sample sizes: 399, 400, 401, 400, 400
#> Resampling results:
#>
#> RMSE Rsquared MAE
#> 1076.369 0.9338026 531.9601
#> 935.9032 0.9469527 525.7192
#>
#> Tuning parameter 'mtry' was held constant at a value of 2
#>
Expand All @@ -105,8 +105,8 @@ print(rf_stack)
#> Number of trees: 500
#> No. of variables tried at each split: 2
#>
#> Mean of squared residuals: 1136154
#> % Var explained: 93.38
#> Mean of squared residuals: 919446.8
#> % Var explained: 94.1
ggplot2::autoplot(rf_stack, training_data = dat, xvars = c("carat", "table"))
```

Expand Down Expand Up @@ -141,7 +141,7 @@ will work in a variety of environments.

# Package development

This packages uses a Makefile. Use `make help` to see the supported
This package uses a Makefile. Use `make help` to see the supported
options.

Use `make fix-style` to fix simple linting errors.
Expand All @@ -157,20 +157,17 @@ CHECK and a code coverage check. This runs
## First time dev setup:

run `make install` from the git repository to install the dev version of
caretEnsemble, along with the necessary package dependencies.

# Inspiration and similar packages:

caretEnsemble was inspired by [medley](https://github.com/mewo2/medley),
which in turn was inspired by Caruana et. al.’s (2004) paper [Ensemble
Selection from Libraries of
caretEnsemble, along with the necessary package dependencies. \#
Inspiration and similar packages: caretEnsemble was inspired by
[medley](https://github.com/mewo2/medley), which in turn was inspired by
Caruana et. al.’s (2004) paper [Ensemble Selection from Libraries of
Models.](http://www.cs.cornell.edu/~caruana/ctp/ct.papers/caruana.icml04.icdm06long.pdf)

If you want to do something similar in python, check out
[vecstack](https://github.com/vecxoz/vecstack)
[vecstack](https://github.com/vecxoz/vecstack).

# Code of Conduct:

Please note that this project is released with a [Contributor Code of
Conduct](https://github.com/zachmayer/caretEnsemble/blob/master/.github/CONTRIBUTING.md).
By participating in this project, you agree to abide by its terms.
By participating in this project you agree to abide by its terms.
2 changes: 1 addition & 1 deletion man/caretEnsemble.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 9 additions & 0 deletions man/defaultControl.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Binary file modified man/figures/README-unnamed-chunk-3-1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified man/figures/README-unnamed-chunk-4-1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
9 changes: 4 additions & 5 deletions tests/testthat/test-permutationImportance.R
Original file line number Diff line number Diff line change
Expand Up @@ -216,15 +216,14 @@ testthat::test_that("permutationImportance handles various edge cases", {
testthat::context("NAN predictions from rpart")
######################################################################


testthat::test_that("permutationImportance handles NAN predictions from rpart", {
set.seed(42)
set.seed(42L)
model_list <- caretEnsemble::caretList(
x = iris[, 1:4],
y = iris[, 5],
x = iris[, 1L:4L],
y = iris[, 5L],
methodList = "rpart"
)
ens <- caretEnsemble(model_list)
imp <- caret::varImp(ens)
testthat::expect_true(all(is.finite(imp)))
})
})
46 changes: 24 additions & 22 deletions vignettes/Version-4.0-New-Features.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@ knitr::opts_chunk$set(
warning = FALSE,
message = FALSE
)
set.seed(42L)
```

caretEnsemble 4.0.0 introduces many new features! Let's quickly go over them.

# Multiclass support
caretEnsemble now fully supports multiclass problems:
```{r}
set.seed(42)
model_list <- caretEnsemble::caretList(
x = iris[, 1:4],
y = iris[, 5],
x = iris[, 1L:4L],
y = iris[, 5L],
methodList = c("rpart", "rf")
)
```
Expand Down Expand Up @@ -62,21 +62,20 @@ print(ls(reg_control))
# Mixed Resampling Strategies
Models with different resampling strategies can now be ensembled:
```{r}
set.seed(42)
y <- iris[, 1]
x <- iris[, 2:3]
y <- iris[, 1L]
x <- iris[, 2L:3L]
flex_list <- caretEnsemble::caretList(
x = x,
y = y,
methodList = c("rpart", "rf"),
trControl = caretEnsemble::defaultControl(y, number = 3)
trControl = caretEnsemble::defaultControl(y, number = 3L)
)
flex_list$glm_boot <- caret::train(
x = x,
y = y,
method = "glm",
trControl = caretEnsemble::defaultControl(y, method = "boot", number=50)
trControl = caretEnsemble::defaultControl(y, method = "boot", number = 25L)
)
flex_ens <- caretEnsemble::caretEnsemble(flex_list)
Expand All @@ -86,35 +85,38 @@ print(flex_ens)
# Mixed Model Types
caretEnsemble now allows ensembling of mixed lists of classification and regression models:
```{r}
set.seed(42)
X <- iris[,1:4]
target_class <- iris[, 5]
target_reg <- as.integer(iris[, 5] == 'virginica')
model_class <- caret::train(iris[, 1:4], target_class, method = "rf", trControl = caretEnsemble::defaultControl(target_class))
model_reg <- caret::train(iris[, 1:4], target_reg, method = "rf", trControl = caretEnsemble::defaultControl(target_reg))
mixed_list <- caretEnsemble::as.caretList(list(class=model_class, reg=model_reg))
X <- iris[, 1L:4L]
target_class <- iris[, 5L]
target_reg <- as.integer(iris[, 5L] == "virginica")
ctrl_class <- caretEnsemble::defaultControl(target_class)
ctrl_reg <- caretEnsemble::defaultControl(target_reg)
model_class <- caret::train(iris[, 1L:4L], target_class, method = "rf", trControl = ctrl_class)
model_reg <- caret::train(iris[, 1L:4L], target_reg, method = "rf", trControl = ctrl_reg)
mixed_list <- caretEnsemble::as.caretList(list(class = model_class, reg = model_reg))
mixed_ens <- caretEnsemble::caretEnsemble(mixed_list)
print(mixed_ens)
```

# Transfer Learning
caretStack now supports transfer learning for ensembling models trained on different datasets:
```{r}
set.seed(42)
train_idx <- sample(1:nrow(iris), 100)
train_idx <- sample.int(nrow(iris), 100L)
train_data <- iris[train_idx, ]
new_data <- iris[-train_idx, ]
model_list <- caretEnsemble::caretList(
x = train_data[, 1:4],
y = train_data[, 5],
x = train_data[, 1L:4L],
y = train_data[, 5L],
methodList = c("rpart", "rf")
)
transfer_ens <- caretEnsemble::caretEnsemble(
model_list,
new_X = new_data[, 1:4],
new_y = new_data[, 5]
new_X = new_data[, 1L:4L],
new_y = new_data[, 5L]
)
print(transfer_ens)
Expand All @@ -133,4 +135,4 @@ importance <- caret::varImp(transfer_ens)
print(importance)
```

This completes our demonstration of the key new features in caretEnsemble 4.0. These enhancements provide greater flexibility, improved performance, and easier usage for ensemble modeling in R.
This completes our demonstration of the key new features in caretEnsemble 4.0. These enhancements provide greater flexibility, improved performance, and easier usage for ensemble modeling in R.

0 comments on commit cbac717

Please sign in to comment.