From 44e61e6066a540b561f9cfcb7818937d372e229f Mon Sep 17 00:00:00 2001 From: david-cortes Date: Sat, 22 Jul 2023 20:54:36 +0200 Subject: [PATCH 1/2] fix error when passing categorical features to lightgbm() --- R-package/R/lgb.Dataset.R | 6 +++++- R-package/R/lightgbm.R | 13 ++++++++++++- R-package/tests/testthat/test_basic.R | 15 +++++++++++++++ 3 files changed, 32 insertions(+), 2 deletions(-) diff --git a/R-package/R/lgb.Dataset.R b/R-package/R/lgb.Dataset.R index e2892ea4bae0..b9f4038abd9d 100644 --- a/R-package/R/lgb.Dataset.R +++ b/R-package/R/lgb.Dataset.R @@ -313,7 +313,7 @@ Dataset <- R6::R6Class( # Should we free raw data? if (isTRUE(private$free_raw_data)) { - private$raw_data <- NULL + self$drop_raw_data() } # Get private information @@ -692,6 +692,10 @@ Dataset <- R6::R6Class( , path.expand(fname) ) return(invisible(self)) + }, + + drop_raw_data = function() { + private$raw_data <- NULL } ), diff --git a/R-package/R/lightgbm.R b/R-package/R/lightgbm.R index cb3ef31e8afa..8df5dfa4d542 100644 --- a/R-package/R/lightgbm.R +++ b/R-package/R/lightgbm.R @@ -218,7 +218,13 @@ lightgbm <- function(data, # Check whether data is lgb.Dataset, if not then create lgb.Dataset manually if (!lgb.is.Dataset(x = dtrain)) { - dtrain <- lgb.Dataset(data = data, label = label, weight = weights, init_score = init_score) + dtrain <- lgb.Dataset( + data = data + , label = label + , weight = weights + , init_score = init_score + , free_raw_data = FALSE + ) } train_args <- list( @@ -246,6 +252,11 @@ lightgbm <- function(data, ) bst$data_processor <- data_processor + # Since the dataset got passed 'free_raw_data = FALSE', need to do the step manually + if (!lgb.is.Dataset(data)) { + bst$.__enclos_env__$private$train_set$drop_raw_data() + } + return(bst) } diff --git a/R-package/tests/testthat/test_basic.R b/R-package/tests/testthat/test_basic.R index b0253b1e488e..80391089786b 100644 --- a/R-package/tests/testthat/test_basic.R +++ b/R-package/tests/testthat/test_basic.R @@ -3773,3 +3773,18 @@ test_that("lightgbm() model predictions retain factor levels for binary classifi expect_true(is.numeric(pred)) expect_false(any(pred %in% y)) }) + +test_that("lightgbm() accepts named categorical_features", { + data(mtcars) + y <- mtcars$mpg + x <- as.matrix(mtcars[, -1L]) + model <- lightgbm( + x + , y + , categorical_feature = "cyl" + , verbose = .LGB_VERBOSITY + , nrounds = 5L + , num_threads = .LGB_MAX_THREADS + ) + expect_true(length(model$params$categorical_feature) > 0L) +}) From ed0650f8328651562f3c5d315df3c164b82696ba Mon Sep 17 00:00:00 2001 From: david-cortes Date: Mon, 24 Jul 2023 19:43:57 +0200 Subject: [PATCH 2/2] remove workaround, switch set_categorical_feature to before call to dataset construct --- R-package/R/lgb.Dataset.R | 6 +----- R-package/R/lgb.train.R | 8 +++----- R-package/R/lightgbm.R | 13 +------------ 3 files changed, 5 insertions(+), 22 deletions(-) diff --git a/R-package/R/lgb.Dataset.R b/R-package/R/lgb.Dataset.R index b9f4038abd9d..e2892ea4bae0 100644 --- a/R-package/R/lgb.Dataset.R +++ b/R-package/R/lgb.Dataset.R @@ -313,7 +313,7 @@ Dataset <- R6::R6Class( # Should we free raw data? if (isTRUE(private$free_raw_data)) { - self$drop_raw_data() + private$raw_data <- NULL } # Get private information @@ -692,10 +692,6 @@ Dataset <- R6::R6Class( , path.expand(fname) ) return(invisible(self)) - }, - - drop_raw_data = function() { - private$raw_data <- NULL } ), diff --git a/R-package/R/lgb.train.R b/R-package/R/lgb.train.R index 4260f81cd3fe..20916c9844b5 100644 --- a/R-package/R/lgb.train.R +++ b/R-package/R/lgb.train.R @@ -154,6 +154,9 @@ lgb.train <- function(params = list(), # Construct datasets, if needed data$update_params(params = params) + if (!is.null(categorical_feature)) { + data$set_categorical_feature(categorical_feature) + } data$construct() # Check interaction constraints @@ -179,11 +182,6 @@ lgb.train <- function(params = list(), data$set_colnames(colnames) } - # Write categorical features - if (!is.null(categorical_feature)) { - data$set_categorical_feature(categorical_feature) - } - valid_contain_train <- FALSE train_data_name <- "train" reduced_valid_sets <- list() diff --git a/R-package/R/lightgbm.R b/R-package/R/lightgbm.R index 8df5dfa4d542..cb3ef31e8afa 100644 --- a/R-package/R/lightgbm.R +++ b/R-package/R/lightgbm.R @@ -218,13 +218,7 @@ lightgbm <- function(data, # Check whether data is lgb.Dataset, if not then create lgb.Dataset manually if (!lgb.is.Dataset(x = dtrain)) { - dtrain <- lgb.Dataset( - data = data - , label = label - , weight = weights - , init_score = init_score - , free_raw_data = FALSE - ) + dtrain <- lgb.Dataset(data = data, label = label, weight = weights, init_score = init_score) } train_args <- list( @@ -252,11 +246,6 @@ lightgbm <- function(data, ) bst$data_processor <- data_processor - # Since the dataset got passed 'free_raw_data = FALSE', need to do the step manually - if (!lgb.is.Dataset(data)) { - bst$.__enclos_env__$private$train_set$drop_raw_data() - } - return(bst) }