Skip to content

Commit

Permalink
rebuild
Browse files Browse the repository at this point in the history
  • Loading branch information
zachmayer committed Aug 13, 2024
1 parent bfaa703 commit 5241cc4
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 30 deletions.
18 changes: 9 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,21 +56,21 @@ print(greedy_stack)
#>
#> No pre-processing
#> Resampling: Cross-Validated (5 fold)
#> Summary of sample sizes: 399, 401, 400, 400, 400
#> Summary of sample sizes: 400, 400, 400, 400, 400
#> Resampling results:
#>
#> RMSE Rsquared MAE
#> 925.6101 0.9454321 513.4172
#> 926.6459 0.9467796 558.8281
#>
#> Tuning parameter 'max_iter' was held constant at a value of 100
#>
#> Final model:
#> Greedy MSE
#> RMSE: 927.4321
#> RMSE: 944.4731
#> Weights:
#> [,1]
#> rf 0.76
#> glmnet 0.24
#> rf 0.58
#> glmnet 0.42
ggplot2::autoplot(greedy_stack, training_data = dat, xvars = c("carat", "table"))
```

Expand All @@ -89,11 +89,11 @@ print(rf_stack)
#>
#> No pre-processing
#> Resampling: Cross-Validated (5 fold)
#> Summary of sample sizes: 399, 400, 401, 400, 400
#> Summary of sample sizes: 400, 400, 400, 400, 400
#> Resampling results:
#>
#> RMSE Rsquared MAE
#> 935.9032 0.9469527 525.7192
#> 1039.019 0.9269161 558.9834
#>
#> Tuning parameter 'mtry' was held constant at a value of 2
#>
Expand All @@ -105,8 +105,8 @@ print(rf_stack)
#> Number of trees: 500
#> No. of variables tried at each split: 2
#>
#> Mean of squared residuals: 919446.8
#> % Var explained: 94.1
#> Mean of squared residuals: 1031204
#> % Var explained: 93.15
ggplot2::autoplot(rf_stack, training_data = dat, xvars = c("carat", "table"))
```

Expand Down
Binary file modified man/figures/README-unnamed-chunk-3-1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified man/figures/README-unnamed-chunk-4-1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion vignettes/Version-4.0-New-Features.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ author: "Zach Deane-Mayer"
date: "`r Sys.Date()`"
output: rmarkdown::html_vignette
vignette: >
%\VignetteIndexEntry{Version-4.0-New-Features}
%\VignetteIndexEntry{Version 4.0 New Features}
%\VignetteEngine{knitr::rmarkdown}
%\VignetteEncoding{UTF-8}
---
Expand Down
44 changes: 24 additions & 20 deletions vignettes/caretEnsemble-intro.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,19 @@ output: rmarkdown::html_vignette
vignette: >
%\VignetteIndexEntry{A Brief Introduction to caretEnsemble}
%\VignetteEngine{knitr::rmarkdown}
%\usepackage[utf8]{inputenc}
%\VignetteEncoding{UTF-8}
---

```{r, include = FALSE}
knitr::opts_chunk$set(
collapse = TRUE,
comment = "#>",
echo = TRUE,
warning = FALSE,
message = FALSE
)
```

`caretEnsemble` is a package for making ensembles of caret models. You should already be somewhat familiar with the caret package before trying out `caretEnsemble`.

`caretEnsemble` has 3 primary functions: **`caretList`**, **`caretEnsemble`** and **`caretStack`**. `caretList` is used to build lists of caret models on the same training data, with the same re-sampling parameters. `caretEnsemble` and `caretStack` are used to create ensemble models from such lists of caret models. `caretEnsemble` uses a glm to create a simple linear blend of models and `caretStack` uses a caret model to combine the outputs from several component caret models.
Expand All @@ -18,8 +28,7 @@ vignette: >
`caretList` is a flexible function for fitting many different caret models, with the same resampling parameters, to the same dataset. It returns a convenient `list` of caret objects which can later be passed to `caretEnsemble` and `caretStack`. `caretList` has almost exactly the same arguments as `train` (from the caret package), with the exception that the `trControl` argument comes last. It can handle both the formula interface and the explicit `x`, `y` interface to train. As in caret, the formula interface introduces some overhead and the `x`, `y` interface is preferred.

`caretEnsemble` has 2 arguments that can be used to specify which models to fit: `methodList` and `tuneList`. `methodList` is a simple character vector of methods that will be fit with the default `train` parameters, while `tuneList` can be used to customize the call to each component model and will be discussed in more detail later. First, lets build an example dataset (adapted from the caret vignette):
```{r, echo=TRUE, results="hide"}
# Adapted from the caret vignette
```{r, results="hide"}
data(Sonar, package = "mlbench")
set.seed(107L)
inTrain <- caret::createDataPartition(y = Sonar$Class, p = 0.75, list = FALSE)
Expand All @@ -33,19 +42,17 @@ model_list <- caretEnsemble::caretList(
data = training,
methodList = c("glmnet", "rpart")
)
summary(model_list)
```
(As with `train`, the formula interface is convenient but introduces move overhead. For large datasets the explicitly passing `x` and `y` is preferred).
We can use the `predict` function to extract predictions from this object for new data:
```{r, echo=TRUE, results="hide"}
```{r}
p <- predict(model_list, newdata = head(testing))
print(p)
```
```{r, echo=FALSE, results="asis"}
knitr::kable(p)
```

