Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[R-package] multiclass predictions: set reshape = TRUE by default? #6131

Closed
mayer79 opened this issue Oct 8, 2023 · 6 comments
Closed

[R-package] multiclass predictions: set reshape = TRUE by default? #6131

mayer79 opened this issue Oct 8, 2023 · 6 comments

Comments

@mayer79
Copy link
Contributor

mayer79 commented Oct 8, 2023

Similar to dmlc/xgboost#9640

I am repeatedly stumbling over the result of predict() with a multiclass model in R. It returns a vector of length $n \times m$, where $n$ is the number of observations, and $m$ the number of classes. Setting reshape = TRUE gives the much more intuitive $n \times m$ matrix. I am wondering if we could change the default to reshape = TRUE? I'd be able to open a PR for this.

Example

library(lightgbm)

params <- list(objective = "multiclass", num_class = 3, learning_rate = 0.2)
X_pred <- data.matrix(iris[, -5])
dtrain <- lgb.Dataset(X_pred, label = as.integer(iris[, 5]) - 1)

fit <- lgb.train(params = params, data = dtrain, nrounds = 100)

predict(fit, head(X_pred, 2))
# Gives an unintuitive vector:  9.999995e-01 4.673014e-07 2.194433e-10 9.999990e-01 9.719635e-07 6.442607e-09

predict(fit, head(X_pred, 2), reshape = TRUE)
# Gives a beautiful matrix:
#                   [,1]                 [,2]                  [,3]
# [1,] 0.9999995 4.673014e-07 2.194433e-10
# [2,] 0.9999990 9.719635e-07 6.442607e-09
@jameslamb
Copy link
Collaborator

Thanks as always for your interest in LightGBM.

I don't disagree with you that an (n_samples, n_classes) matrix is more "intuitive" of a response for predict() for a multiclass model.

But in my opinion, that by itself isn't worth potentially breaking users' and reverse dependencies' code by changing this default.

I'd find it more convincing if {lightgbm} and {xgboost} were unique in their handling of multiclass predictions like this and if other popular R libraries produce an (n_samples, n_classes) matrix by default.

Could you please investigate and summarize here the behavior of all the libraries listed at #4295 (comment) and any others that you think are similarly popular? That'd be a stronger case for potentially changing this default.

@mayer79
Copy link
Contributor Author

mayer79 commented Oct 8, 2023

Thx, James, for the hint and the post of @david-cortes

I'd say that LightGBM and XGBoost are quite unique in this respect. I have added examples for the three major meta-learners, and some other packages. {gbm} I have not tested because it discourages multiclassificion.

Meta-learner: caret -> matrix

library(caret)

set.seed(1)

fit <- train(
  Species ~ ., 
  data = iris, 
  method = "ranger",
  trControl = trainControl(classProbs = TRUE)
)
predict(fit, head(iris, 3), type = "prob")
#   setosa versicolor virginica
# 1      1          0         0
# 2      1          0         0
# 3      1          0         0

Meta-learner {mlr3} -> matrix

library(mlr3)
library(mlr3learners)

set.seed(1)

task_iris <- TaskClassif$new(id = "class", backend = iris, target = "Species")
fit_rf <- lrn("classif.ranger", predict_type = "prob")
fit_rf$train(task_iris)
predict(fit_rf, head(iris, 3), predict_type = "prob")
#      setosa versicolor virginica
# [1,] 1.0000      0e+00         0
# [2,] 0.9995      5e-04         0
# [3,] 1.0000      0e+00         0

Meta-learner: {Tidymodels} -> tibble

library(tidymodels)

set.seed(1)

iris_recipe <- iris %>%
  recipe(Species ~ .)

reg <- rand_forest() %>%
  set_mode("classification")

iris_wf <- workflow() %>%
  add_recipe(iris_recipe) %>%
  add_model(reg)

fit <- iris_wf %>%
  fit(iris)

predict(fit, head(iris, 3), type = "prob")
# A tibble: 3 × 3
# .pred_setosa .pred_versicolor .pred_virginica
# <dbl>            <dbl>           <dbl>
#   1         1            0                      0
#   2         1.00         0.000222               0
#   3         1            0                      0

{ranger} -> list with matrix attached

fit <- ranger::ranger(Species ~ ., data = iris, probability = TRUE)
predict(fit, head(iris, 3))$predictions  # matrix

#      setosa versicolor virginica
#[1,] 1.0000     0.0000     0e+00
#[2,] 0.9984     0.0012     4e-04
#[3,] 1.0000     0.0000     0e+00

{randomForest} -> matrix

fit <- randomForest::randomForest(Species ~ ., data = iris)
predict(fit, head(iris, 3), type = "prob")

# setosa versicolor virginica
# 1      1          0         0
# 2      1          0         0
# 3      1          0         0
# attr(,"class")
# [1] "matrix" "array"  "votes" 

{glmnet} -> 3D array (one slice per penalty)

fit <- glmnet::glmnet(x = iris[1:4], y = iris[, 5], data = iris, family = "multinomial")
predict(fit, data.matrix(iris[1:3, 1:4]), type = "response", s = 0)

# , , 1
# 
# setosa   versicolor    virginica
# 1 0.9999824 1.758585e-05 9.562213e-31
# 2 0.9998478 1.521641e-04 3.072000e-28
# 3 0.9999785 2.150941e-05 8.188653e-30

@david-cortes
Copy link
Contributor

@mayer79 There is no reshape parameter in the latest version of lightgbm.

@jameslamb
Copy link
Collaborator

there is no reshape parameter in the latest version of lightgbm

Ah right! Forgot about #4971, and that we've already been over all of this 🙈 : #4971 (review)

Sorry for making you put together that list of how other packages do this @mayer79.

On the development version of {lightgbm}...

sh build-cran-package.sh --no-build-vignettes
R CMD INSTALL ./lightgbm_4.1.0.99.tar.gz

... predict() already does what you're asking for for multiclass classification.

library(lightgbm)

X_pred <- data.matrix(iris[, -5])
dtrain <- lgb.Dataset(X_pred, label = as.integer(iris[, 5]) - 1)

bst <- lgb.train(
    params = list(
        objective = "multiclass"
        , num_class = 3
        , min_data_in_bin = 1
        , min_data_in_leaf = 1
    )
    , data = dtrain
    , nrounds = 10
)

predict(bst, head(X_pred, 2))
#           [,1]     [,2]     [,3]
# [1,] 0.7995459 0.100227 0.100227
# [2,] 0.7995459 0.100227 0.100227

So you'll get the default behavior you're asking for if/when we're able to get {lightgbm} v4.x to CRAN. You can subscribe to #5987 to follow the progress on that.

@mayer79
Copy link
Contributor Author

mayer79 commented Oct 9, 2023

Mr. @david-cortes and Mr. @jameslamb are fixing issues faster than the light of speed... thx for implementing and for the patience :-).

Copy link

github-actions bot commented Oct 9, 2024

This issue has been automatically locked since there has not been any recent activity since it was closed. To start a new related discussion, open a new issue at https://github.com/microsoft/LightGBM/issues including a reference to this.

@github-actions github-actions bot locked as resolved and limited conversation to collaborators Oct 9, 2024
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

No branches or pull requests

3 participants