Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvement/statsforecastets: make sf_ets probabilistic + add future_covariate support for sf_ets + add AutoTheta #1476

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
3e571c3
StatsForecastETS now is probabilistic in the same way as StatsForecas…
Beerstabr Jan 5, 2023
0a2bd18
Merge pull request #1 from Beerstabr/master
Beerstabr Jan 5, 2023
60253a5
include future covariates in sf_ets
Beerstabr Jan 6, 2023
c7f20a3
sf_ets with future_covariates works.. probably it is underestimating …
Beerstabr Jan 7, 2023
ece2bad
Create separate file for StatsForecast models and extract some functi…
Beerstabr Jan 8, 2023
4131cb0
Added AutoTheta from the StatsForecast package.
Beerstabr Jan 9, 2023
2921d52
Deleted sf_auto_arima.py and sf_ets.py, because the code is now inclu…
Beerstabr Jan 9, 2023
049970c
Merge branch 'master' into improvement/statsforecastets_probabilistic
Beerstabr Jan 9, 2023
30aff7d
Merge branch 'master' into improvement/statsforecastets_probabilistic
hrzn Jan 20, 2023
1326065
Merge branch 'master' into improvement/statsforecastets_probabilistic
hrzn Jan 24, 2023
6a514d9
Update darts/models/forecasting/sf_models.py
Beerstabr Jan 26, 2023
b81fddb
Update darts/models/forecasting/sf_models.py
Beerstabr Jan 26, 2023
2e07d3f
Merge branch 'master' into improvement/statsforecastets_probabilistic
hrzn Jan 27, 2023
67d48fa
Moved all statsforecast models to their own .py file. Added some comm…
Beerstabr Jan 27, 2023
6a2f73f
Beginning of test for fit on residuals for statsforecast ets.
Beerstabr Jan 28, 2023
6ea16b4
- AutoCES not probablisitc anymore, because that is not yet released …
Beerstabr Jan 31, 2023
b74c49a
- AutoCES not probablisitc anymore, because that is not yet released …
Beerstabr Jan 31, 2023
91bbe59
Merge branch 'master' into improvement/statsforecastets_probabilistic
hrzn Feb 7, 2023
83f08da
Changed StatsForecastETS to StatsForecastAutoETS.
Beerstabr Feb 7, 2023
a0a322b
Merge remote-tracking branch 'origin/improvement/statsforecastets_pro…
Beerstabr Feb 7, 2023
fc20e77
Merge branch 'master' into improvement/statsforecastets_probabilistic
hrzn Feb 10, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions darts/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,15 @@ class NotImportedCatBoostModel:
try:
from darts.models.forecasting.croston import Croston
from darts.models.forecasting.sf_auto_arima import StatsForecastAutoARIMA
from darts.models.forecasting.sf_ets import StatsForecastETS
from darts.models.forecasting.sf_auto_ces import StatsForecastAutoCES
from darts.models.forecasting.sf_auto_ets import StatsForecastAutoETS
from darts.models.forecasting.sf_auto_theta import StatsForecastAutoTheta

except ImportError:
logger.warning(
"The statsforecast module could not be imported. "
"To enable support for the StatsForecastAutoARIMA, "
"StatsForecastETS and Croston models, please consider "
"StatsForecastAutoETS and Croston models, please consider "
"installing it."
)

Expand All @@ -104,10 +107,10 @@ class NotImportedStatsForecastAutoARIMA:

StatsForecastAutoARIMA = NotImportedStatsForecastAutoARIMA()

class NotImportedStatsForecastETS:
class NotImportedStatsForecastAutoETS:
usable = False

StatsForecastETS = NotImportedStatsForecastETS()
StatsForecastAutoETS = NotImportedStatsForecastAutoETS()