If you desire more control over the model fit, use the `caretModelSpec` to construct a list of model specifications for the `tuneList` argument. This argument can be used to fit several different variants of the same model, and can also be used to pass arguments through `train` down to the component functions (e.g. `trace=FALSE` for `nnet`):
```{r, echo=TRUE, results="hide", warning=FALSE}
```{r}
model_list_big <- caretEnsemble::caretList(
Class ~ .,
data = training,
Expand All @@ -56,30 +63,31 @@ model_list_big <- caretEnsemble::caretList(
nn = caretEnsemble::caretModelSpec(method = "nnet", tuneLength = 2L, trace = FALSE)
)
)
summary(model_list_big)
```

Finally, you should note that `caretList` does not support custom caret models. Fitting those models are beyond the scope of this vignette, but if you do so, you can manually add them to the model list (e.g. `model_list_big[["my_custom_model"]] <- my_custom_model`). Just be sure to use the same re-sampling indexes in `trControl` as you use in the `caretList` models!

## caretEnsemble
`caretList` is the preferred way to construct list of caret models in this package, as it will ensure the resampling indexes are identical across all models. Lets take a closer look at our list of models:
```{r, echo=TRUE, fig.show="hold"}
```{r}
lattice::xyplot(caret::resamples(model_list))
```

As you can see from this plot, these 2 models are uncorrelated, and the rpart model is occasionally anti-predictive, with a one re-sample showing AUC of 0.46.

We can confirm the 2 model"s correlation with the `modelCor` function from caret (caret has a lot of convenient functions for analyzing lists of models):
```{r, echo=TRUE}
```{r}
caret::modelCor(caret::resamples(model_list))
```

These 2 models make a good candidate for an ensemble: their predictions are fairly uncorrelated, but their overall accuracy is similar. We do a simple, linear greedy optimization on AUC using caretEnsemble:
```{r, echo=TRUE}
```{r}
greedy_ensemble <- caretEnsemble::caretEnsemble(model_list)
summary(greedy_ensemble)
```

```{r, echo=TRUE}
```{r}
model_preds <- predict(model_list, newdata = testing, excluded_class_id = 2L)
ens_preds <- predict(greedy_ensemble, newdata = testing, excluded_class_id = 2L)
model_preds$ensemble <- ens_preds
Expand All @@ -90,21 +98,17 @@ print(auc)
The ensemble has an AUC on the training set resamples of `r round(auc[1, 'ensemble'], 2)` which is about `r round(auc[1, 'ensemble'] - max(auc[1, 'glmnet'], auc[1, 'rpart']), 3) * 100`% better than the best individual model.

Note that the levels for the Sonar Data are "M" and "R", where M is level 1 and R is level 2. "M" stands for "metal cylinder" and "R" stands for rock. M is the positive class, so we exclude class 2L from our predictions. You can set excluded_class_id = 0L
```{r, echo=TRUE, results="hide", warning=FALSE}
```{r}
predict(greedy_ensemble, newdata = head(testing), excluded_class_id = 0L)
```

```{r, echo=FALSE, results="asis"}
knitr::kable(predict(greedy_ensemble, newdata = head(testing), excluded_class_id = 0L))
```

We can also use varImp to extract the variable importances from each member of the ensemble, as well as the final ensemble model:
```{r, echo=TRUE, results="hidshowe"}
```{r}
round(caret::varImp(greedy_ensemble), 4L)
```

## caretStack
```{r, echo=TRUE}
```{r}
glm_ensemble <- caretEnsemble::caretStack(model_list, method = "glm")
model_preds2 <- model_preds
model_preds2$ensemble <- predict(glm_ensemble, newdata = testing, excluded_class_id = 2L)
Expand All @@ -116,7 +120,7 @@ print(CF / sum(CF))
Note that `glm_ensemble$ens_model` is a regular caret object of class `train`. The glm-weighted model weights (glm vs rpart) and test-set AUCs are extremely similar to the caretEnsemble greedy optimization.

We can also use more sophisticated ensembles than simple linear weights, but these models are much more susceptible to over-fitting, and generally require large sets of resamples to train on (n=50 or higher for bootstrap samples). Lets try one anyways:
```{r, echo=TRUE}
```{r}
gbm_ensemble <- caretEnsemble::caretStack(
model_list,
method = "gbm",
Expand Down

0 comments on commit 5241cc4

Please sign in to comment.