diff --git a/CHANGELOG.md b/CHANGELOG.md index fa602543ed..b0a4c4e66d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co ### For users of the library: **Improved** +- Added `darts.utils.statistics.plot_ccf` that can be used to plot the cross correlation between a time series (e.g. target series) and the lagged values of another time series (e.g. covariates series). [#2122](https://github.com/unit8co/darts/pull/2122) by [Dennis Bader](https://github.com/dennisbader). - Improvements to `TimeSeries`: - Improved the time series frequency inference when using slices or pandas DatetimeIndex as keys for `__getitem__`. [#2152](https://github.com/unit8co/darts/pull/2152) by [DavidKleindienst](https://github.com/DavidKleindienst). diff --git a/darts/tests/utils/test_statistics.py b/darts/tests/utils/test_statistics.py index 6b3ace96e3..e7775645c9 100644 --- a/darts/tests/utils/test_statistics.py +++ b/darts/tests/utils/test_statistics.py @@ -9,6 +9,8 @@ check_seasonality, extract_trend_and_seasonality, granger_causality_tests, + plot_acf, + plot_ccf, plot_pacf, plot_residuals_analysis, remove_seasonality, @@ -235,5 +237,7 @@ def test_statistics_plot(self): plt.close() plot_residuals_analysis(self.series[:10]) plt.close() + plot_acf(self.series) plot_pacf(self.series) + plot_ccf(self.series, self.series) plt.close() diff --git a/darts/utils/statistics.py b/darts/utils/statistics.py index 7b1fbdf601..faf4d1304c 100644 --- a/darts/utils/statistics.py +++ b/darts/utils/statistics.py @@ -10,8 +10,16 @@ import numpy as np from scipy.signal import argrelmax from scipy.stats import norm +from statsmodels.compat.python import lzip from statsmodels.tsa.seasonal import MSTL, STL, seasonal_decompose -from statsmodels.tsa.stattools import acf, adfuller, grangercausalitytests, kpss, pacf +from statsmodels.tsa.stattools import ( + acf, + adfuller, + ccovf, + grangercausalitytests, + kpss, + pacf, +) from darts import TimeSeries from darts.logging import get_logger, raise_if, raise_if_not, raise_log @@ -599,8 +607,8 @@ def plot_acf( default_formatting: bool = True, ) -> None: """ - Plots the ACF of `ts`, highlighting it at lag `m`, with corresponding significance interval. - Uses :func:`statsmodels.tsa.stattools.acf` [1]_ + Plots the Autocorrelation Function (ACF) of `ts`, highlighting it at lag `m`, with corresponding significance + interval. Uses :func:`statsmodels.tsa.stattools.acf` [1]_ Parameters ---------- @@ -695,8 +703,8 @@ def plot_pacf( default_formatting: bool = True, ) -> None: """ - Plots the Partial ACF of `ts`, highlighting it at lag `m`, with corresponding significance interval. - Uses :func:`statsmodels.tsa.stattools.pacf` [1]_ + Plots the Partial Autocorrelation Function (PACF) of `ts`, highlighting it at lag `m`, with corresponding + significance interval. Uses :func:`statsmodels.tsa.stattools.pacf` [1]_ Parameters ---------- @@ -785,6 +793,124 @@ def plot_pacf( axis.plot((0, max_lag + 1), (0, 0), color="black" if default_formatting else None) +def plot_ccf( + ts: TimeSeries, + ts_other: TimeSeries, + m: Optional[int] = None, + max_lag: int = 24, + alpha: float = 0.05, + bartlett_confint: bool = True, + fig_size: Tuple[int, int] = (10, 5), + axis: Optional[plt.axis] = None, + default_formatting: bool = True, +) -> None: + """ + Plots the Cross Correlation Function (CCF) between `ts` and `ts_other`, highlighting it at lag `m`, with + corresponding significance interval. Uses :func:`statsmodels.tsa.stattools.ccf` [1]_ + + This can be used to find the cross correlation between the target and different covariates lags. + If `ts_other` is identical `ts`, it corresponds to `plot_acf()`. + + Parameters + ---------- + ts + The TimeSeries whose CCF with `ts_other` should be plotted. + ts_other + The TimeSeries which to compare against `ts` in the CCF. E.g. check the cross correlation of different + covariate lags with the target. + m + Optionally, a time lag to highlight on the plot. + max_lag + The maximal lag order to consider. + alpha + The confidence interval to display. + bartlett_confint + The boolean value indicating whether the confidence interval should be + calculated using Bartlett's formula. + fig_size + The size of the figure to be displayed. + axis + Optionally, an axis object to plot the CCF on. + default_formatting + Whether to use the darts default scheme. + + References + ---------- + .. [1] https://www.statsmodels.org/dev/generated/statsmodels.tsa.stattools.ccf.html + """ + + ts._assert_univariate() + ts_other._assert_univariate() + raise_if( + max_lag is None or not (1 <= max_lag < len(ts)), + "max_lag must be greater than or equal to 1 and less than len(ts).", + ) + raise_if( + m is not None and not (0 <= m <= max_lag), + "m must be greater than or equal to 0 and less than or equal to max_lag.", + ) + raise_if( + alpha is None or not (0 < alpha < 1), + "alpha must be greater than 0 and less than 1.", + ) + ts_other = ts_other.slice_intersect(ts) + if len(ts_other) != len(ts): + raise_log( + ValueError("`ts_other` must contain at least the full time index of `ts`."), + logger=logger, + ) + + x = ts.values() + y = ts_other.values() + cvf = ccovf(x=x, y=y, adjusted=True, demean=True, fft=False) + + ccf = cvf / (np.std(x) * np.std(y)) + ccf = ccf[: max_lag + 1] + + n_obs = len(x) + if bartlett_confint: + varccf = np.ones_like(ccf) / n_obs + varccf[0] = 0 + varccf[1] = 1.0 / n_obs + varccf[2:] *= 1 + 2 * np.cumsum(ccf[1:-1] ** 2) + else: + varccf = 1.0 / n_obs + + interval = norm.ppf(1.0 - alpha / 2.0) * np.sqrt(varccf) + confint = np.array(lzip(ccf - interval, ccf + interval)) + + if axis is None: + plt.figure(figsize=fig_size) + axis = plt + + for i in range(len(ccf)): + axis.plot( + (i, i), + (0, ccf[i]), + color=("#b512b8" if m is not None and i == m else "black") + if default_formatting + else None, + lw=(1 if m is not None and i == m else 0.5), + ) + + # Adjusts the upper band of the confidence interval to center it on the x axis. + upp_band = [confint[lag][1] - ccf[lag] for lag in range(1, max_lag + 1)] + + # Setting color t0 None overrides custom settings + extra_arguments = {} + if default_formatting: + extra_arguments["color"] = "#003DFD" + + axis.fill_between( + np.arange(1, max_lag + 1), + upp_band, + [-x for x in upp_band], + alpha=0.25 if default_formatting else None, + **extra_arguments, + ) + axis.plot((0, max_lag + 1), (0, 0), color="black" if default_formatting else None) + + def plot_hist( data: Union[TimeSeries, List[float], np.ndarray], bins: Optional[Union[int, np.ndarray, List[float]]] = None,