class NotImportedCroston:
usable = False
Expand Down
30 changes: 30 additions & 0 deletions darts/models/components/statsforecast_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""
StatsForecast utils
-----------
"""

import numpy as np

# In a normal distribution, 68.27 percentage of values lie within one standard deviation of the mean
one_sigma_rule = 68.27


def create_normal_samples(
mu: float,
std: float,
num_samples: int,
n: int,
) -> np.array:
"""Generate samples assuming a Normal distribution."""
samples = np.random.normal(loc=mu, scale=std, size=(num_samples, n)).T
samples = np.expand_dims(samples, axis=1)
return samples


def unpack_sf_dict(
forecast_dict: dict,
):
"""Unpack the dictionary that is returned by the StatsForecast 'predict()' method."""
mu = forecast_dict["mean"]
std = forecast_dict[f"hi-{one_sigma_rule}"] - mu
return mu, std
16 changes: 9 additions & 7 deletions darts/models/forecasting/sf_auto_arima.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@

from typing import Optional

import numpy as np
from statsforecast.models import AutoARIMA as SFAutoARIMA

from darts import TimeSeries
from darts.models.components.statsforecast_utils import (
create_normal_samples,
one_sigma_rule,
unpack_sf_dict,
)
from darts.models.forecasting.forecasting_model import (
FutureCovariatesLocalForecastingModel,
)
Expand Down Expand Up @@ -91,17 +95,15 @@ def _predict(
verbose: bool = False,
):
super()._predict(n, future_covariates, num_samples)
forecast_df = self.model.predict(
forecast_dict = self.model.predict(
h=n,
X=future_covariates.values(copy=False) if future_covariates else None,
level=(68.27,), # ask one std for the confidence interval.
level=(one_sigma_rule,), # ask one std for the confidence interval.
)

mu = forecast_df["mean"]
mu, std = unpack_sf_dict(forecast_dict)
if num_samples > 1:
std = forecast_df["hi-68.27"] - mu
samples = np.random.normal(loc=mu, scale=std, size=(num_samples, n)).T
samples = np.expand_dims(samples, axis=1)
samples = create_normal_samples(mu, std, num_samples, n)
else:
samples = mu

Expand Down
80 changes: 80 additions & 0 deletions darts/models/forecasting/sf_auto_ces.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""
StatsForecastAutoCES
-----------
"""

from statsforecast.models import AutoCES as SFAutoCES

from darts import TimeSeries
from darts.models.forecasting.forecasting_model import LocalForecastingModel


class StatsForecastAutoCES(LocalForecastingModel):
def __init__(self, *autoces_args, **autoces_kwargs):
"""Auto-CES based on `Statsforecasts package
<https://github.com/Nixtla/statsforecast>`_.

Automatically selects the best Complex Exponential Smoothing model using an information criterion.
<https://onlinelibrary.wiley.com/doi/full/10.1002/nav.22074>

We refer to the `statsforecast AutoCES documentation
<https://nixtla.github.io/statsforecast/models.html#autoces>`_
for the documentation of the arguments.

Parameters
----------
autoces_args
Positional arguments for ``statsforecasts.models.AutoCES``.
autoces_kwargs
Keyword arguments for ``statsforecasts.models.AutoCES``.

..

Examples
--------
>>> from darts.models import StatsForecastAutoCES
>>> from darts.datasets import AirPassengersDataset
>>> series = AirPassengersDataset().load()
>>> model = StatsForecastAutoCES(season_length=12)
>>> model.fit(series[:-36])
>>> pred = model.predict(36, num_samples=100)
"""
super().__init__()
self.model = SFAutoCES(*autoces_args, **autoces_kwargs)

def __str__(self):
return "Auto-CES-Statsforecasts"

def fit(self, series: TimeSeries):
super().fit(series)
self._assert_univariate(series)
series = self.training_series
self.model.fit(
series.values(copy=False).flatten(),
)
return self

def predict(
self,
n: int,
num_samples: int = 1,
verbose: bool = False,
):
super().predict(n, num_samples)
forecast_dict = self.model.predict(
h=n,
)

mu = forecast_dict["mean"]

return self._build_forecast_series(mu)

@property
def min_train_series_length(self) -> int:
return 10

def _supports_range_index(self) -> bool:
return True

def _is_probabilistic(self) -> bool:
return False
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
"""
StatsForecastETS
StatsForecastAutoETS
-----------
"""

from typing import Optional

from statsforecast.models import ETS
from statsforecast.models import AutoETS as SFAutoETS

