Skip to content

Commit 2c43352

Browse files
gdevos010Greg DeVoshrzndennisbader
authored
Feat/torchmetrics (#996)
* added droput and batch_norm similiar to nhits' implementation * activation param for nhits * fix * support any pytorch activation function. NHiTs AvgPool1d support * PR * Update CHANGELOG.md * fixed typo * pytorch lightning did not like saving nn.modules * first pass * metrics now works with likelihood * rename var * made metrics a parameter. Added Tests * torchmetrics is already a dependency * changelog * model now accepts torchmetrics and torchCollections * torchmetric example in early stop * handle no metrics * made _calculate_metrics private * Ray tune example * Ray tune example * changelog * Update darts/models/forecasting/pl_forecasting_module.py Co-authored-by: Dennis Bader <[email protected]> * Update darts/models/forecasting/pl_forecasting_module.py Co-authored-by: Dennis Bader <[email protected]> * added torch_metrics to doc strings in torch based models * added torch_metrics to doc strings in torch based models * Update darts/models/forecasting/pl_forecasting_module.py * black formatting Co-authored-by: Greg DeVos <[email protected]> Co-authored-by: Julien Herzen <[email protected]> Co-authored-by: Dennis Bader <[email protected]>
1 parent abf12da commit 2c43352

11 files changed

+303
-75
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ Darts is still in an early development phase and we cannot always guarantee back
1515
by [Greg DeVos](https://github.com/gdevos010)
1616
- Implemented ["GLU Variants Improve Transformer"](https://arxiv.org/abs/2002.05202) for transformer based models (transformer and TFT). [#959](https://github.com/unit8co/darts/issues/959)
1717
by [Greg DeVos](https://github.com/gdevos010)
18+
- Added support for torch metrics during training and validation. [#996](https://github.com/unit8co/darts/pull/996) by [Greg DeVos](https://github.com/gdevos010)
1819

1920
## [0.19.0](https://github.com/unit8co/darts/tree/0.19.0) (2022-04-13)
2021
### For users of the library:

darts/models/forecasting/block_rnn_model.py

+3
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,9 @@ def __init__(
178178
PyTorch loss function used for training.
179179
This parameter will be ignored for probabilistic models if the ``likelihood`` parameter is specified.
180180
Default: ``torch.nn.MSELoss()``.
181+
torch_metrics
182+
A torch metric or a ``MetricCollection`` used for evaluation. A full list of available metrics can be found
183+
at https://torchmetrics.readthedocs.io/en/latest/. Default: ``None``.
181184
likelihood
182185
One of Darts' :meth:`Likelihood <darts.utils.likelihood_models.Likelihood>` models to be used for
183186
probabilistic forecasts. Default: ``None``.

darts/models/forecasting/nbeats.py

+3
Original file line numberDiff line numberDiff line change
@@ -598,6 +598,9 @@ def __init__(
598598
PyTorch loss function used for training.
599599
This parameter will be ignored for probabilistic models if the ``likelihood`` parameter is specified.
600600
Default: ``torch.nn.MSELoss()``.
601+
torch_metrics
602+
A torch metric or a ``MetricCollection`` used for evaluation. A full list of available metrics can be found
603+
at https://torchmetrics.readthedocs.io/en/latest/. Default: ``None``.
601604
likelihood
602605
One of Darts' :meth:`Likelihood <darts.utils.likelihood_models.Likelihood>` models to be used for
603606
probabilistic forecasts. Default: ``None``.

darts/models/forecasting/nhits.py

+3
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,9 @@ def __init__(
534534
PyTorch loss function used for training.
535535
This parameter will be ignored for probabilistic models if the ``likelihood`` parameter is specified.
536536
Default: ``torch.nn.MSELoss()``.
537+
torch_metrics
538+
A torch metric or a ``MetricCollection`` used for evaluation. A full list of available metrics can be found
539+
at https://torchmetrics.readthedocs.io/en/latest/. Default: ``None``.
537540
likelihood
538541
One of Darts' :meth:`Likelihood <darts.utils.likelihood_models.Likelihood>` models to be used for
539542
probabilistic forecasts. Default: ``None``.

darts/models/forecasting/pl_forecasting_module.py

+55-10
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
"""
44

55
from abc import ABC, abstractmethod
6-
from typing import Any, Dict, Optional, Sequence, Tuple
6+
from typing import Any, Dict, Optional, Sequence, Tuple, Union
77

88
import pytorch_lightning as pl
99
import torch
1010
import torch.nn as nn
11+
import torchmetrics
1112
from joblib import Parallel, delayed
1213

1314
from darts.logging import get_logger, raise_if, raise_log
@@ -29,10 +30,13 @@ def __init__(
2930
input_chunk_length: int,
3031
output_chunk_length: int,
3132
loss_fn: nn.modules.loss._Loss = nn.MSELoss(),
33+
torch_metrics: Optional[
34+
Union[torchmetrics.Metric, torchmetrics.MetricCollection]
35+
] = None,
3236
likelihood: Optional[Likelihood] = None,
3337
optimizer_cls: torch.optim.Optimizer = torch.optim.Adam,
3438
optimizer_kwargs: Optional[Dict] = None,
35-
lr_scheduler_cls: torch.optim.lr_scheduler._LRScheduler = None,
39+
lr_scheduler_cls: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
3640
lr_scheduler_kwargs: Optional[Dict] = None,
3741
) -> None:
3842
"""
@@ -58,6 +62,9 @@ def __init__(
5862
PyTorch loss function used for training.
5963
This parameter will be ignored for probabilistic models if the ``likelihood`` parameter is specified.
6064
Default: ``torch.nn.MSELoss()``.
65+
torch_metrics
66+
A torch metric or a ``MetricCollection`` used for evaluation. A full list of available metrics can be found
67+
at https://torchmetrics.readthedocs.io/en/latest/. Default: ``None``.
6168
likelihood
6269
One of Darts' :meth:`Likelihood <darts.utils.likelihood_models.Likelihood>` models to be used for
6370
probabilistic forecasts. Default: ``None``.
@@ -76,7 +83,8 @@ def __init__(
7683
super().__init__()
7784

7885
# save hyper parameters for saving/loading
79-
self.save_hyperparameters()
86+
# do not save type nn.Module params
87+
self.save_hyperparameters(ignore=["loss_fn", "torch_metrics"])
8088

8189
raise_if(
8290
input_chunk_length is None or output_chunk_length is None,
@@ -100,6 +108,22 @@ def __init__(
100108
dict() if lr_scheduler_kwargs is None else lr_scheduler_kwargs
101109
)
102110

111+
if torch_metrics is None:
112+
torch_metrics = torchmetrics.MetricCollection([])
113+
elif isinstance(torch_metrics, torchmetrics.Metric):
114+
torch_metrics = torchmetrics.MetricCollection([torch_metrics])
115+
elif isinstance(torch_metrics, torchmetrics.MetricCollection):
116+
pass
117+
else:
118+
raise_log(
119+
AttributeError(
120+
"`torch_metrics` only accepts type torchmetrics.Metric or torchmetrics.MetricCollection"
121+
),
122+
logger,
123+
)
124+
self.train_metrics = torch_metrics.clone(prefix="train_")
125+
self.val_metrics = torch_metrics.clone(prefix="val_")
126+
103127
# initialize prediction parameters
104128
self.pred_n: Optional[int] = None
105129
self.pred_num_samples: Optional[int] = None
@@ -126,6 +150,7 @@ def training_step(self, train_batch, batch_idx) -> torch.Tensor:
126150
] # By convention target is always the last element returned by datasets
127151
loss = self._compute_loss(output, target)
128152
self.log("train_loss", loss, batch_size=train_batch[0].shape[0], prog_bar=True)
153+
self._calculate_metrics(output, target, self.train_metrics)
129154
return loss
130155

131156
def validation_step(self, val_batch, batch_idx) -> torch.Tensor:
@@ -134,6 +159,7 @@ def validation_step(self, val_batch, batch_idx) -> torch.Tensor:
134159
target = val_batch[-1]
135160
loss = self._compute_loss(output, target)
136161
self.log("val_loss", loss, batch_size=val_batch[0].shape[0], prog_bar=True)
162+
self._calculate_metrics(output, target, self.val_metrics)
137163
return loss
138164

139165
def predict_step(
@@ -154,7 +180,7 @@ def predict_step(
154180
# number of individual series to be predicted in current batch
155181
num_series = input_data_tuple[0].shape[0]
156182

157-
# number of of times the input tensor should be tiled to produce predictions for multiple samples
183+
# number of times the input tensor should be tiled to produce predictions for multiple samples
158184
# this variable is larger than 1 only if the batch_size is at least twice as large as the number
159185
# of individual time series being predicted in current batch (`num_series`)
160186
batch_sample_size = min(
@@ -226,12 +252,31 @@ def _compute_loss(self, output, target):
226252
if self.likelihood:
227253
return self.likelihood.compute_loss(output, target)
228254
else:
229-
# If there's no likelihood, nr_params=1 and we need to squeeze out the
255+
# If there's no likelihood, nr_params=1, and we need to squeeze out the
230256
# last dimension of model output, for properly computing the loss.
231257
return self.criterion(output.squeeze(dim=-1), target)
232258

259+
def _calculate_metrics(self, output, target, metrics):
260+
if not len(metrics):
261+
return
262+
263+
if self.likelihood:
264+
_metric = metrics(target, self.likelihood.sample(output))
265+
else:
266+
# If there's no likelihood, nr_params=1, and we need to squeeze out the
267+
# last dimension of model output, for properly computing the metric.
268+
_metric = metrics(target, output.squeeze(dim=-1))
269+
270+
self.log_dict(
271+
_metric,
272+
on_epoch=True,
273+
on_step=False,
274+
logger=True,
275+
prog_bar=True,
276+
)
277+
233278
def configure_optimizers(self):
234-
"""configures optimizers and learning rate schedulers for for model optimization."""
279+
"""configures optimizers and learning rate schedulers for model optimization."""
235280

236281
# A utility function to create optimizer and lr scheduler from desired classes
237282
def _create_from_cls_and_kwargs(cls, kws):
@@ -365,7 +410,7 @@ def _get_batch_prediction(
365410
self, n: int, input_batch: Tuple, roll_size: int
366411
) -> torch.Tensor:
367412
"""
368-
Feeds PastCovariatesTorchModel with input and output chunks of a PastCovariatesSequentialDataset to farecast
413+
Feeds PastCovariatesTorchModel with input and output chunks of a PastCovariatesSequentialDataset to forecast
369414
the next ``n`` target values per target variable.
370415
371416
Parameters:
@@ -416,7 +461,7 @@ def _get_batch_prediction(
416461
batch_prediction[-1] = batch_prediction[-1][:, :roll_size, :]
417462

418463
# ==========> PAST INPUT <==========
419-
# roll over input series to contain latest target and covariate
464+
# roll over input series to contain the latest target and covariate
420465
input_past = torch.roll(input_past, -roll_size, 1)
421466

422467
# update target input to include next `roll_size` predictions
@@ -532,7 +577,7 @@ def _get_batch_prediction(
532577
self, n: int, input_batch: Tuple, roll_size: int
533578
) -> torch.Tensor:
534579
"""
535-
Feeds MixedCovariatesModel with input and output chunks of a MixedCovariatesSequentialDataset to farecast
580+
Feeds MixedCovariatesModel with input and output chunks of a MixedCovariatesSequentialDataset to forecast
536581
the next ``n`` target values per target variable.
537582
538583
Parameters
@@ -598,7 +643,7 @@ def _get_batch_prediction(
598643
batch_prediction[-1] = batch_prediction[-1][:, :roll_size, :]
599644

600645
# ==========> PAST INPUT <==========
601-
# roll over input series to contain latest target and covariate
646+
# roll over input series to contain the latest target and covariate
602647
input_past = torch.roll(input_past, -roll_size, 1)
603648

604649
# update target input to include next `roll_size` predictions

darts/models/forecasting/rnn_model.py

+3
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,9 @@ def __init__(
257257
PyTorch loss function used for training.
258258
This parameter will be ignored for probabilistic models if the ``likelihood`` parameter is specified.
259259
Default: ``torch.nn.MSELoss()``.
260+
torch_metrics
261+
A torch metric or a ``MetricCollection`` used for evaluation. A full list of available metrics can be found
262+
at https://torchmetrics.readthedocs.io/en/latest/. Default: ``None``.
260263
likelihood
261264
One of Darts' :meth:`Likelihood <darts.utils.likelihood_models.Likelihood>` models to be used for
262265
probabilistic forecasts. Default: ``None``.

darts/models/forecasting/tcn_model.py

+3
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,9 @@ def __init__(
297297
PyTorch loss function used for training.
298298
This parameter will be ignored for probabilistic models if the ``likelihood`` parameter is specified.
299299
Default: ``torch.nn.MSELoss()``.
300+
torch_metrics
301+
A torch metric or a ``MetricCollection`` used for evaluation. A full list of available metrics can be found
302+
at https://torchmetrics.readthedocs.io/en/latest/. Default: ``None``.
300303
likelihood
301304
One of Darts' :meth:`Likelihood <darts.utils.likelihood_models.Likelihood>` models to be used for
302305
probabilistic forecasts. Default: ``None``.

darts/models/forecasting/tft_model.py

+3
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,9 @@ def __init__(
629629
PyTorch loss function used for training. By default, the TFT model is probabilistic and uses a
630630
``likelihood`` instead (``QuantileRegression``). To make the model deterministic, you can set the `
631631
`likelihood`` to None and give a ``loss_fn`` argument.
632+
torch_metrics
633+
A torch metric or a ``MetricCollection`` used for evaluation. A full list of available metrics can be found
634+
at https://torchmetrics.readthedocs.io/en/latest/. Default: ``None``.
632635
likelihood
633636
The likelihood model to be used for probabilistic forecasts. By default, the TFT uses
634637
a ``QuantileRegression`` likelihood.

darts/models/forecasting/transformer_model.py

+3
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,9 @@ def __init__(
273273
PyTorch loss function used for training.
274274
This parameter will be ignored for probabilistic models if the ``likelihood`` parameter is specified.
275275
Default: ``torch.nn.MSELoss()``.
276+
torch_metrics
277+
A torch metric or a ``MetricCollection`` used for evaluation. A full list of available metrics can be found
278+
at https://torchmetrics.readthedocs.io/en/latest/. Default: ``None``.
276279
likelihood
277280
One of Darts' :meth:`Likelihood <darts.utils.likelihood_models.Likelihood>` models to be used for
278281
probabilistic forecasts. Default: ``None``.

0 commit comments

Comments
 (0)