From 4cc29b969a1b28b4d4123a52b27675883cdf8480 Mon Sep 17 00:00:00 2001 From: David Cortes Date: Thu, 7 Apr 2022 21:26:28 +0200 Subject: [PATCH 01/14] switch to single prediction type argument --- R-package/NAMESPACE | 1 + R-package/R/lgb.Booster.R | 75 +++++++++++++------- R-package/demo/boost_from_prediction.R | 4 +- R-package/demo/leaf_stability.R | 6 +- R-package/demo/multiclass.R | 4 +- R-package/demo/multiclass_custom_objective.R | 4 +- R-package/man/predict.lgb.Booster.Rd | 39 ++++++---- R-package/tests/testthat/test_Predictor.R | 28 ++++---- R-package/tests/testthat/test_basic.R | 4 +- R-package/tests/testthat/test_lgb.Booster.R | 8 +-- 10 files changed, 105 insertions(+), 68 deletions(-) diff --git a/R-package/NAMESPACE b/R-package/NAMESPACE index 33152e33bc1d..790fc89c1a5d 100644 --- a/R-package/NAMESPACE +++ b/R-package/NAMESPACE @@ -54,6 +54,7 @@ importFrom(jsonlite,fromJSON) importFrom(methods,is) importFrom(parallel,detectCores) importFrom(stats,quantile) +importFrom(utils,head) importFrom(utils,modifyList) importFrom(utils,read.delim) useDynLib(lib_lightgbm , .registration = TRUE) diff --git a/R-package/R/lgb.Booster.R b/R-package/R/lgb.Booster.R index ac9f2404b606..ffb7cd56f810 100644 --- a/R-package/R/lgb.Booster.R +++ b/R-package/R/lgb.Booster.R @@ -713,6 +713,22 @@ Booster <- R6::R6Class( #' @param object Object of class \code{lgb.Booster} #' @param newdata a \code{matrix} object, a \code{dgCMatrix} object or #' a character representing a path to a text file (CSV, TSV, or LibSVM) +#' @param type Type of prediction to output. Allowed types are:\itemize{ +#' \item \code{"link"}: will output the predicted score according to the objective function being +#' optimized (depending on the link function that the objective uses), after applying any necessary +#' transformations - for example, for \code{objective="binary"}, it will output class probabilities. +#' \item \code{"response"}: for classification objectives, will output the class with the highest predicted +#' probability. For other objectives, will output the same as "link". +#' \item \code{"raw"}: will output the non-transformed numbers (sum of predictions from boosting iterations' +#' results) from which the "link" number is produced for a given objective function - for example, for +#' \code{objective="binary"}, this corresponds to log-odds. For many objectives such as "regression", +#' since no transformation is applied, the output will be the same as for "link". +#' \item \code{"leaf"}: will output the index of the terminal node / leaf at which each observations falls in +#' each tree in the model, outputted as as integers, with one column per tree. +#' \item \code{"contrib"}: will return the per-feature contributions for each prediction, including an intercept +#' (each feature will produce one column). If there are multiple classes, each class will have separate +#' feature contributions (thus the number of columns is feaures+1 multiplied by the number of classes). +#' } #' @param start_iteration int or None, optional (default=None) #' Start index of the iteration to predict. #' If None or <= 0, starts from the first iteration. @@ -721,22 +737,19 @@ Booster <- R6::R6Class( #' If None, if the best iteration exists and start_iteration is None or <= 0, the #' best iteration is used; otherwise, all iterations from start_iteration are used. #' If <= 0, all iterations from start_iteration are used (no limits). -#' @param rawscore whether the prediction should be returned in the for of original untransformed -#' sum of predictions from boosting iterations' results. E.g., setting \code{rawscore=TRUE} -#' for logistic regression would result in predictions for log-odds instead of probabilities. -#' @param predleaf whether predict leaf index instead. -#' @param predcontrib return per-feature contributions for each record. #' @param header only used for prediction for text file. True if text file has header #' @param params a list of additional named parameters. See #' \href{https://lightgbm.readthedocs.io/en/latest/Parameters.html#predict-parameters}{ #' the "Predict Parameters" section of the documentation} for a list of parameters and #' valid values. #' @param ... ignored -#' @return For regression or binary classification, it returns a vector of length \code{nrows(data)}. -#' For multiclass classification, it returns a matrix of dimensions \code{(nrows(data), num_class)}. +#' @return For prediction types that are meant to always return one output per observation (e.g. when predicting +#' \code{type="link"} on a binary classification or regression objective), will return a vector with one +#' row per observation in \code{newdata}. #' -#' When passing \code{predleaf=TRUE} or \code{predcontrib=TRUE}, the output will always be -#' returned as a matrix. +#' For prediction types that are meant to return more than one output per observation (e.g. when predicting +#' \code{type="link"} on a multi-class objective, or when predicting \code{type="leaf"}, regardless of +#' objective), will return a matrix with one row per observation in \code{newdata} and one column per output. #' #' @examples #' \donttest{ @@ -770,15 +783,13 @@ Booster <- R6::R6Class( #' ) #' ) #' } -#' @importFrom utils modifyList +#' @importFrom utils modifyList head #' @export predict.lgb.Booster <- function(object, newdata, + type = c("link", "response", "raw", "leaf", "contrib"), start_iteration = NULL, num_iteration = NULL, - rawscore = FALSE, - predleaf = FALSE, - predcontrib = FALSE, header = FALSE, params = list(), ...) { @@ -799,18 +810,34 @@ predict.lgb.Booster <- function(object, )) } - return( - object$predict( - data = newdata - , start_iteration = start_iteration - , num_iteration = num_iteration - , rawscore = rawscore - , predleaf = predleaf - , predcontrib = predcontrib - , header = header - , params = params - ) + type <- head(type, 1L) + rawscore <- FALSE + predleaf <- FALSE + predcontrib <- FALSE + if (type == "leaf") { + predleaf <- TRUE + } else if (type == "contrib") { + predcontrib <- TRUE + } + + pred <- object$predict( + data = newdata + , start_iteration = start_iteration + , num_iteration = num_iteration + , rawscore = rawscore + , predleaf = predleaf + , predcontrib = predcontrib + , header = header + , params = params ) + if (type == "response") { + if (object$params$objective == "binary") { + pred <- as.integer(pred >= 0.5) + } else if (object$params$objective %in% c("multiclass", "multiclassova")) { + pred <- max.col(pred) - 1L + } + } + return(pred) } #' @name print.lgb.Booster diff --git a/R-package/demo/boost_from_prediction.R b/R-package/demo/boost_from_prediction.R index b6b3f1ceba7b..5ee04f2d756e 100644 --- a/R-package/demo/boost_from_prediction.R +++ b/R-package/demo/boost_from_prediction.R @@ -22,8 +22,8 @@ param <- list( bst <- lgb.train(param, dtrain, 1L, valids = valids) # Note: we need the margin value instead of transformed prediction in set_init_score -ptrain <- predict(bst, agaricus.train$data, rawscore = TRUE) -ptest <- predict(bst, agaricus.test$data, rawscore = TRUE) +ptrain <- predict(bst, agaricus.train$data, type = "raw") +ptest <- predict(bst, agaricus.test$data, type = "raw") # set the init_score property of dtrain and dtest # base margin is the base prediction we will boost from diff --git a/R-package/demo/leaf_stability.R b/R-package/demo/leaf_stability.R index af1c533ac5b1..0733f31c3f87 100644 --- a/R-package/demo/leaf_stability.R +++ b/R-package/demo/leaf_stability.R @@ -111,7 +111,7 @@ new_data <- data.frame( X = rowMeans(predict( model , agaricus.test$data - , predleaf = TRUE + , type = "leaf" )) , Y = pmin( pmax( @@ -162,7 +162,7 @@ new_data2 <- data.frame( X = rowMeans(predict( model2 , agaricus.test$data - , predleaf = TRUE + , type = "leaf" )) , Y = pmin( pmax( @@ -218,7 +218,7 @@ new_data3 <- data.frame( X = rowMeans(predict( model3 , agaricus.test$data - , predleaf = TRUE + , type = "leaf" )) , Y = pmin( pmax( diff --git a/R-package/demo/multiclass.R b/R-package/demo/multiclass.R index afc7a4086b98..35441ccec983 100644 --- a/R-package/demo/multiclass.R +++ b/R-package/demo/multiclass.R @@ -64,7 +64,7 @@ my_preds <- predict(model, test[, 1L:4L]) my_preds <- predict(model, test[, 1L:4L]) # We can also get the predicted scores before the Sigmoid/Softmax application -my_preds <- predict(model, test[, 1L:4L], rawscore = TRUE) +my_preds <- predict(model, test[, 1L:4L], type = "raw") # We can also get the leaf index -my_preds <- predict(model, test[, 1L:4L], predleaf = TRUE) +my_preds <- predict(model, test[, 1L:4L], type = "leaf") diff --git a/R-package/demo/multiclass_custom_objective.R b/R-package/demo/multiclass_custom_objective.R index ebc3e2bbdeb2..09bdd322179c 100644 --- a/R-package/demo/multiclass_custom_objective.R +++ b/R-package/demo/multiclass_custom_objective.R @@ -36,7 +36,7 @@ model_builtin <- lgb.train( , obj = "multiclass" ) -preds_builtin <- predict(model_builtin, test[, 1L:4L], rawscore = TRUE) +preds_builtin <- predict(model_builtin, test[, 1L:4L], type = "raw") probs_builtin <- exp(preds_builtin) / rowSums(exp(preds_builtin)) # Method 2 of training with custom objective function @@ -109,7 +109,7 @@ model_custom <- lgb.train( , eval = custom_multiclass_metric ) -preds_custom <- predict(model_custom, test[, 1L:4L], rawscore = TRUE) +preds_custom <- predict(model_custom, test[, 1L:4L], type = "raw") probs_custom <- exp(preds_custom) / rowSums(exp(preds_custom)) # compare predictions diff --git a/R-package/man/predict.lgb.Booster.Rd b/R-package/man/predict.lgb.Booster.Rd index d4ddfe0ff668..574cfab6496e 100644 --- a/R-package/man/predict.lgb.Booster.Rd +++ b/R-package/man/predict.lgb.Booster.Rd @@ -7,11 +7,9 @@ \method{predict}{lgb.Booster}( object, newdata, + type = c("link", "response", "raw", "leaf", "contrib"), start_iteration = NULL, num_iteration = NULL, - rawscore = FALSE, - predleaf = FALSE, - predcontrib = FALSE, header = FALSE, params = list(), ... @@ -23,6 +21,23 @@ \item{newdata}{a \code{matrix} object, a \code{dgCMatrix} object or a character representing a path to a text file (CSV, TSV, or LibSVM)} +\item{type}{Type of prediction to output. Allowed types are:\itemize{ +\item \code{"link"}: will output the predicted score according to the objective function being + optimized (depending on the link function that the objective uses), after applying any necessary + transformations - for example, for \code{objective="binary"}, it will output class probabilities. +\item \code{"response"}: for classification objectives, will output the class with the highest predicted + probability. For other objectives, will output the same as "link". +\item \code{"raw"}: will output the non-transformed numbers (sum of predictions from boosting iterations' + results) from which the "link" number is produced for a given objective function - for example, for + \code{objective="binary"}, this corresponds to log-odds. For many objectives such as "regression", + since no transformation is applied, the output will be the same as for "link". +\item \code{"leaf"}: will output the index of the terminal node / leaf at which each observations falls in + each tree in the model, outputted as as integers, with one column per tree. +\item \code{"contrib"}: will return the per-feature contributions for each prediction, including an intercept + (each feature will produce one column). If there are multiple classes, each class will have separate + feature contributions (thus the number of columns is feaures+1 multiplied by the number of classes). +}} + \item{start_iteration}{int or None, optional (default=None) Start index of the iteration to predict. If None or <= 0, starts from the first iteration.} @@ -33,14 +48,6 @@ If None, if the best iteration exists and start_iteration is None or <= 0, the best iteration is used; otherwise, all iterations from start_iteration are used. If <= 0, all iterations from start_iteration are used (no limits).} -\item{rawscore}{whether the prediction should be returned in the for of original untransformed -sum of predictions from boosting iterations' results. E.g., setting \code{rawscore=TRUE} -for logistic regression would result in predictions for log-odds instead of probabilities.} - -\item{predleaf}{whether predict leaf index instead.} - -\item{predcontrib}{return per-feature contributions for each record.} - \item{header}{only used for prediction for text file. True if text file has header} \item{params}{a list of additional named parameters. See @@ -51,11 +58,13 @@ valid values.} \item{...}{ignored} } \value{ -For regression or binary classification, it returns a vector of length \code{nrows(data)}. - For multiclass classification, it returns a matrix of dimensions \code{(nrows(data), num_class)}. +For prediction types that are meant to always return one output per observation (e.g. when predicting + \code{type="link"} on a binary classification or regression objective), will return a vector with one + row per observation in \code{newdata}. - When passing \code{predleaf=TRUE} or \code{predcontrib=TRUE}, the output will always be - returned as a matrix. + For prediction types that are meant to return more than one output per observation (e.g. when predicting + \code{type="link"} on a multi-class objective, or when predicting \code{type="leaf"}, regardless of + objective), will return a matrix with one row per observation in \code{newdata} and one column per output. } \description{ Predicted values based on class \code{lgb.Booster} diff --git a/R-package/tests/testthat/test_Predictor.R b/R-package/tests/testthat/test_Predictor.R index 19360daa688f..3213563c9c2a 100644 --- a/R-package/tests/testthat/test_Predictor.R +++ b/R-package/tests/testthat/test_Predictor.R @@ -81,8 +81,8 @@ test_that("start_iteration works correctly", { , early_stopping_rounds = 2L ) expect_true(lgb.is.Booster(bst)) - pred1 <- predict(bst, newdata = test$data, rawscore = TRUE) - pred_contrib1 <- predict(bst, test$data, predcontrib = TRUE) + pred1 <- predict(bst, newdata = test$data, type = "raw") + pred_contrib1 <- predict(bst, test$data, type = "contrib") pred2 <- rep(0.0, length(pred1)) pred_contrib2 <- rep(0.0, length(pred2)) step <- 11L @@ -96,7 +96,7 @@ test_that("start_iteration works correctly", { inc_pred <- predict(bst, test$data , start_iteration = start_iter , num_iteration = n_iter - , rawscore = TRUE + , type = "raw" ) inc_pred_contrib <- bst$predict(test$data , start_iteration = start_iter @@ -109,8 +109,8 @@ test_that("start_iteration works correctly", { expect_equal(pred2, pred1) expect_equal(pred_contrib2, pred_contrib1) - pred_leaf1 <- predict(bst, test$data, predleaf = TRUE) - pred_leaf2 <- predict(bst, test$data, start_iteration = 0L, num_iteration = end_iter + 1L, predleaf = TRUE) + pred_leaf1 <- predict(bst, test$data, type = "leaf") + pred_leaf2 <- predict(bst, test$data, start_iteration = 0L, num_iteration = end_iter + 1L, type = "leaf") expect_equal(pred_leaf1, pred_leaf2) }) @@ -139,11 +139,11 @@ test_that("start_iteration works correctly", { # dense matrix with row names pred <- predict(bst, X) .expect_has_row_names(pred, X) - pred <- predict(bst, X, rawscore = TRUE) + pred <- predict(bst, X, type = "raw") .expect_has_row_names(pred, X) - pred <- predict(bst, X, predleaf = TRUE) + pred <- predict(bst, X, type = "leaf") .expect_has_row_names(pred, X) - pred <- predict(bst, X, predcontrib = TRUE) + pred <- predict(bst, X, type = "contrib") .expect_has_row_names(pred, X) # dense matrix without row names @@ -156,11 +156,11 @@ test_that("start_iteration works correctly", { Xcsc <- as(X, "CsparseMatrix") pred <- predict(bst, Xcsc) .expect_has_row_names(pred, Xcsc) - pred <- predict(bst, Xcsc, rawscore = TRUE) + pred <- predict(bst, Xcsc, type = "raw") .expect_has_row_names(pred, Xcsc) - pred <- predict(bst, Xcsc, predleaf = TRUE) + pred <- predict(bst, Xcsc, type = "leaf") .expect_has_row_names(pred, Xcsc) - pred <- predict(bst, Xcsc, predcontrib = TRUE) + pred <- predict(bst, Xcsc, type = "contrib") .expect_has_row_names(pred, Xcsc) # sparse matrix without row names @@ -245,7 +245,7 @@ test_that("predictions for regression and binary classification are returned as pred <- predict(model, X) expect_true(is.vector(pred)) expect_equal(length(pred), nrow(X)) - pred <- predict(model, X, rawscore = TRUE) + pred <- predict(model, X, type = "raw") expect_true(is.vector(pred)) expect_equal(length(pred), nrow(X)) @@ -262,7 +262,7 @@ test_that("predictions for regression and binary classification are returned as pred <- predict(model, X) expect_true(is.vector(pred)) expect_equal(length(pred), nrow(X)) - pred <- predict(model, X, rawscore = TRUE) + pred <- predict(model, X, type = "raw") expect_true(is.vector(pred)) expect_equal(length(pred), nrow(X)) }) @@ -283,7 +283,7 @@ test_that("predictions for multiclass classification are returned as matrix", { expect_true(is.matrix(pred)) expect_equal(nrow(pred), nrow(X)) expect_equal(ncol(pred), 3L) - pred <- predict(model, X, rawscore = TRUE) + pred <- predict(model, X, type = "raw") expect_true(is.matrix(pred)) expect_equal(nrow(pred), nrow(X)) expect_equal(ncol(pred), 3L) diff --git a/R-package/tests/testthat/test_basic.R b/R-package/tests/testthat/test_basic.R index e59139c01a9e..a1313b0aaa51 100644 --- a/R-package/tests/testthat/test_basic.R +++ b/R-package/tests/testthat/test_basic.R @@ -2898,7 +2898,7 @@ test_that("lightgbm() accepts init_score as function argument", { , nrounds = 5L , verbose = -1L ) - pred1 <- predict(bst1, train$data, rawscore = TRUE) + pred1 <- predict(bst1, train$data, type = "raw") bst2 <- lightgbm( data = train$data @@ -2908,7 +2908,7 @@ test_that("lightgbm() accepts init_score as function argument", { , nrounds = 5L , verbose = -1L ) - pred2 <- predict(bst2, train$data, rawscore = TRUE) + pred2 <- predict(bst2, train$data, type = "raw") expect_true(any(pred1 != pred2)) }) diff --git a/R-package/tests/testthat/test_lgb.Booster.R b/R-package/tests/testthat/test_lgb.Booster.R index a5f609b682a5..98348efaddce 100644 --- a/R-package/tests/testthat/test_lgb.Booster.R +++ b/R-package/tests/testthat/test_lgb.Booster.R @@ -293,8 +293,8 @@ test_that("Saving a large model to string should work", { ) pred <- predict(bst, train$data) - pred_leaf_indx <- predict(bst, train$data, predleaf = TRUE) - pred_raw_score <- predict(bst, train$data, rawscore = TRUE) + pred_leaf_indx <- predict(bst, train$data, type = "leaf") + pred_raw_score <- predict(bst, train$data, type = "raw") model_string <- bst$save_model_to_string() # make sure this test is still producing a model bigger than the default @@ -312,8 +312,8 @@ test_that("Saving a large model to string should work", { model_str = model_string ) pred2 <- predict(bst2, train$data) - pred2_leaf_indx <- predict(bst2, train$data, predleaf = TRUE) - pred2_raw_score <- predict(bst2, train$data, rawscore = TRUE) + pred2_leaf_indx <- predict(bst2, train$data, type = "leaf") + pred2_raw_score <- predict(bst2, train$data, type = "raw") expect_identical(pred, pred2) expect_identical(pred_leaf_indx, pred2_leaf_indx) expect_identical(pred_raw_score, pred2_raw_score) From 7195b2e79e0da9a8bc2586cecfdd1cf2d77db81c Mon Sep 17 00:00:00 2001 From: David Cortes Date: Thu, 7 Apr 2022 21:44:16 +0200 Subject: [PATCH 02/14] linter --- R-package/R/lgb.Booster.R | 11 ++++++----- R-package/man/predict.lgb.Booster.Rd | 11 ++++++----- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/R-package/R/lgb.Booster.R b/R-package/R/lgb.Booster.R index ffb7cd56f810..a0da565c27bb 100644 --- a/R-package/R/lgb.Booster.R +++ b/R-package/R/lgb.Booster.R @@ -723,11 +723,12 @@ Booster <- R6::R6Class( #' results) from which the "link" number is produced for a given objective function - for example, for #' \code{objective="binary"}, this corresponds to log-odds. For many objectives such as "regression", #' since no transformation is applied, the output will be the same as for "link". -#' \item \code{"leaf"}: will output the index of the terminal node / leaf at which each observations falls in -#' each tree in the model, outputted as as integers, with one column per tree. -#' \item \code{"contrib"}: will return the per-feature contributions for each prediction, including an intercept -#' (each feature will produce one column). If there are multiple classes, each class will have separate -#' feature contributions (thus the number of columns is feaures+1 multiplied by the number of classes). +#' \item \code{"leaf"}: will output the index of the terminal node / leaf at which each observations falls +#' in each tree in the model, outputted as as integers, with one column per tree. +#' \item \code{"contrib"}: will return the per-feature contributions for each prediction, including an +#' intercept (each feature will produce one column). If there are multiple classes, each class will +#' have separate feature contributions (thus the number of columns is feaures+1 multiplied by the +#' number of classes). #' } #' @param start_iteration int or None, optional (default=None) #' Start index of the iteration to predict. diff --git a/R-package/man/predict.lgb.Booster.Rd b/R-package/man/predict.lgb.Booster.Rd index 574cfab6496e..1931603ec97e 100644 --- a/R-package/man/predict.lgb.Booster.Rd +++ b/R-package/man/predict.lgb.Booster.Rd @@ -31,11 +31,12 @@ a character representing a path to a text file (CSV, TSV, or LibSVM)} results) from which the "link" number is produced for a given objective function - for example, for \code{objective="binary"}, this corresponds to log-odds. For many objectives such as "regression", since no transformation is applied, the output will be the same as for "link". -\item \code{"leaf"}: will output the index of the terminal node / leaf at which each observations falls in - each tree in the model, outputted as as integers, with one column per tree. -\item \code{"contrib"}: will return the per-feature contributions for each prediction, including an intercept - (each feature will produce one column). If there are multiple classes, each class will have separate - feature contributions (thus the number of columns is feaures+1 multiplied by the number of classes). +\item \code{"leaf"}: will output the index of the terminal node / leaf at which each observations falls + in each tree in the model, outputted as as integers, with one column per tree. +\item \code{"contrib"}: will return the per-feature contributions for each prediction, including an + intercept (each feature will produce one column). If there are multiple classes, each class will + have separate feature contributions (thus the number of columns is feaures+1 multiplied by the + number of classes). }} \item{start_iteration}{int or None, optional (default=None) From 398f41a6155c62c7bbcbbe72774c8bbf1a5d936b Mon Sep 17 00:00:00 2001 From: David Cortes Date: Thu, 7 Apr 2022 21:54:17 +0200 Subject: [PATCH 03/14] missing piece of code --- R-package/R/lgb.Booster.R | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/R-package/R/lgb.Booster.R b/R-package/R/lgb.Booster.R index a0da565c27bb..138c81cb035b 100644 --- a/R-package/R/lgb.Booster.R +++ b/R-package/R/lgb.Booster.R @@ -815,7 +815,9 @@ predict.lgb.Booster <- function(object, rawscore <- FALSE predleaf <- FALSE predcontrib <- FALSE - if (type == "leaf") { + if (type == "raw") { + rawscore <- TRUE + } else if (type == "leaf") { predleaf <- TRUE } else if (type == "contrib") { predcontrib <- TRUE From ed3668633df9bc6f13e3c164bc749db28d7cf9ca Mon Sep 17 00:00:00 2001 From: David Cortes Date: Sun, 5 Jun 2022 14:20:19 +0200 Subject: [PATCH 04/14] comments --- R-package/DESCRIPTION | 2 +- R-package/NAMESPACE | 1 - R-package/R/lgb.Booster.R | 34 ++++++++++++++---- R-package/man/predict.lgb.Booster.Rd | 37 ++++++++++--------- R-package/tests/testthat/test_Predictor.R | 44 +++++++++++++++++++++++ 5 files changed, 93 insertions(+), 25 deletions(-) diff --git a/R-package/DESCRIPTION b/R-package/DESCRIPTION index 7efb865f33e8..eb4ca2404f63 100755 --- a/R-package/DESCRIPTION +++ b/R-package/DESCRIPTION @@ -66,4 +66,4 @@ Imports: utils SystemRequirements: C++11 -RoxygenNote: 7.1.2 +RoxygenNote: 7.2.0 diff --git a/R-package/NAMESPACE b/R-package/NAMESPACE index 790fc89c1a5d..33152e33bc1d 100644 --- a/R-package/NAMESPACE +++ b/R-package/NAMESPACE @@ -54,7 +54,6 @@ importFrom(jsonlite,fromJSON) importFrom(methods,is) importFrom(parallel,detectCores) importFrom(stats,quantile) -importFrom(utils,head) importFrom(utils,modifyList) importFrom(utils,read.delim) useDynLib(lib_lightgbm , .registration = TRUE) diff --git a/R-package/R/lgb.Booster.R b/R-package/R/lgb.Booster.R index 138c81cb035b..399979bdd117 100644 --- a/R-package/R/lgb.Booster.R +++ b/R-package/R/lgb.Booster.R @@ -724,12 +724,15 @@ Booster <- R6::R6Class( #' \code{objective="binary"}, this corresponds to log-odds. For many objectives such as "regression", #' since no transformation is applied, the output will be the same as for "link". #' \item \code{"leaf"}: will output the index of the terminal node / leaf at which each observations falls -#' in each tree in the model, outputted as as integers, with one column per tree. +#' in each tree in the model, outputted as integers, with one column per tree. #' \item \code{"contrib"}: will return the per-feature contributions for each prediction, including an #' intercept (each feature will produce one column). If there are multiple classes, each class will #' have separate feature contributions (thus the number of columns is feaures+1 multiplied by the #' number of classes). #' } +#' +#' Note that, if using custom objectives, types "link" and "response" will not be available and will +#' default towards using "raw" instead. #' @param start_iteration int or None, optional (default=None) #' Start index of the iteration to predict. #' If None or <= 0, starts from the first iteration. @@ -784,11 +787,11 @@ Booster <- R6::R6Class( #' ) #' ) #' } -#' @importFrom utils modifyList head +#' @importFrom utils modifyList #' @export predict.lgb.Booster <- function(object, newdata, - type = c("link", "response", "raw", "leaf", "contrib"), + type = "link", start_iteration = NULL, num_iteration = NULL, header = FALSE, @@ -801,17 +804,36 @@ predict.lgb.Booster <- function(object, additional_params <- list(...) if (length(additional_params) > 0L) { - if ("reshape" %in% names(additional_params)) { + additional_params_names <- names(additional_params) + if ("reshape" %in% additional_params_names) { stop("'reshape' argument is no longer supported.") } + + old_args_for_type <- list( + "rawscore" = "raw" + , "predleaf" = "leaf" + , "predcontrib" = "contrib" + ) + for (arg in names(old_args_for_type)) { + if (arg %in% additional_params_names) { + stop(sprintf("Argument '%s' is no longer supported. Use type='%s' instead." + , arg + , old_args_for_type[[arg]])) + } + } + warning(paste0( "predict.lgb.Booster: Found the following passed through '...': " - , paste(names(additional_params), collapse = ", ") + , paste(additional_params_names, collapse = ", ") , ". These are ignored. Use argument 'params' instead." )) } - type <- head(type, 1L) + if (object$params$objective == "none" && type %in% c("link", "response")) { + warning("Prediction types 'link' and 'response' are not supported for custom objectives.") + type <- "raw" + } + rawscore <- FALSE predleaf <- FALSE predcontrib <- FALSE diff --git a/R-package/man/predict.lgb.Booster.Rd b/R-package/man/predict.lgb.Booster.Rd index 1931603ec97e..78f96082d62c 100644 --- a/R-package/man/predict.lgb.Booster.Rd +++ b/R-package/man/predict.lgb.Booster.Rd @@ -7,7 +7,7 @@ \method{predict}{lgb.Booster}( object, newdata, - type = c("link", "response", "raw", "leaf", "contrib"), + type = "link", start_iteration = NULL, num_iteration = NULL, header = FALSE, @@ -22,22 +22,25 @@ a character representing a path to a text file (CSV, TSV, or LibSVM)} \item{type}{Type of prediction to output. Allowed types are:\itemize{ -\item \code{"link"}: will output the predicted score according to the objective function being - optimized (depending on the link function that the objective uses), after applying any necessary - transformations - for example, for \code{objective="binary"}, it will output class probabilities. -\item \code{"response"}: for classification objectives, will output the class with the highest predicted - probability. For other objectives, will output the same as "link". -\item \code{"raw"}: will output the non-transformed numbers (sum of predictions from boosting iterations' - results) from which the "link" number is produced for a given objective function - for example, for - \code{objective="binary"}, this corresponds to log-odds. For many objectives such as "regression", - since no transformation is applied, the output will be the same as for "link". -\item \code{"leaf"}: will output the index of the terminal node / leaf at which each observations falls - in each tree in the model, outputted as as integers, with one column per tree. -\item \code{"contrib"}: will return the per-feature contributions for each prediction, including an - intercept (each feature will produce one column). If there are multiple classes, each class will - have separate feature contributions (thus the number of columns is feaures+1 multiplied by the - number of classes). -}} + \item \code{"link"}: will output the predicted score according to the objective function being + optimized (depending on the link function that the objective uses), after applying any necessary + transformations - for example, for \code{objective="binary"}, it will output class probabilities. + \item \code{"response"}: for classification objectives, will output the class with the highest predicted + probability. For other objectives, will output the same as "link". + \item \code{"raw"}: will output the non-transformed numbers (sum of predictions from boosting iterations' + results) from which the "link" number is produced for a given objective function - for example, for + \code{objective="binary"}, this corresponds to log-odds. For many objectives such as "regression", + since no transformation is applied, the output will be the same as for "link". + \item \code{"leaf"}: will output the index of the terminal node / leaf at which each observations falls + in each tree in the model, outputted as integers, with one column per tree. + \item \code{"contrib"}: will return the per-feature contributions for each prediction, including an + intercept (each feature will produce one column). If there are multiple classes, each class will + have separate feature contributions (thus the number of columns is feaures+1 multiplied by the + number of classes). + } + + Note that, if using custom objectives, types "link" and "response" will not be available and will + default towards using "raw" instead.} \item{start_iteration}{int or None, optional (default=None) Start index of the iteration to predict. diff --git a/R-package/tests/testthat/test_Predictor.R b/R-package/tests/testthat/test_Predictor.R index 3213563c9c2a..fa75b2ab62bb 100644 --- a/R-package/tests/testthat/test_Predictor.R +++ b/R-package/tests/testthat/test_Predictor.R @@ -288,3 +288,47 @@ test_that("predictions for multiclass classification are returned as matrix", { expect_equal(nrow(pred), nrow(X)) expect_equal(ncol(pred), 3L) }) + +test_that("predict type='response' returns integers for classification objectives", { + data(agaricus.train, package = "lightgbm") + X <- as.matrix(agaricus.train$data) + y <- agaricus.train$label + dtrain <- lgb.Dataset(X, label = y, params = list(max_bins = 5L)) + bst <- lgb.train( + data = dtrain + , obj = "binary" + , nrounds = 5L + , verbose = VERBOSITY + ) + pred <- predict(bst, X, type = "response") + expect_true(all(pred %in% c(0, 1))) + + data(iris) + X <- as.matrix(iris[, -5L]) + y <- as.numeric(iris$Species) - 1.0 + dtrain <- lgb.Dataset(X, label = y) + model <- lgb.train( + data = dtrain + , obj = "multiclass" + , nrounds = 5L + , verbose = VERBOSITY + , params = list(num_class = 3L) + ) + pred <- predict(model, X, type = "response") + expect_true(all(pred %in% c(0, 1, 2))) +}) + +test_that("predict type='response' returns decimals for regression objectives", { + data(agaricus.train, package = "lightgbm") + X <- as.matrix(agaricus.train$data) + y <- agaricus.train$label + dtrain <- lgb.Dataset(X, label = y, params = list(max_bins = 5L)) + bst <- lgb.train( + data = dtrain + , obj = "regression" + , nrounds = 5L + , verbose = VERBOSITY + ) + pred <- predict(bst, X, type = "response") + expect_true(all(!(pred %in% c(0, 1)))) +}) From d33b7a861876e48dc20c0f2b522f507c1a25f363 Mon Sep 17 00:00:00 2001 From: David Cortes Date: Sun, 5 Jun 2022 14:27:39 +0200 Subject: [PATCH 05/14] linter --- R-package/tests/testthat/test_Predictor.R | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/R-package/tests/testthat/test_Predictor.R b/R-package/tests/testthat/test_Predictor.R index fa75b2ab62bb..5bc0546a78ba 100644 --- a/R-package/tests/testthat/test_Predictor.R +++ b/R-package/tests/testthat/test_Predictor.R @@ -301,7 +301,7 @@ test_that("predict type='response' returns integers for classification objective , verbose = VERBOSITY ) pred <- predict(bst, X, type = "response") - expect_true(all(pred %in% c(0, 1))) + expect_true(all(pred %in% c(0L, 1L))) data(iris) X <- as.matrix(iris[, -5L]) @@ -315,7 +315,7 @@ test_that("predict type='response' returns integers for classification objective , params = list(num_class = 3L) ) pred <- predict(model, X, type = "response") - expect_true(all(pred %in% c(0, 1, 2))) + expect_true(all(pred %in% c(0L, 1L, 2L))) }) test_that("predict type='response' returns decimals for regression objectives", { @@ -330,5 +330,5 @@ test_that("predict type='response' returns decimals for regression objectives", , verbose = VERBOSITY ) pred <- predict(bst, X, type = "response") - expect_true(all(!(pred %in% c(0, 1)))) + expect_true(all(!(pred %in% c(0.0, 1.0)))) }) From 8b113c1dd87e89ca5818308f9e67e9c6f82aa2c0 Mon Sep 17 00:00:00 2001 From: David Cortes Date: Sun, 5 Jun 2022 14:40:55 +0200 Subject: [PATCH 06/14] fix test --- R-package/tests/testthat/test_Predictor.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R-package/tests/testthat/test_Predictor.R b/R-package/tests/testthat/test_Predictor.R index 5bc0546a78ba..260bd3365973 100644 --- a/R-package/tests/testthat/test_Predictor.R +++ b/R-package/tests/testthat/test_Predictor.R @@ -101,7 +101,7 @@ test_that("start_iteration works correctly", { inc_pred_contrib <- bst$predict(test$data , start_iteration = start_iter , num_iteration = n_iter - , predcontrib = TRUE + , type = "contrib" ) pred2 <- pred2 + inc_pred pred_contrib2 <- pred_contrib2 + inc_pred_contrib From 4838e29575ea1670d0d6a42449b49125adf398a7 Mon Sep 17 00:00:00 2001 From: David Cortes Date: Sun, 5 Jun 2022 14:57:01 +0200 Subject: [PATCH 07/14] revert incorrect 'fix' --- R-package/tests/testthat/test_Predictor.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R-package/tests/testthat/test_Predictor.R b/R-package/tests/testthat/test_Predictor.R index 260bd3365973..5bc0546a78ba 100644 --- a/R-package/tests/testthat/test_Predictor.R +++ b/R-package/tests/testthat/test_Predictor.R @@ -101,7 +101,7 @@ test_that("start_iteration works correctly", { inc_pred_contrib <- bst$predict(test$data , start_iteration = start_iter , num_iteration = n_iter - , type = "contrib" + , predcontrib = TRUE ) pred2 <- pred2 + inc_pred pred_contrib2 <- pred_contrib2 + inc_pred_contrib From 84ec681d861f4f3ae82d6d0392e4b67651f8c9b0 Mon Sep 17 00:00:00 2001 From: David Cortes Date: Sun, 5 Jun 2022 15:12:14 +0200 Subject: [PATCH 08/14] fix failing test --- R-package/R/lgb.Booster.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R-package/R/lgb.Booster.R b/R-package/R/lgb.Booster.R index 399979bdd117..3f9329da92d4 100644 --- a/R-package/R/lgb.Booster.R +++ b/R-package/R/lgb.Booster.R @@ -829,7 +829,7 @@ predict.lgb.Booster <- function(object, )) } - if (object$params$objective == "none" && type %in% c("link", "response")) { + if (!isnull(object$params$objective) && object$params$objective == "none" && type %in% c("link", "response")) { warning("Prediction types 'link' and 'response' are not supported for custom objectives.") type <- "raw" } From 9e0b83d1fea1c014271e1a219fc448b13dcdbed6 Mon Sep 17 00:00:00 2001 From: David Cortes Date: Sun, 5 Jun 2022 15:19:15 +0200 Subject: [PATCH 09/14] fix test again --- R-package/R/lgb.Booster.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R-package/R/lgb.Booster.R b/R-package/R/lgb.Booster.R index 3f9329da92d4..44ad37dab4fd 100644 --- a/R-package/R/lgb.Booster.R +++ b/R-package/R/lgb.Booster.R @@ -829,7 +829,7 @@ predict.lgb.Booster <- function(object, )) } - if (!isnull(object$params$objective) && object$params$objective == "none" && type %in% c("link", "response")) { + if (!is.null(object$params$objective) && object$params$objective == "none" && type %in% c("link", "response")) { warning("Prediction types 'link' and 'response' are not supported for custom objectives.") type <- "raw" } From acdd7154c7de5f63918a9ac6db8e3b0a26f8ad62 Mon Sep 17 00:00:00 2001 From: David Cortes Date: Sun, 5 Jun 2022 15:46:19 +0200 Subject: [PATCH 10/14] modify recently introduced tests after changes here --- R-package/tests/testthat/test_Predictor.R | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/R-package/tests/testthat/test_Predictor.R b/R-package/tests/testthat/test_Predictor.R index 26af29ed6ec9..2936d5672b24 100644 --- a/R-package/tests/testthat/test_Predictor.R +++ b/R-package/tests/testthat/test_Predictor.R @@ -140,7 +140,7 @@ test_that("predict() params should override keyword argument for raw-score predi # check that the predictions from predict.lgb.Booster() really look like raw score predictions preds_prob <- predict(bst, X) - preds_raw_s3_keyword <- predict(bst, X, rawscore = TRUE) + preds_raw_s3_keyword <- predict(bst, X, type = "raw") preds_prob_from_raw <- 1.0 / (1.0 + exp(-preds_raw_s3_keyword)) expect_equal(preds_prob, preds_prob_from_raw, tolerance = TOLERANCE) accuracy <- sum(as.integer(preds_prob_from_raw > 0.5) == y) / length(y) @@ -160,9 +160,7 @@ test_that("predict() params should override keyword argument for raw-score predi , nm = rawscore_alias ) ) - preds_raw_s3_param <- predict(bst, X, params = params) preds_raw_r6_param <- bst$predict(X, params = params) - expect_equal(preds_raw_s3_keyword, preds_raw_s3_param) expect_equal(preds_raw_s3_keyword, preds_raw_r6_param) } }) @@ -190,7 +188,7 @@ test_that("predict() params should override keyword argument for leaf-index pred ) # check that predictions really look like leaf index predictions - preds_leaf_s3_keyword <- predict(bst, X, predleaf = TRUE) + preds_leaf_s3_keyword <- predict(bst, X, type = "leaf") expect_true(is.matrix(preds_leaf_s3_keyword)) expect_equal(dim(preds_leaf_s3_keyword), c(nrow(X), bst$current_iter())) expect_true(min(preds_leaf_s3_keyword) >= 0L) @@ -213,9 +211,7 @@ test_that("predict() params should override keyword argument for leaf-index pred , nm = predleaf_alias ) ) - preds_leaf_s3_param <- predict(bst, X, params = params) preds_leaf_r6_param <- bst$predict(X, params = params) - expect_equal(preds_leaf_s3_keyword, preds_leaf_s3_param) expect_equal(preds_leaf_s3_keyword, preds_leaf_r6_param) } }) @@ -243,7 +239,7 @@ test_that("predict() params should override keyword argument for feature contrib ) # check that predictions really look like feature contributions - preds_contrib_s3_keyword <- predict(bst, X, predcontrib = TRUE) + preds_contrib_s3_keyword <- predict(bst, X, type = "contrib") num_features <- ncol(X) shap_base_value <- unname(preds_contrib_s3_keyword[, ncol(preds_contrib_s3_keyword)]) expect_true(is.matrix(preds_contrib_s3_keyword)) @@ -266,9 +262,7 @@ test_that("predict() params should override keyword argument for feature contrib , nm = predcontrib_alias ) ) - preds_contrib_s3_param <- predict(bst, X, params = params) preds_contrib_r6_param <- bst$predict(X, params = params) - expect_equal(preds_contrib_s3_keyword, preds_contrib_s3_param) expect_equal(preds_contrib_s3_keyword, preds_contrib_r6_param) } }) From 0288e6e088737b028b84fdd0e6937658523a79af Mon Sep 17 00:00:00 2001 From: David Cortes Date: Tue, 14 Jun 2022 21:14:34 +0200 Subject: [PATCH 11/14] rename prediction types --- R-package/R/lgb.Booster.R | 30 +++++++++++------------ R-package/man/predict.lgb.Booster.Rd | 24 +++++++++--------- R-package/tests/testthat/test_Predictor.R | 10 ++++---- 3 files changed, 32 insertions(+), 32 deletions(-) diff --git a/R-package/R/lgb.Booster.R b/R-package/R/lgb.Booster.R index c83f69a7bfe2..98b3f5b67008 100644 --- a/R-package/R/lgb.Booster.R +++ b/R-package/R/lgb.Booster.R @@ -743,24 +743,24 @@ Booster <- R6::R6Class( #' @param newdata a \code{matrix} object, a \code{dgCMatrix} object or #' a character representing a path to a text file (CSV, TSV, or LibSVM) #' @param type Type of prediction to output. Allowed types are:\itemize{ -#' \item \code{"link"}: will output the predicted score according to the objective function being +#' \item \code{"response"}: will output the predicted score according to the objective function being #' optimized (depending on the link function that the objective uses), after applying any necessary #' transformations - for example, for \code{objective="binary"}, it will output class probabilities. -#' \item \code{"response"}: for classification objectives, will output the class with the highest predicted -#' probability. For other objectives, will output the same as "link". +#' \item \code{"class"}: for classification objectives, will output the class with the highest predicted +#' probability. For other objectives, will output the same as "response". #' \item \code{"raw"}: will output the non-transformed numbers (sum of predictions from boosting iterations' -#' results) from which the "link" number is produced for a given objective function - for example, for -#' \code{objective="binary"}, this corresponds to log-odds. For many objectives such as "regression", -#' since no transformation is applied, the output will be the same as for "link". +#' results) from which the "response" number is produced for a given objective function - for example, +#' for \code{objective="binary"}, this corresponds to log-odds. For many objectives such as +#' "regression", since no transformation is applied, the output will be the same as for "response". #' \item \code{"leaf"}: will output the index of the terminal node / leaf at which each observations falls #' in each tree in the model, outputted as integers, with one column per tree. #' \item \code{"contrib"}: will return the per-feature contributions for each prediction, including an #' intercept (each feature will produce one column). If there are multiple classes, each class will -#' have separate feature contributions (thus the number of columns is feaures+1 multiplied by the +#' have separate feature contributions (thus the number of columns is features+1 multiplied by the #' number of classes). #' } #' -#' Note that, if using custom objectives, types "link" and "response" will not be available and will +#' Note that, if using custom objectives, types "class" and "response" will not be available and will #' default towards using "raw" instead. #' @param start_iteration int or None, optional (default=None) #' Start index of the iteration to predict. @@ -778,11 +778,11 @@ Booster <- R6::R6Class( #' the values in \code{params} take precedence. #' @param ... ignored #' @return For prediction types that are meant to always return one output per observation (e.g. when predicting -#' \code{type="link"} on a binary classification or regression objective), will return a vector with one -#' row per observation in \code{newdata}. +#' \code{type="response"} on a binary classification or regression objective), will return a vector with one +#' element per row in \code{newdata}. #' #' For prediction types that are meant to return more than one output per observation (e.g. when predicting -#' \code{type="link"} on a multi-class objective, or when predicting \code{type="leaf"}, regardless of +#' \code{type="response"} on a multi-class objective, or when predicting \code{type="leaf"}, regardless of #' objective), will return a matrix with one row per observation in \code{newdata} and one column per output. #' #' @examples @@ -821,7 +821,7 @@ Booster <- R6::R6Class( #' @export predict.lgb.Booster <- function(object, newdata, - type = "link", + type = "response", start_iteration = NULL, num_iteration = NULL, header = FALSE, @@ -859,8 +859,8 @@ predict.lgb.Booster <- function(object, )) } - if (!is.null(object$params$objective) && object$params$objective == "none" && type %in% c("link", "response")) { - warning("Prediction types 'link' and 'response' are not supported for custom objectives.") + if (!is.null(object$params$objective) && object$params$objective == "none" && type %in% c("class", "response")) { + warning("Prediction types 'class' and 'response' are not supported for custom objectives.") type <- "raw" } @@ -885,7 +885,7 @@ predict.lgb.Booster <- function(object, , header = header , params = params ) - if (type == "response") { + if (type == "class") { if (object$params$objective == "binary") { pred <- as.integer(pred >= 0.5) } else if (object$params$objective %in% c("multiclass", "multiclassova")) { diff --git a/R-package/man/predict.lgb.Booster.Rd b/R-package/man/predict.lgb.Booster.Rd index 29149e04d551..7d9734d9181f 100644 --- a/R-package/man/predict.lgb.Booster.Rd +++ b/R-package/man/predict.lgb.Booster.Rd @@ -7,7 +7,7 @@ \method{predict}{lgb.Booster}( object, newdata, - type = "link", + type = "response", start_iteration = NULL, num_iteration = NULL, header = FALSE, @@ -22,24 +22,24 @@ a character representing a path to a text file (CSV, TSV, or LibSVM)} \item{type}{Type of prediction to output. Allowed types are:\itemize{ - \item \code{"link"}: will output the predicted score according to the objective function being + \item \code{"response"}: will output the predicted score according to the objective function being optimized (depending on the link function that the objective uses), after applying any necessary transformations - for example, for \code{objective="binary"}, it will output class probabilities. - \item \code{"response"}: for classification objectives, will output the class with the highest predicted - probability. For other objectives, will output the same as "link". + \item \code{"class"}: for classification objectives, will output the class with the highest predicted + probability. For other objectives, will output the same as "response". \item \code{"raw"}: will output the non-transformed numbers (sum of predictions from boosting iterations' - results) from which the "link" number is produced for a given objective function - for example, for - \code{objective="binary"}, this corresponds to log-odds. For many objectives such as "regression", - since no transformation is applied, the output will be the same as for "link". + results) from which the "response" number is produced for a given objective function - for example, + for \code{objective="binary"}, this corresponds to log-odds. For many objectives such as + "regression", since no transformation is applied, the output will be the same as for "response". \item \code{"leaf"}: will output the index of the terminal node / leaf at which each observations falls in each tree in the model, outputted as integers, with one column per tree. \item \code{"contrib"}: will return the per-feature contributions for each prediction, including an intercept (each feature will produce one column). If there are multiple classes, each class will - have separate feature contributions (thus the number of columns is feaures+1 multiplied by the + have separate feature contributions (thus the number of columns is features+1 multiplied by the number of classes). } - Note that, if using custom objectives, types "link" and "response" will not be available and will + Note that, if using custom objectives, types "class" and "response" will not be available and will default towards using "raw" instead.} \item{start_iteration}{int or None, optional (default=None) @@ -64,11 +64,11 @@ the values in \code{params} take precedence.} } \value{ For prediction types that are meant to always return one output per observation (e.g. when predicting - \code{type="link"} on a binary classification or regression objective), will return a vector with one - row per observation in \code{newdata}. + \code{type="response"} on a binary classification or regression objective), will return a vector with one + element per row in \code{newdata}. For prediction types that are meant to return more than one output per observation (e.g. when predicting - \code{type="link"} on a multi-class objective, or when predicting \code{type="leaf"}, regardless of + \code{type="response"} on a multi-class objective, or when predicting \code{type="leaf"}, regardless of objective), will return a matrix with one row per observation in \code{newdata} and one column per output. } \description{ diff --git a/R-package/tests/testthat/test_Predictor.R b/R-package/tests/testthat/test_Predictor.R index 2936d5672b24..1a9e5f3ab796 100644 --- a/R-package/tests/testthat/test_Predictor.R +++ b/R-package/tests/testthat/test_Predictor.R @@ -442,7 +442,7 @@ test_that("predictions for multiclass classification are returned as matrix", { expect_equal(ncol(pred), 3L) }) -test_that("predict type='response' returns integers for classification objectives", { +test_that("predict type='class' returns predicted class for classification objectives", { data(agaricus.train, package = "lightgbm") X <- as.matrix(agaricus.train$data) y <- agaricus.train$label @@ -453,7 +453,7 @@ test_that("predict type='response' returns integers for classification objective , nrounds = 5L , verbose = VERBOSITY ) - pred <- predict(bst, X, type = "response") + pred <- predict(bst, X, type = "class") expect_true(all(pred %in% c(0L, 1L))) data(iris) @@ -467,11 +467,11 @@ test_that("predict type='response' returns integers for classification objective , verbose = VERBOSITY , params = list(num_class = 3L) ) - pred <- predict(model, X, type = "response") + pred <- predict(model, X, type = "class") expect_true(all(pred %in% c(0L, 1L, 2L))) }) -test_that("predict type='response' returns decimals for regression objectives", { +test_that("predict type='class' returns values in the target's range for regression objectives", { data(agaricus.train, package = "lightgbm") X <- as.matrix(agaricus.train$data) y <- agaricus.train$label @@ -482,6 +482,6 @@ test_that("predict type='response' returns decimals for regression objectives", , nrounds = 5L , verbose = VERBOSITY ) - pred <- predict(bst, X, type = "response") + pred <- predict(bst, X, type = "class") expect_true(all(!(pred %in% c(0.0, 1.0)))) }) From 3c0dc29a0d5dbcdba773ebec162831fd3fb466bf Mon Sep 17 00:00:00 2001 From: David Cortes Date: Mon, 20 Jun 2022 20:35:09 +0200 Subject: [PATCH 12/14] rebase --- R-package/tests/testthat/test_Predictor.R | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/R-package/tests/testthat/test_Predictor.R b/R-package/tests/testthat/test_Predictor.R index fc5efad94f73..94a475ab9136 100644 --- a/R-package/tests/testthat/test_Predictor.R +++ b/R-package/tests/testthat/test_Predictor.R @@ -131,20 +131,20 @@ test_that("Feature contributions from sparse inputs produce sparse outputs", { , params = list(min_data_in_leaf = 5L) ) - pred_dense <- predict(bst, X, predcontrib = TRUE) + pred_dense <- predict(bst, X, type = "contrib") Xcsc <- as(X, "CsparseMatrix") - pred_csc <- predict(bst, Xcsc, predcontrib = TRUE) + pred_csc <- predict(bst, Xcsc, type = "contrib") expect_s4_class(pred_csc, "dgCMatrix") expect_equal(unname(pred_dense), unname(as.matrix(pred_csc))) Xcsr <- as(X, "RsparseMatrix") - pred_csr <- predict(bst, Xcsr, predcontrib = TRUE) + pred_csr <- predict(bst, Xcsr, type = "contrib") expect_s4_class(pred_csr, "dgRMatrix") expect_equal(as(pred_csr, "CsparseMatrix"), pred_csc) Xspv <- as(X[1L, , drop = FALSE], "sparseVector") - pred_spv <- predict(bst, Xspv, predcontrib = TRUE) + pred_spv <- predict(bst, Xspv, type = "contrib") expect_s4_class(pred_spv, "dsparseVector") expect_equal(Matrix::t(as(pred_spv, "CsparseMatrix")), unname(pred_csc[1L, , drop = FALSE])) }) @@ -164,14 +164,14 @@ test_that("Sparse feature contribution predictions do not take inputs with wrong X_wrong <- X[, c(1L:10L, 1L:10L)] X_wrong <- as(X_wrong, "CsparseMatrix") - expect_error(predict(bst, X_wrong, predcontrib = TRUE), regexp = "input data has 20 columns") + expect_error(predict(bst, X_wrong, type = "contrib"), regexp = "input data has 20 columns") X_wrong <- as(X_wrong, "RsparseMatrix") - expect_error(predict(bst, X_wrong, predcontrib = TRUE), regexp = "input data has 20 columns") + expect_error(predict(bst, X_wrong, type = "contrib"), regexp = "input data has 20 columns") X_wrong <- as(X_wrong, "CsparseMatrix") X_wrong <- X_wrong[, 1L:3L] - expect_error(predict(bst, X_wrong, predcontrib = TRUE), regexp = "input data has 3 columns") + expect_error(predict(bst, X_wrong, type = "contrib"), regexp = "input data has 3 columns") }) test_that("Feature contribution predictions do not take non-general CSR or CSC inputs", { @@ -192,8 +192,8 @@ test_that("Feature contribution predictions do not take non-general CSR or CSC i , params = list(min_data_in_leaf = 5L) ) - expect_error(predict(bst, SmatC, predcontrib = TRUE)) - expect_error(predict(bst, SmatR, predcontrib = TRUE)) + expect_error(predict(bst, SmatC, type = "contrib")) + expect_error(predict(bst, SmatR, type = "contrib")) }) test_that("predict() params should override keyword argument for raw-score predictions", { @@ -395,7 +395,7 @@ test_that("predict() params should override keyword argument for feature contrib .expect_has_row_names(pred, Xcsc) pred <- predict(bst, Xcsc, type = "contrib") .expect_has_row_names(pred, Xcsc) - pred <- predict(bst, as(Xcsc, "RsparseMatrix"), predcontrib = TRUE) + pred <- predict(bst, as(Xcsc, "RsparseMatrix"), type = "contrib") .expect_has_row_names(pred, Xcsc) # sparse matrix without row names From 59b977677d992ee1a3f0167432ecb143f7e0ab87 Mon Sep 17 00:00:00 2001 From: David Cortes Date: Wed, 22 Jun 2022 20:04:39 +0200 Subject: [PATCH 13/14] restore tests for prediction type in params --- R-package/tests/testthat/test_Predictor.R | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/R-package/tests/testthat/test_Predictor.R b/R-package/tests/testthat/test_Predictor.R index 94a475ab9136..7ebc7e213724 100644 --- a/R-package/tests/testthat/test_Predictor.R +++ b/R-package/tests/testthat/test_Predictor.R @@ -6,8 +6,6 @@ VERBOSITY <- as.integer( TOLERANCE <- 1e-6 -library(Matrix) - test_that("Predictor$finalize() should not fail", { X <- as.matrix(as.integer(iris[, "Species"]), ncol = 1L) y <- iris[["Sepal.Length"]] @@ -240,7 +238,9 @@ test_that("predict() params should override keyword argument for raw-score predi , nm = rawscore_alias ) ) + preds_raw_s3_param <- predict(bst, X, params = params) preds_raw_r6_param <- bst$predict(X, params = params) + expect_equal(preds_raw_s3_keyword, preds_raw_s3_param) expect_equal(preds_raw_s3_keyword, preds_raw_r6_param) } }) @@ -291,7 +291,9 @@ test_that("predict() params should override keyword argument for leaf-index pred , nm = predleaf_alias ) ) + preds_leaf_s3_param <- predict(bst, X, params = params) preds_leaf_r6_param <- bst$predict(X, params = params) + expect_equal(preds_leaf_s3_keyword, preds_leaf_s3_param) expect_equal(preds_leaf_s3_keyword, preds_leaf_r6_param) } }) @@ -342,7 +344,9 @@ test_that("predict() params should override keyword argument for feature contrib , nm = predcontrib_alias ) ) + preds_contrib_s3_param <- predict(bst, X, params = params) preds_contrib_r6_param <- bst$predict(X, params = params) + expect_equal(preds_contrib_s3_keyword, preds_contrib_s3_param) expect_equal(preds_contrib_s3_keyword, preds_contrib_r6_param) } }) From 2d2bb38eaf0bda9193f3b6cc1ce048b44523fb48 Mon Sep 17 00:00:00 2001 From: david-cortes Date: Wed, 22 Jun 2022 21:06:38 +0300 Subject: [PATCH 14/14] Update R-package/tests/testthat/test_Predictor.R Co-authored-by: James Lamb --- R-package/tests/testthat/test_Predictor.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R-package/tests/testthat/test_Predictor.R b/R-package/tests/testthat/test_Predictor.R index 7ebc7e213724..875b24ba0dc1 100644 --- a/R-package/tests/testthat/test_Predictor.R +++ b/R-package/tests/testthat/test_Predictor.R @@ -569,5 +569,5 @@ test_that("predict type='class' returns values in the target's range for regress , verbose = VERBOSITY ) pred <- predict(bst, X, type = "class") - expect_true(all(!(pred %in% c(0.0, 1.0)))) + expect_true(!any(pred %in% c(0.0, 1.0))) })