Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

return metric score along with untrained best models & params #822

Merged
17 changes: 10 additions & 7 deletions darts/models/forecasting/forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ def gridsearch(
verbose=False,
n_jobs: int = 1,
n_random_samples: Optional[Union[int, float]] = None,
) -> Tuple["ForecastingModel", Dict]:
) -> Tuple["ForecastingModel", Dict[str, Any], float]:
"""
Find the best hyper-parameters among a given set using a grid search.

Expand Down Expand Up @@ -660,9 +660,10 @@ def gridsearch(

Returns
-------
ForecastingModel, Dict
ForecastingModel, Dict, float
A tuple containing an untrained `model_class` instance created from the best-performing hyper-parameters,
along with a dictionary containing these best hyper-parameters.
along with a dictionary containing these best hyper-parameters,
and metric score for the best hyper-parameters.
"""
raise_if_not(
(forecast_horizon is not None)
Expand Down Expand Up @@ -707,7 +708,7 @@ def gridsearch(
zip(params_cross_product), verbose, total=len(params_cross_product)
)

def _evaluate_combination(param_combination):
def _evaluate_combination(param_combination) -> float:
param_combination_dict = dict(
list(zip(parameters.keys(), param_combination))
)
Expand Down Expand Up @@ -748,9 +749,11 @@ def _evaluate_combination(param_combination):
)
error = metric(pred, val_series)

return error
return float(error)

errors = _parallel_apply(iterator, _evaluate_combination, n_jobs, {}, {})
errors: List[float] = _parallel_apply(
iterator, _evaluate_combination, n_jobs, {}, {}
)

min_error = min(errors)

Expand All @@ -760,7 +763,7 @@ def _evaluate_combination(param_combination):

logger.info("Chosen parameters: " + str(best_param_combination))

return model_class(**best_param_combination), best_param_combination
return model_class(**best_param_combination), best_param_combination, min_error

def residuals(
self, series: TimeSeries, forecast_horizon: int = 1, verbose: bool = False
Expand Down
2 changes: 1 addition & 1 deletion darts/tests/models/forecasting/test_4theta.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_best_model(self):
series = sine_series + linear_series
train_series, val_series = series.split_before(series.time_index[-10])
thetas = np.linspace(-3, 3, 30)
best_model, _ = FourTheta.select_best_model(train_series, thetas)
best_model, _, _ = FourTheta.select_best_model(train_series, thetas)
model = FourTheta(
random.choice(thetas),
model_mode=random.choice(list(ModelMode)),
Expand Down
11 changes: 7 additions & 4 deletions darts/tests/models/forecasting/test_backtesting.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def compare_best_against_random(model_class, params, series, stride=1):

# instantiate best model in expanding window mode
np.random.seed(1)
best_model_1, _ = model_class.gridsearch(
best_model_1, _, _ = model_class.gridsearch(
params,
series,
forecast_horizon=10,
Expand All @@ -57,7 +57,9 @@ def compare_best_against_random(model_class, params, series, stride=1):

# instantiate best model in split mode
train, val = series.split_before(series.time_index[-10])
best_model_2, _ = model_class.gridsearch(params, train, val_series=val, metric=mape)
best_model_2, _, _ = model_class.gridsearch(
params, train, val_series=val, metric=mape
)

# intantiate model with random parameters from 'params'
random.seed(1)
Expand Down Expand Up @@ -343,6 +345,7 @@ def test_gridsearch_random_search(self):

self.assertEqual(type(result[0]), RandomForest)
self.assertEqual(type(result[1]["lags"]), int)
self.assertEqual(type(result[2]), float)
self.assertTrue(min(param_range) <= result[1]["lags"] <= max(param_range))

@unittest.skipUnless(TORCH_AVAILABLE, "requires torch")
Expand Down Expand Up @@ -433,12 +436,12 @@ def test_gridsearch_n_jobs(self):
parameters = test["parameters"]

np.random.seed(1)
_, best_params1 = model.gridsearch(
_, best_params1, _ = model.gridsearch(
parameters=parameters, series=ts_train, val_series=ts_val, n_jobs=1
)

np.random.seed(1)
_, best_params2 = model.gridsearch(
_, best_params2, _ = model.gridsearch(
parameters=parameters, series=ts_train, val_series=ts_val, n_jobs=-1
)

Expand Down