from darts import TimeSeries
from darts.models import LinearRegressionModel
from darts.models.components.statsforecast_utils import (
create_normal_samples,
one_sigma_rule,
unpack_sf_dict,
)
from darts.models.forecasting.forecasting_model import (
FutureCovariatesLocalForecastingModel,
)


class StatsForecastETS(FutureCovariatesLocalForecastingModel):
class StatsForecastAutoETS(FutureCovariatesLocalForecastingModel):
def __init__(self, *ets_args, add_encoders: Optional[dict] = None, **ets_kwargs):
"""ETS based on `Statsforecasts package
<https://github.com/Nixtla/statsforecast>`_.
Expand All @@ -25,6 +31,12 @@ def __init__(self, *ets_args, add_encoders: Optional[dict] = None, **ets_kwargs)
This model accepts the same arguments as the `statsforecast ETS
<https://nixtla.github.io/statsforecast/models.html#ets>`_. package.

In addition to the StatsForecast implementation, this model can handle future covariates. It does so by first
regressing the series against the future covariates using the :class:'LinearRegressionModel' model and then
running StatsForecast's AutoETS on the in-sample residuals from this original regression. This approach was
inspired by 'this post of Stephan Kolassa< https://stats.stackexchange.com/q/220885>'_.


Parameters
----------
season_length
Expand Down Expand Up @@ -64,14 +76,15 @@ def __init__(self, *ets_args, add_encoders: Optional[dict] = None, **ets_kwargs)
Examples
--------
>>> from darts.datasets import AirPassengersDataset
>>> from darts.models import StatsForecastETS
>>> from darts.models import StatsForecastAutoETS
>>> series = AirPassengersDataset().load()
>>> model = StatsForecastETS(season_length=12, model="AZZ")
>>> model = StatsForecastAutoETS(season_length=12, model="AZZ")
>>> model.fit(series[:-36])
>>> pred = model.predict(36)
"""
super().__init__(add_encoders=add_encoders)
self.model = ETS(*ets_args, **ets_kwargs)
self.model = SFAutoETS(*ets_args, **ets_kwargs)
self._linreg = None

def __str__(self):
return "ETS-Statsforecasts"
Expand All @@ -80,9 +93,25 @@ def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = Non
super()._fit(series, future_covariates)
self._assert_univariate(series)
series = self.training_series

if future_covariates is not None:
# perform OLS and get in-sample residuals
linreg = LinearRegressionModel(lags_future_covariates=[0])
linreg.fit(series, future_covariates=future_covariates)
fitted_values = linreg.model.predict(
X=future_covariates.slice_intersect(series).values(copy=False)
)
fitted_values_ts = TimeSeries.from_times_and_values(
times=series.time_index, values=fitted_values
)
resids = series - fitted_values_ts
self._linreg = linreg
target = resids
else:
target = series

self.model.fit(
series.values(copy=False).flatten(),
X=future_covariates.values(copy=False) if future_covariates else None,
target.values(copy=False).flatten(),
)
return self

Expand All @@ -94,12 +123,27 @@ def _predict(
verbose: bool = False,
):
super()._predict(n, future_covariates, num_samples)
forecast_df = self.model.predict(
forecast_dict = self.model.predict(
h=n,
X=future_covariates.values(copy=False) if future_covariates else None,
level=(one_sigma_rule,), # ask one std for the confidence interval
)

return self._build_forecast_series(forecast_df["mean"])
mu_ets, std = unpack_sf_dict(forecast_dict)

if future_covariates is not None:
mu_linreg = self._linreg.predict(n, future_covariates=future_covariates)
mu_linreg_values = mu_linreg.values(copy=False).reshape(
n,
)
mu = mu_ets + mu_linreg_values
else:
mu = mu_ets

if num_samples > 1:
samples = create_normal_samples(mu, std, num_samples, n)
else:
samples = mu
return self._build_forecast_series(samples)

@property
def min_train_series_length(self) -> int:
Expand All @@ -109,4 +153,4 @@ def _supports_range_index(self) -> bool:
return True

def _is_probabilistic(self) -> bool:
return False
return True
Loading