Skip to content

Commit cb80cf3

Browse files
khanetorhrzn
andauthored
return metric score along with untrained best models & params (#822)
* return metric score along with untrained best models & params * update test and metric type * fmt * gridsearch score unittest Co-authored-by: Julien Herzen <[email protected]>
1 parent 4920de0 commit cb80cf3

File tree

3 files changed

+57
-38
lines changed

3 files changed

+57
-38
lines changed

darts/models/forecasting/forecasting_model.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,7 @@ def gridsearch(
569569
verbose=False,
570570
n_jobs: int = 1,
571571
n_random_samples: Optional[Union[int, float]] = None,
572-
) -> Tuple["ForecastingModel", Dict]:
572+
) -> Tuple["ForecastingModel", Dict[str, Any], float]:
573573
"""
574574
Find the best hyper-parameters among a given set using a grid search.
575575
@@ -660,9 +660,10 @@ def gridsearch(
660660
661661
Returns
662662
-------
663-
ForecastingModel, Dict
663+
ForecastingModel, Dict, float
664664
A tuple containing an untrained `model_class` instance created from the best-performing hyper-parameters,
665-
along with a dictionary containing these best hyper-parameters.
665+
along with a dictionary containing these best hyper-parameters,
666+
and metric score for the best hyper-parameters.
666667
"""
667668
raise_if_not(
668669
(forecast_horizon is not None)
@@ -707,7 +708,7 @@ def gridsearch(
707708
zip(params_cross_product), verbose, total=len(params_cross_product)
708709
)
709710

710-
def _evaluate_combination(param_combination):
711+
def _evaluate_combination(param_combination) -> float:
711712
param_combination_dict = dict(
712713
list(zip(parameters.keys(), param_combination))
713714
)
@@ -748,9 +749,11 @@ def _evaluate_combination(param_combination):
748749
)
749750
error = metric(pred, val_series)
750751

751-
return error
752+
return float(error)
752753

753-
errors = _parallel_apply(iterator, _evaluate_combination, n_jobs, {}, {})
754+
errors: List[float] = _parallel_apply(
755+
iterator, _evaluate_combination, n_jobs, {}, {}
756+
)
754757

755758
min_error = min(errors)
756759

@@ -760,7 +763,7 @@ def _evaluate_combination(param_combination):
760763

761764
logger.info("Chosen parameters: " + str(best_param_combination))
762765

763-
return model_class(**best_param_combination), best_param_combination
766+
return model_class(**best_param_combination), best_param_combination, min_error
764767

