Skip to content

Commit

Permalink
Kmedved/master (#268)
Browse files Browse the repository at this point in the history
* Add HistGradientBoosting-Like Early Stopping

* Updates in response to Pylint, and Flake8 Suggestions

* fix linting and line length issues

* bump version

Co-authored-by: kmedved <[email protected]>
  • Loading branch information
ryan-wolbeck and kmedved authored Jul 30, 2021
1 parent cfc067b commit 51a2991
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 1 deletion.
11 changes: 11 additions & 0 deletions ngboost/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ class NGBRegressor(NGBoost, BaseEstimator):
tol : numerical tolerance to be used in optimization
random_state : seed for reproducibility. See
https://stackoverflow.com/questions/28064634/random-state-pseudo-random-number-in-scikit-learn
validation_fraction: Proportion of training data to set
aside as validation data for early stopping.
early_stopping_rounds: The number of consecutive boosting iterations during which the
loss has to increase before the algorithm stops early.
Set to None to disable early stopping and validation.
None enables running over the full data set.
Output:
An NGBRegressor object that can be fit.
"""
Expand All @@ -61,6 +68,8 @@ def __init__(
verbose_eval=100,
tol=1e-4,
random_state=None,
validation_fraction=0.1,
early_stopping_rounds=None,
):
assert issubclass(
Dist, RegressionDistn
Expand All @@ -84,6 +93,8 @@ def __init__(
verbose_eval,
tol,
random_state,
validation_fraction,
early_stopping_rounds,
)

def __getstate__(self):
Expand Down
37 changes: 37 additions & 0 deletions ngboost/ngboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# pylint: disable=redundant-keyword-arg,protected-access
import numpy as np
from sklearn.base import clone
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeRegressor
from sklearn.utils import check_array, check_random_state, check_X_y

Expand Down Expand Up @@ -38,6 +39,13 @@ class NGBoost:
tol : numerical tolerance to be used in optimization
random_state : seed for reproducibility.
See https://stackoverflow.com/questions/28064634/random-state-pseudo-random-number-in-scikit-learn
validation_fraction: Proportion of training data to set aside as validation data for early stopping.
early_stopping_rounds: The number of consecutive boosting iterations during which the
loss has to increase before the algorithm stops early.
Set to None to disable early stopping and validation.
None enables running over the full data set.
Output:
An NGBRegressor object that can be fit.
"""
Expand All @@ -56,6 +64,8 @@ def __init__(
verbose_eval=100,
tol=1e-4,
random_state=None,
validation_fraction=0.1,
early_stopping_rounds=None,
):
self.Dist = Dist
self.Score = Score
Expand All @@ -76,6 +86,9 @@ def __init__(
self.tol = tol
self.random_state = check_random_state(random_state)
self.best_val_loss_itr = None
self.validation_fraction = validation_fraction
self.early_stopping_rounds = early_stopping_rounds

if hasattr(self.Dist, "multi_output"):
self.multi_output = self.Dist.multi_output
else:
Expand Down Expand Up @@ -227,6 +240,29 @@ def fit(
A fit NGBRegressor object
"""

# if early stopping is specified, split X,Y and sample weights (if given) into training and validation sets
# This will overwrite any X_val and Y_val values passed by the user directly.
if self.early_stopping_rounds is not None:

early_stopping_rounds = self.early_stopping_rounds

if sample_weight is None:
X, X_val, Y, Y_val = train_test_split(
X,
Y,
test_size=self.validation_fraction,
random_state=self.random_state,
)

else:
X, X_val, Y, Y_val, sample_weight, val_sample_weight = train_test_split(
X,
Y,
sample_weight,
test_size=self.validation_fraction,
random_state=self.random_state,
)

if Y is None:
raise ValueError("y cannot be None")

Expand All @@ -240,6 +276,7 @@ def fit(
self.fit_init_params_to_marginal(Y)

params = self.pred_param(X)

if X_val is not None and Y_val is not None:
X_val, Y_val = check_X_y(
X_val,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "ngboost"
version = "0.3.12dev"
version = "0.3.13dev"
description = "Library for probabilistic predictions via gradient boosting."
authors = ["Stanford ML Group <[email protected]>"]
readme = "README.md"
Expand Down

0 comments on commit 51a2991

Please sign in to comment.