Skip to content

Commit 8b88b0d

Browse files
Docs/general improvements (#1904)
* udpate dtw example and add dtw to rendered documentation * make all windows inherit from Window * clean up windows * improve dtw documentation * improved forecasting model module documentation * update models and add model links in covariates user guide * add model links to README * update changelog * fix typo in dtw example notebook * remove outdated lines from tide model from before probabilistic support * apply suggestions from PR review * update readme model table * Feat/fit predict encodings (#1925) * added encode_train_inference to encoders * added generate_fit_predict_encodings to ForecastingModel * simplify TransferrableFut..Model.generatice_predict_encodings * update changelog * Apply suggestions from code review Co-authored-by: madtoinou <[email protected]> * apply suggestions from PR review part 2 --------- Co-authored-by: madtoinou <[email protected]> * update readme * apply suggestions from code review and improve README.md * Update README.md Co-authored-by: madtoinou <[email protected]> --------- Co-authored-by: madtoinou <[email protected]>
1 parent 1af057a commit 8b88b0d

File tree

10 files changed

+282
-164
lines changed

10 files changed

+282
-164
lines changed

CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
3535
- Improvements to `Explainability` module:
3636
- 🚀🚀 New forecasting model explainer: `TFTExplainer` for `TFTModel`. You can now access and visualize the trained model's feature importances and self attention. [#1392](https://github.com/unit8co/darts/issues/1392) by [Sebastian Cattes](https://github.com/Cattes) and [Dennis Bader](https://github.com/dennisbader).
3737
- Added static covariates support to `ShapeExplainer`. [#1803](https://github.com/unit8co/darts/pull/1803) by [Anne de Vries](https://github.com/anne-devries) and [Dennis Bader](https://github.com/dennisbader).
38+
- Improvements to documentation [#1904](https://github.com/unit8co/darts/pull/1904) by [Dennis Bader](https://github.com/dennisbader):
39+
- made model sections in README.md, covariates user guide and forecasting model API Reference more user friendly by adding model links and reorganizing them into model categories.
40+
- added the Dynamic Time Warping (DTW) module and improved its appearance.
3841
- Other improvements:
3942
- Improved static covariates column naming when using `StaticCovariatesTransformer` with a `sklearn.preprocessing.OneHotEncoder`. [#1863](https://github.com/unit8co/darts/pull/1863) by [Anne de Vries](https://github.com/anne-devries).
4043
- Added `MSTL` (Season-Trend decomposition using LOESS for multiple seasonalities) as a `method` option for `extract_trend_and_seasonality()`. [#1879](https://github.com/unit8co/darts/pull/1879) by [Alex Colpitts](https://github.com/alexcolpitts96).

README.md

+47-38
Large diffs are not rendered by default.

darts/dataprocessing/dtw/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
"""
2+
Dynamic Time Warping (DTW)
3+
--------------------------
4+
"""
5+
16
from .cost_matrix import CostMatrix
27
from .dtw import DTWAlignment, dtw
38
from .window import CRWindow, Itakura, NoWindow, SakoeChiba, Window

darts/dataprocessing/dtw/dtw.py

+52-27
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
1+
"""
2+
Dynamic Time Warping (DTW)
3+
--------------------------
4+
"""
5+
16
import copy
2-
from typing import Callable, Union
7+
from typing import Callable, Optional, Union
38

49
import numpy as np
510
import pandas as pd
@@ -21,7 +26,7 @@
2126
# CORE ALGORITHM
2227
def _dtw_cost_matrix(
2328
x: np.ndarray, y: np.ndarray, dist: DistanceFunc, window: Window
24-
) -> np.ndarray:
29+
) -> CostMatrix:
2530

2631
dtw = CostMatrix._from_window(window)
2732

@@ -138,16 +143,40 @@ def _fast_dtw(
138143
return cost
139144

140145

146+
def _default_distance_multi(x_values: np.ndarray, y_values: np.ndarray):
147+
return np.sum(np.abs(x_values - y_values))
148+
149+
150+
def _default_distance_uni(x_value: float, y_value: float):
151+
return abs(x_value - y_value)
152+
153+
141154
# Public API Functions
142155
class DTWAlignment:
156+
"""
157+
Dynamic Time Warping (DTW) Alignment.
158+
159+
Attributes
160+
----------
161+
n
162+
The length of `series1`
163+
m
164+
The length of `series2`
165+
series1
166+
A `TimeSeries` to align with `series2`.
167+
series2
168+
A `TimeSeries` to align with `series1`.
169+
cost
170+
The `CostMatrix` for DTW.
171+
"""
172+
143173
n: int
144174
m: int
145175
series1: TimeSeries
146176
series2: TimeSeries
147177
cost: CostMatrix
148178

149179
def __init__(self, series1: TimeSeries, series2: TimeSeries, cost: CostMatrix):
150-
151180
self.n = len(series1)
152181
self.m = len(series2)
153182
self.series1 = series1
@@ -157,7 +186,8 @@ def __init__(self, series1: TimeSeries, series2: TimeSeries, cost: CostMatrix):
157186
from ._plot import plot, plot_alignment
158187

159188
def path(self) -> np.ndarray:
160-
"""
189+
"""Gives the index paths from `series1` to `series2`.
190+
161191
Returns
162192
-------
163193
np.ndarray of shape `(len(path), 2)`
@@ -172,7 +202,8 @@ def path(self) -> np.ndarray:
172202
return self._path
173203

174204
def distance(self) -> float:
175-
"""
205+
"""Gives the total distance between pair-wise elements in the two series after warping.
206+
176207
Returns
177208
-------
178209
float
@@ -181,7 +212,8 @@ def distance(self) -> float:
181212
return self.cost[(self.n, self.m)]
182213

183214
def mean_distance(self) -> float:
184-
"""
215+
"""Gives the mean distance between pair-wise elements in the two series after warping.
216+
185217
Returns
186218
-------
187219
float
@@ -195,9 +227,8 @@ def mean_distance(self) -> float:
195227
return self._mean_distance
196228

197229
def warped(self) -> (TimeSeries, TimeSeries):
198-
"""
199-
Warps the two time series according to the warp path returned by .path(), which minimizes
200-
the pair-wise distance.
230+
"""Warps the two time series according to the warp path returned by `DTWAlignment.path()`, which minimizes the
231+
pair-wise distance.
201232
This will bring two time series that are out-of-phase back into phase.
202233
203234
Returns
@@ -254,24 +285,16 @@ def warped(self) -> (TimeSeries, TimeSeries):
254285
)
255286

256287

257-
def default_distance_multi(x_values: np.ndarray, y_values: np.ndarray):
258-
return np.sum(np.abs(x_values - y_values))
259-
260-
261-
def default_distance_uni(x_value: float, y_value: float):
262-
return abs(x_value - y_value)
263-
264-
265288
def dtw(
266289
series1: TimeSeries,
267290
series2: TimeSeries,
268-
window: Window = NoWindow(),
291+
window: Optional[Window] = None,
269292
distance: Union[DistanceFunc, None] = None,
270293
multi_grid_radius: int = -1,
271294
) -> DTWAlignment:
272295
"""
273-
Determines the optimal alignment between two time series series1 and series2,
274-
according to the Dynamic Time Warping algorithm.
296+
Determines the optimal alignment between two time series `series1` and `series2`, according to the Dynamic Time
297+
Warping algorithm.
275298
The alignment minimizes the distance between pair-wise elements after warping.
276299
All elements in the two series are matched and are in strictly monotonically increasing order.
277300
Considers only the values in the series, ignoring the time axis.
@@ -282,24 +305,24 @@ def dtw(
282305
Parameters
283306
----------
284307
series1
285-
`TimeSeries`
308+
A `TimeSeries` to align with `series2`.
286309
series2
287-
A `TimeSeries`
310+
A `TimeSeries` to align with `series1`.
288311
window
289-
Used to constrain the search for the optimal alignment: see SakoeChiba and Itakura.
290-
Default considers all possible alignments.
312+
Optionally, a `Window` used to constrain the search for the optimal alignment: see `SakoeChiba` and `Itakura`.
313+
Default considers all possible alignments (`NoWindow`).
291314
distance
292315
Function taking as input either two `floats` for univariate series or two `np.ndarray`,
293316
and returning the distance between them.
294317
295318
Defaults to the abs difference for univariate-data and the
296319
sum of the abs difference for multi-variate series.
297320
multi_grid_radius
298-
Default radius of -1 results in an exact evaluation of the dynamic time warping algorithm.
321+
Default radius of `-1` results in an exact evaluation of the dynamic time warping algorithm.
299322
Without constraints DTW runs in O(nxm) time where n,m are the size of the series.
300323
Exact evaluation with no constraints, will result in a performance warning on large datasets.
301324
302-
Setting multi_grid_radius to a value other than -1, will enable the approximate multi-grid solver,
325+
Setting `multi_grid_radius` to a value other than `-1`, will enable the approximate multi-grid solver,
303326
which executes in linear time, vs quadratic time for exact evaluation.
304327
Increasing radius trades solution accuracy for performance.
305328
@@ -308,6 +331,8 @@ def dtw(
308331
DTWAlignment
309332
Helper object for getting warp path, mean_distance, distance and warped time series
310333
"""
334+
if window is None:
335+
window = NoWindow()
311336

312337
if (
313338
multi_grid_radius == -1
@@ -328,7 +353,7 @@ def dtw(
328353
logger,
329354
)
330355

331-
distance = default_distance_uni if both_univariate else default_distance_multi
356+
distance = _default_distance_uni if both_univariate else _default_distance_multi
332357

333358
if both_univariate:
334359
values_x = series1.univariate_values(copy=False)

0 commit comments

Comments
 (0)