765768
def residuals(
766769
self, series: TimeSeries, forecast_horizon: int = 1, verbose: bool = False

darts/tests/models/forecasting/test_4theta.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def test_best_model(self):
6262
series = sine_series + linear_series
6363
train_series, val_series = series.split_before(series.time_index[-10])
6464
thetas = np.linspace(-3, 3, 30)
65-
best_model, _ = FourTheta.select_best_model(train_series, thetas)
65+
best_model, _, _ = FourTheta.select_best_model(train_series, thetas)
6666
model = FourTheta(
6767
random.choice(thetas),
6868
model_mode=random.choice(list(ModelMode)),

darts/tests/models/forecasting/test_backtesting.py

+46-30
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,21 @@
4242
TORCH_AVAILABLE = False
4343

4444

45+
def get_dummy_series(
46+
ts_length: int, lt_end_value: int = 10, st_value_offset: int = 10
47+
) -> TimeSeries:
48+
return (
49+
lt(length=ts_length, end_value=lt_end_value)
50+
+ st(length=ts_length, value_y_offset=st_value_offset)
51+
+ rt(length=ts_length)
52+
)
53+
54+
4555
def compare_best_against_random(model_class, params, series, stride=1):
4656

4757
# instantiate best model in expanding window mode
4858
np.random.seed(1)
49-
best_model_1, _ = model_class.gridsearch(
59+
best_model_1, _, _ = model_class.gridsearch(
5060
params,
5161
series,
5262
forecast_horizon=10,
@@ -57,7 +67,9 @@ def compare_best_against_random(model_class, params, series, stride=1):
5767

5868
# instantiate best model in split mode
5969
train, val = series.split_before(series.time_index[-10])
60-
best_model_2, _ = model_class.gridsearch(params, train, val_series=val, metric=mape)
70+
best_model_2, _, _ = model_class.gridsearch(
71+
params, train, val_series=val, metric=mape
72+
)
6173

6274
# intantiate model with random parameters from 'params'
6375
random.seed(1)
@@ -297,12 +309,7 @@ def test_backtest_regression(self):
297309
def test_gridsearch(self):
298310
np.random.seed(1)
299311

300-
ts_length = 50
301-
dummy_series = (
302-
lt(length=ts_length, end_value=10)
303-
+ st(length=ts_length, value_y_offset=10)
304-
+ rt(length=ts_length)
305-
)
312+
dummy_series = get_dummy_series(ts_length=50)
306313
dummy_series_int_index = TimeSeries.from_values(dummy_series.values())
307314

308315
theta_params = {"theta": list(range(3, 10))}
@@ -322,16 +329,34 @@ def test_gridsearch(self):
322329
compare_best_against_random(ExponentialSmoothing, es_params, dummy_series)
323330
)
324331

332+
def test_gridsearch_metric_score(self):
333+
np.random.seed(1)
334+
335+
model_class = Theta
336+
params = {"theta": list(range(3, 6))}
337+
dummy_series = get_dummy_series(ts_length=50)
338+
339+
best_model, _, score = model_class.gridsearch(
340+
params,
341+
series=dummy_series,
342+
forecast_horizon=10,
343+
stride=1,
344+
start=dummy_series.time_index[-21],
345+
)
346+
recalculated_score = best_model.backtest(
347+
series=dummy_series,
348+
start=dummy_series.time_index[-21],
349+
forecast_horizon=10,
350+
stride=1,
351+
)
352+
353+
self.assertEqual(score, recalculated_score, "The metric scores should match")
354+
325355
@unittest.skipUnless(TORCH_AVAILABLE, "requires torch")
326356
def test_gridsearch_random_search(self):
327357
np.random.seed(1)
328358

329-
ts_length = 50
330-
dummy_series = (
331-
lt(length=ts_length, end_value=10)
332-
+ st(length=ts_length, value_y_offset=10)
333-
+ rt(length=ts_length)
334-
)
359+
dummy_series = get_dummy_series(ts_length=50)
335360

336361
param_range = list(range(10, 20))
337362
params = {"lags": param_range}
@@ -343,16 +368,12 @@ def test_gridsearch_random_search(self):
343368

344369
self.assertEqual(type(result[0]), RandomForest)
345370
self.assertEqual(type(result[1]["lags"]), int)
371+
self.assertEqual(type(result[2]), float)
346372
self.assertTrue(min(param_range) <= result[1]["lags"] <= max(param_range))
347373

348374
@unittest.skipUnless(TORCH_AVAILABLE, "requires torch")
349375
def test_gridsearch_n_random_samples_bad_arguments(self):
350-
ts_length = 50
351-
dummy_series = (
352-
lt(length=ts_length, end_value=10)
353-
+ st(length=ts_length, value_y_offset=10)
354-
+ rt(length=ts_length)
355-
)
376+
dummy_series = get_dummy_series(ts_length=50)
356377

357378
params = {"lags": list(range(1, 11)), "past_covariates": list(range(1, 11))}
358379

@@ -398,16 +419,11 @@ def test_gridsearch_n_jobs(self):
398419
"""
399420

400421
np.random.seed(1)
401-
ts_length = 100
402422

403-
dummy_series = (
404-
lt(length=ts_length, end_value=1)
405-
+ st(length=ts_length, value_y_offset=0)
406-
+ rt(length=ts_length)
423+
dummy_series = get_dummy_series(
424+
ts_length=100, lt_end_value=1, st_value_offset=0
407425
).astype(np.float32)
408-
409-
ts_train = dummy_series[: round(ts_length * 0.8)]
410-
ts_val = dummy_series[round(ts_length * 0.8) :]
426+
ts_train, ts_val = dummy_series.split_before(split_point=0.8)
411427

412428
test_cases = [
413429
{
@@ -433,12 +449,12 @@ def test_gridsearch_n_jobs(self):
433449
parameters = test["parameters"]
434450

435451
np.random.seed(1)
436-
_, best_params1 = model.gridsearch(
452+
_, best_params1, _ = model.gridsearch(
437453
parameters=parameters, series=ts_train, val_series=ts_val, n_jobs=1
438454
)
439455

440456
np.random.seed(1)
441-
_, best_params2 = model.gridsearch(
457+
_, best_params2, _ = model.gridsearch(
442458
parameters=parameters, series=ts_train, val_series=ts_val, n_jobs=-1
443459
)
444460

0 commit comments

Comments
 (0)