From 34b7484288746985d0e0186ff4d9a880130ecf93 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Wed, 22 Dec 2021 09:47:48 -0600 Subject: [PATCH] [R-package] fix CVBooster reset_parameter() method (fixes #4900) (#4901) * [R-package] fix CVBooster reset_parameter() method (fixes #4900) * make it clear that there should be one booster per fold --- R-package/R/lgb.cv.R | 4 ++-- R-package/tests/testthat/test_basic.R | 25 +++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/R-package/R/lgb.cv.R b/R-package/R/lgb.cv.R index 29d8f9c104ff..06e646ada3dd 100644 --- a/R-package/R/lgb.cv.R +++ b/R-package/R/lgb.cv.R @@ -12,8 +12,8 @@ CVBooster <- R6::R6Class( return(invisible(NULL)) }, reset_parameter = function(new_params) { - for (x in boosters) { - x$reset_parameter(params = new_params) + for (x in self$boosters) { + x[["booster"]]$reset_parameter(params = new_params) } return(invisible(self)) } diff --git a/R-package/tests/testthat/test_basic.R b/R-package/tests/testthat/test_basic.R index 6ef2c333b35d..c16d73baf8df 100644 --- a/R-package/tests/testthat/test_basic.R +++ b/R-package/tests/testthat/test_basic.R @@ -380,6 +380,31 @@ test_that("cv works", { expect_false(is.null(bst$record_evals)) }) +test_that("CVBooster$reset_parameter() works as expected", { + dtrain <- lgb.Dataset(train$data, label = train$label) + n_folds <- 2L + cv_bst <- lgb.cv( + params = list( + objective = "regression" + , min_data = 1L + , num_leaves = 7L + , verbose = VERBOSITY + ) + , data = dtrain + , nrounds = 3L + , nfold = n_folds + ) + expect_is(cv_bst, "lgb.CVBooster") + expect_length(cv_bst$boosters, n_folds) + for (bst in cv_bst$boosters) { + expect_equal(bst[["booster"]]$params[["num_leaves"]], 7L) + } + cv_bst$reset_parameter(list(num_leaves = 11L)) + for (bst in cv_bst$boosters) { + expect_equal(bst[["booster"]]$params[["num_leaves"]], 11L) + } +}) + test_that("lgb.cv() rejects negative or 0 value passed to nrounds", { dtrain <- lgb.Dataset(train$data, label = train$label) params <- list(