Skip to content

Commit

Permalink
fix: time index intersection for coefficient of variation (unit8co#2202)
Browse files Browse the repository at this point in the history
* fix: properly take the intersected time indexes for the coefficient of variation

* fix: computing rmse on ndarray directly

* fix: forgot sqrt for rmse in coef of variation

* fix: update type of return in docstring, taking into consideration the multi_ts and multivariate decorator, which convert arrays into list

* update changelog

* update changelog

---------

Co-authored-by: dennisbader <[email protected]>
  • Loading branch information
madtoinou and dennisbader authored Feb 5, 2024
1 parent 8cb04f6 commit 1d7d854
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 20 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
- Added option to exclude some `group_cols` from being added as static covariates when using `TimeSeries.from_group_dataframe()` with parameter `drop_group_cols`.

**Fixed**
- Fixed a bug in `coefficient_of_variaton()` with `intersect=True`, where the coefficient was not computed on the intersection. [#2202](https://github.com/unit8co/darts/pull/2202) by [Antoine Madrona](https://github.com/madtoinou).

### For developers of the library:

Expand Down
40 changes: 20 additions & 20 deletions darts/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from functools import wraps
from inspect import signature
from typing import Callable, Optional, Sequence, Tuple, Union
from typing import Callable, List, Optional, Sequence, Tuple, Union
from warnings import warn

import numpy as np
Expand All @@ -27,7 +27,7 @@
# care of dealing with Sequence[TimeSeries] and multivariate TimeSeries on its own (See mase() implementation).


def multi_ts_support(func):
def multi_ts_support(func) -> Union[float, List[float]]:
"""
This decorator further adapts the metrics that took as input two univariate/multivariate ``TimeSeries`` instances,
adding support for equally-sized sequences of ``TimeSeries`` instances. The decorator computes the pairwise metric
Expand Down Expand Up @@ -107,7 +107,7 @@ def wrapper_multi_ts_support(*args, **kwargs):
return wrapper_multi_ts_support


def multivariate_support(func):
def multivariate_support(func) -> Union[float, List[float]]:
"""
This decorator transforms a metric function that takes as input two univariate TimeSeries instances
into a function that takes two equally-sized multivariate TimeSeries instances, computes the pairwise univariate
Expand Down Expand Up @@ -279,7 +279,7 @@ def mae(
Returns
-------
float
Union[float, List[float]]
The Mean Absolute Error (MAE)
"""

Expand Down Expand Up @@ -336,7 +336,7 @@ def mse(
Returns
-------
float
Union[float, List[float]]
The Mean Squared Error (MSE)
"""

Expand Down Expand Up @@ -393,7 +393,7 @@ def rmse(
Returns
-------
float
Union[float, List[float]]
The Root Mean Squared Error (RMSE)
"""
return np.sqrt(mse(actual_series, pred_series, intersect))
Expand Down Expand Up @@ -448,7 +448,7 @@ def rmsle(
Returns
-------
float
Union[float, List[float]]
The Root Mean Squared Log Error (RMSLE)
"""

Expand Down Expand Up @@ -510,15 +510,15 @@ def coefficient_of_variation(
Returns
-------
float
Union[float, List[float]]
The Coefficient of Variation
"""

return (
100
* rmse(actual_series, pred_series, intersect)
/ actual_series.pd_dataframe(copy=False).mean().mean()
y_true, y_pred = _get_values_or_raise(
actual_series, pred_series, intersect, remove_nan_union=True
)
# not calling rmse as y_true and y_pred are np.ndarray
return 100 * np.sqrt(np.mean((y_true - y_pred) ** 2)) / y_true.mean()


@multi_ts_support
Expand Down Expand Up @@ -577,7 +577,7 @@ def mape(
Returns
-------
float
Union[float, List[float]]
The Mean Absolute Percentage Error (MAPE)
"""

Expand Down Expand Up @@ -650,7 +650,7 @@ def smape(
Returns
-------
float
Union[float, List[float]]
The symmetric Mean Absolute Percentage Error (sMAPE)
"""

Expand Down Expand Up @@ -725,7 +725,7 @@ def mase(
Returns
-------
float
Union[float, List[float]]
The Mean Absolute Scaled Error (MASE)
"""

Expand Down Expand Up @@ -907,7 +907,7 @@ def ope(
Returns
-------
float
Union[float, List[float]]
The Overall Percentage Error (OPE)
"""

Expand Down Expand Up @@ -977,7 +977,7 @@ def marre(
Returns
-------
float
Union[float, List[float]]
The Mean Absolute Ranged Relative Error (MARRE)
"""

Expand Down Expand Up @@ -1042,7 +1042,7 @@ def r2_score(
Returns
-------
float
Union[float, List[float]]
The Coefficient of Determination :math:`R^2`
"""
y1, y2 = _get_values_or_raise(
Expand Down Expand Up @@ -1185,7 +1185,7 @@ def rho_risk(
Returns
-------
float
Union[float, List[float]]
The rho-risk metric
"""

Expand Down Expand Up @@ -1263,7 +1263,7 @@ def quantile_loss(
Returns
-------
float
Union[float, List[float]]
The quantile loss metric
"""

Expand Down

0 comments on commit 1d7d854

Please sign in to comment.