Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package] reorganize early stopping callback #6114

Merged
merged 8 commits into from
Oct 5, 2023
64 changes: 45 additions & 19 deletions python-package/lightgbm/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,12 @@ def __call__(self, env: CallbackEnv) -> None:
if new_param != env.params.get(key, None):
new_parameters[key] = new_param
if new_parameters:
env.model.reset_parameter(new_parameters)
if isinstance(env.model, Booster):
env.model.reset_parameter(new_parameters)
else:
# CVBooster holds a list of Booster objects, each needs to be updated
for i in range(len(env.model.boosters)):
env.model.boosters[i].reset_parameter(new_parameters)
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
env.params.update(new_parameters)


Expand Down Expand Up @@ -291,32 +296,49 @@ def _gt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
def _lt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
return curr_score < best_score - delta

def _is_train_set(self, ds_name: str, eval_name: str, train_name: str) -> bool:
return (ds_name == "cv_agg" and eval_name == "train") or ds_name == train_name
def _is_train_set(self, ds_name: str, eval_name: str, env: CallbackEnv) -> bool:
"""Check, by name, if a given Dataset is the training data."""
# for lgb.cv() with eval_train_metric=True, evaluation is also done on the training set
# and those metrics are considered for early stopping
if ds_name == "cv_agg" and eval_name == "train":
return True

# for lgb.train(), it's possible to pass the training data via valid_sets with any eval_name
if isinstance(env.model, Booster):
if ds_name == env.model._train_data_name:
return True
jameslamb marked this conversation as resolved.
Show resolved Hide resolved

return False

def _init(self, env: CallbackEnv) -> None:
if env.evaluation_result_list is None or env.evaluation_result_list == []:
raise ValueError(
"For early stopping, at least one dataset and eval metric is required for evaluation"
)

if self.stopping_rounds <= 0:
raise ValueError(f"stopping_rounds should be greater than zero. got: {self.stopping_rounds}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this doesn't need the env we could move it to __init__, WDYT?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree! Having it be a loud error right when the callback is created, instead of deferred all the way til the first iteration of training, seems useful. And I'd be surprised to learn that there are other libraries or user code depending on initializing lgb.early_stopping() with a negative value of this and then somehow updating the value before the first time it's called.

Moved into __init__() in bd3366a.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doing this broke this test:

def test_train_raises_informative_error_for_params_of_wrong_type():
X, y = make_synthetic_regression()
params = {"early_stopping_round": "too-many"}
dtrain = lgb.Dataset(X, label=y)
with pytest.raises(lgb.basic.LightGBMError, match="Parameter early_stopping_round should be of type int, got \"too-many\""):
lgb.train(params, dtrain)

Now the error from the early stopping callback gets thrown before this one from the C++ side:

Log::Fatal("Parameter %s should be of type int, got \"%s\"", key.c_str(), candidate);

So I pushed 7a98d82, which:

  • switches that test to use a different parameter, to keep covering that C++-side validation
  • adds a test in test_callback.py on this specific error from lgb.early_stopping()
  • adds an isinstance() check in the condition guarding that error in lgb.early_stopping(), so you can an informative error instead of something like TypeError: '<=' not supported between instances of 'str' and 'int'

Given all those changes, @jmoralez could you re-review? I don't want to sneak those in on your previous approval.


is_dart = any(env.params.get(alias, "") == 'dart' for alias in _ConfigAliases.get("boosting"))
only_train_set = (
len(env.evaluation_result_list) == 1
and self._is_train_set(
ds_name=env.evaluation_result_list[0][0],
eval_name=env.evaluation_result_list[0][1].split(" ")[0],
train_name=env.model._train_data_name)
)
self.enabled = not is_dart and not only_train_set
if not self.enabled:
if is_dart:
_log_warning('Early stopping is not available in dart mode')
elif only_train_set:
_log_warning('Only training set found, disabling early stopping.')
if is_dart:
self.enabled = False
_log_warning('Early stopping is not available in dart mode')
return

if self.stopping_rounds <= 0:
raise ValueError("stopping_rounds should be greater than zero.")
# validation sets are guaranteed to not be identical to the training data in cv()
if isinstance(env.model, Booster):
only_train_set = (
len(env.evaluation_result_list) == 1
and self._is_train_set(
ds_name=env.evaluation_result_list[0][0],
eval_name=env.evaluation_result_list[0][1].split(" ")[0],
env=env
)
)
if only_train_set:
self.enabled = False
_log_warning('Only training set found, disabling early stopping.')
return

if self.verbose:
_log_info(f"Training until validation scores don't improve for {self.stopping_rounds} rounds")
Expand Down Expand Up @@ -395,7 +417,11 @@ def __call__(self, env: CallbackEnv) -> None:
eval_name_splitted = env.evaluation_result_list[i][1].split(" ")
if self.first_metric_only and self.first_metric != eval_name_splitted[-1]:
continue # use only the first metric for early stopping
if self._is_train_set(env.evaluation_result_list[i][0], eval_name_splitted[0], env.model._train_data_name):
if self._is_train_set(
ds_name=env.evaluation_result_list[i][0],
eval_name=eval_name_splitted[0],
env=env
):
continue # train data for lgb.cv or sklearn wrapper (underlying lgb.train)
elif env.iteration - self.best_iter[i] >= self.stopping_rounds:
if self.verbose:
Expand Down