Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

return metric score along with untrained best models & params #822

Merged
17 changes: 10 additions & 7 deletions darts/models/forecasting/forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ def gridsearch(
verbose=False,
n_jobs: int = 1,
n_random_samples: Optional[Union[int, float]] = None,
) -> Tuple["ForecastingModel", Dict]:
) -> Tuple["ForecastingModel", Dict[str, Any], float]:
"""
Find the best hyper-parameters among a given set using a grid search.

Expand Down Expand Up @@ -660,9 +660,10 @@ def gridsearch(

Returns
-------
ForecastingModel, Dict
ForecastingModel, Dict, float
A tuple containing an untrained `model_class` instance created from the best-performing hyper-parameters,
along with a dictionary containing these best hyper-parameters.
along with a dictionary containing these best hyper-parameters,
and metric score for the best hyper-parameters.
"""
raise_if_not(
(forecast_horizon is not None)
Expand Down Expand Up @@ -707,7 +708,7 @@ def gridsearch(
zip(params_cross_product), verbose, total=len(params_cross_product)
)

def _evaluate_combination(param_combination):
def _evaluate_combination(param_combination) -> float:
param_combination_dict = dict(
list(zip(parameters.keys(), param_combination))
)
Expand Down Expand Up @@ -748,9 +749,11 @@ def _evaluate_combination(param_combination):
)
error = metric(pred, val_series)

return error
return float(error)

errors = _parallel_apply(iterator, _evaluate_combination, n_jobs, {}, {})
errors: List[float] = _parallel_apply(
iterator, _evaluate_combination, n_jobs, {}, {}
)

min_error = min(errors)

Expand All @@ -760,7 +763,7 @@ def _evaluate_combination(param_combination):

logger.info("Chosen parameters: " + str(best_param_combination))

return model_class(**best_param_combination), best_param_combination
return model_class(**best_param_combination), best_param_combination, min_error

def residuals(
self, series: TimeSeries, forecast_horizon: int = 1, verbose: bool = False
Expand Down
2 changes: 1 addition & 1 deletion darts/tests/models/forecasting/test_4theta.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_best_model(self):
series = sine_series + linear_series
train_series, val_series = series.split_before(series.time_index[-10])
thetas = np.linspace(-3, 3, 30)
best_model, _ = FourTheta.select_best_model(train_series, thetas)
best_model, _, _ = FourTheta.select_best_model(train_series, thetas)
model = FourTheta(
random.choice(thetas),
model_mode=random.choice(list(ModelMode)),
Expand Down
76 changes: 46 additions & 30 deletions darts/tests/models/forecasting/test_backtesting.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,21 @@
TORCH_AVAILABLE = False


def get_dummy_series(
ts_length: int, lt_end_value: int = 10, st_value_offset: int = 10
) -> TimeSeries:
return (
lt(length=ts_length, end_value=lt_end_value)
+ st(length=ts_length, value_y_offset=st_value_offset)
+ rt(length=ts_length)
)


def compare_best_against_random(model_class, params, series, stride=1):

# instantiate best model in expanding window mode
np.random.seed(1)
best_model_1, _ = model_class.gridsearch(
best_model_1, _, _ = model_class.gridsearch(
params,
series,
forecast_horizon=10,
Expand All @@ -57,7 +67,9 @@ def compare_best_against_random(model_class, params, series, stride=1):

# instantiate best model in split mode
train, val = series.split_before(series.time_index[-10])
best_model_2, _ = model_class.gridsearch(params, train, val_series=val, metric=mape)
best_model_2, _, _ = model_class.gridsearch(
params, train, val_series=val, metric=mape
)

# intantiate model with random parameters from 'params'
random.seed(1)
Expand Down Expand Up @@ -297,12 +309,7 @@ def test_backtest_regression(self):
def test_gridsearch(self):
np.random.seed(1)

ts_length = 50
dummy_series = (
lt(length=ts_length, end_value=10)
+ st(length=ts_length, value_y_offset=10)
+ rt(length=ts_length)
)
dummy_series = get_dummy_series(ts_length=50)
dummy_series_int_index = TimeSeries.from_values(dummy_series.values())

theta_params = {"theta": list(range(3, 10))}
Expand All @@ -322,16 +329,34 @@ def test_gridsearch(self):
compare_best_against_random(ExponentialSmoothing, es_params, dummy_series)
)

def test_gridsearch_metric_score(self):
np.random.seed(1)

model_class = Theta
params = {"theta": list(range(3, 6))}
dummy_series = get_dummy_series(ts_length=50)

best_model, _, score = model_class.gridsearch(
params,
series=dummy_series,
forecast_horizon=10,
stride=1,
start=dummy_series.time_index[-21],
)
recalculated_score = best_model.backtest(
series=dummy_series,
start=dummy_series.time_index[-21],
forecast_horizon=10,
stride=1,
)

self.assertEqual(score, recalculated_score, "The metric scores should match")

@unittest.skipUnless(TORCH_AVAILABLE, "requires torch")
def test_gridsearch_random_search(self):
np.random.seed(1)

ts_length = 50
dummy_series = (
lt(length=ts_length, end_value=10)
+ st(length=ts_length, value_y_offset=10)
+ rt(length=ts_length)
)
dummy_series = get_dummy_series(ts_length=50)

param_range = list(range(10, 20))
params = {"lags": param_range}
Expand All @@ -343,16 +368,12 @@ def test_gridsearch_random_search(self):

self.assertEqual(type(result[0]), RandomForest)
self.assertEqual(type(result[1]["lags"]), int)
self.assertEqual(type(result[2]), float)
self.assertTrue(min(param_range) <= result[1]["lags"] <= max(param_range))

@unittest.skipUnless(TORCH_AVAILABLE, "requires torch")
def test_gridsearch_n_random_samples_bad_arguments(self):
ts_length = 50
dummy_series = (
lt(length=ts_length, end_value=10)
+ st(length=ts_length, value_y_offset=10)
+ rt(length=ts_length)
)
dummy_series = get_dummy_series(ts_length=50)

params = {"lags": list(range(1, 11)), "past_covariates": list(range(1, 11))}

Expand Down Expand Up @@ -398,16 +419,11 @@ def test_gridsearch_n_jobs(self):
"""

np.random.seed(1)
ts_length = 100

dummy_series = (
lt(length=ts_length, end_value=1)
+ st(length=ts_length, value_y_offset=0)
+ rt(length=ts_length)
dummy_series = get_dummy_series(
ts_length=100, lt_end_value=1, st_value_offset=0
).astype(np.float32)

ts_train = dummy_series[: round(ts_length * 0.8)]
ts_val = dummy_series[round(ts_length * 0.8) :]
ts_train, ts_val = dummy_series.split_before(split_point=0.8)

test_cases = [
{
Expand All @@ -433,12 +449,12 @@ def test_gridsearch_n_jobs(self):
parameters = test["parameters"]

np.random.seed(1)
_, best_params1 = model.gridsearch(
_, best_params1, _ = model.gridsearch(
parameters=parameters, series=ts_train, val_series=ts_val, n_jobs=1
)

np.random.seed(1)
_, best_params2 = model.gridsearch(
_, best_params2, _ = model.gridsearch(
parameters=parameters, series=ts_train, val_series=ts_val, n_jobs=-1
)

Expand Down