Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/torchmetrics #996

Merged
merged 35 commits into from
Jun 15, 2022
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
23b531a
added droput and batch_norm similiar to nhits' implementation
May 11, 2022
156e3bb
activation param for nhits
May 11, 2022
4e7890e
fix
gdevos010 May 11, 2022
37cb53c
support any pytorch activation function. NHiTs AvgPool1d support
May 11, 2022
aee0b8f
PR
May 11, 2022
8aa1948
Update CHANGELOG.md
gdevos010 May 11, 2022
9053e1d
fixed typo
May 12, 2022
d6a2c1f
Merge branch 'fix/nbeats-nhits-TODOs' of https://github.com/gdevos010…
May 12, 2022
63329a6
pytorch lightning did not like saving nn.modules
May 12, 2022
fc7d55a
first pass
May 15, 2022
031e743
Merge remote-tracking branch 'origin/HEAD' into torchmetrics
May 18, 2022
8b62061
metrics now works with likelihood
May 19, 2022
0516e1d
rename var
gdevos010 May 28, 2022
68f46a1
made metrics a parameter. Added Tests
gdevos010 Jun 7, 2022
51bda3e
Merge branch 'unit8co:master' into feat/torchmetrics
gdevos010 Jun 7, 2022
08d74dd
Merge remote-tracking branch 'origin/feat/torchmetrics' into feat/tor…
gdevos010 Jun 7, 2022
337c55a
torchmetrics is already a dependency
gdevos010 Jun 7, 2022
56755b9
changelog
gdevos010 Jun 7, 2022
4231236
Merge branch 'master' into feat/torchmetrics
hrzn Jun 8, 2022
982c297
model now accepts torchmetrics and torchCollections
gdevos010 Jun 9, 2022
5d4c2f8
torchmetric example in early stop
gdevos010 Jun 9, 2022
2f31df1
handle no metrics
gdevos010 Jun 10, 2022
7e5bef0
made _calculate_metrics private
gdevos010 Jun 10, 2022
d1f93e3
Ray tune example
gdevos010 Jun 10, 2022
06c5510
Ray tune example
gdevos010 Jun 10, 2022
1510d36
Merge branch 'master' into feat/torchmetrics
hrzn Jun 11, 2022
fb20deb
changelog
gdevos010 Jun 11, 2022
88874f6
Merge branch 'master' into feat/torchmetrics
gdevos010 Jun 13, 2022
29fd0ab
Update darts/models/forecasting/pl_forecasting_module.py
gdevos010 Jun 13, 2022
5340091
Update darts/models/forecasting/pl_forecasting_module.py
gdevos010 Jun 13, 2022
4993a9c
added torch_metrics to doc strings in torch based models
gdevos010 Jun 13, 2022
6965fe3
added torch_metrics to doc strings in torch based models
gdevos010 Jun 13, 2022
38c395e
Update darts/models/forecasting/pl_forecasting_module.py
dennisbader Jun 13, 2022
5682fc3
Merge branch 'master' into feat/torchmetrics
dennisbader Jun 13, 2022
d6a0e24
black formatting
dennisbader Jun 13, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Darts is still in an early development phase and we cannot always guarantee back
by [Greg DeVos](https://github.com/gdevos010)
- Implemented ["GLU Variants Improve Transformer"](https://arxiv.org/abs/2002.05202) for transformer based models (transformer and TFT). [#959](https://github.com/unit8co/darts/issues/959)
by [Greg DeVos](https://github.com/gdevos010)
- Added support for torch metrics during training and validation. [#996](https://github.com/unit8co/darts/pull/996) by [Greg DeVos](https://github.com/gdevos010)

## [0.19.0](https://github.com/unit8co/darts/tree/0.19.0) (2022-04-13)
### For users of the library:
Expand Down
70 changes: 68 additions & 2 deletions darts/models/forecasting/pl_forecasting_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
"""

from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Sequence, Tuple
from typing import Any, Dict, List, Optional, Sequence, Tuple

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torchmetrics
from joblib import Parallel, delayed

from darts.logging import get_logger, raise_if, raise_log
Expand All @@ -29,10 +30,12 @@ def __init__(
input_chunk_length: int,
output_chunk_length: int,
loss_fn: nn.modules.loss._Loss = nn.MSELoss(),
torch_metrics: Optional[List[str]] = None,
metrics_params: Optional[List[Dict]] = None,
likelihood: Optional[Likelihood] = None,
optimizer_cls: torch.optim.Optimizer = torch.optim.Adam,
optimizer_kwargs: Optional[Dict] = None,
lr_scheduler_cls: torch.optim.lr_scheduler._LRScheduler = None,
lr_scheduler_cls: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
lr_scheduler_kwargs: Optional[Dict] = None,
) -> None:
"""
Expand All @@ -58,6 +61,11 @@ def __init__(
PyTorch loss function used for training.
This parameter will be ignored for probabilistic models if the ``likelihood`` parameter is specified.
Default: ``torch.nn.MSELoss()``.
torch_metrics
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you copy paste the docstring into all our TorchForecastingModels (TFTModel, NBEATSModel, ...)? Otherwise this will not be shown in the model documentation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

List of torch metrics to be used for evaluation. A full list of available metrics can be found at
https://torchmetrics.readthedocs.io/en/latest/. Default: ``None``.
metrics_params
Dictionary of parameters to be passed to the metrics. Default: ``None``.
likelihood
One of Darts' :meth:`Likelihood <darts.utils.likelihood_models.Likelihood>` models to be used for
probabilistic forecasts. Default: ``None``.
Expand Down Expand Up @@ -100,6 +108,23 @@ def __init__(
dict() if lr_scheduler_kwargs is None else lr_scheduler_kwargs
)

self.metrics = []
self.metrics_str = torch_metrics if torch_metrics else []
self.metrics_params = metrics_params if metrics_params else []

if metrics_params:
raise_if(
len(self.metrics_params) != len(self.metrics_str),
"Number of metrics parameters must be equal to number of metrics.",
logger,
)
# create empty dict for each metric
if self.metrics_str and metrics_params is None:
for _ in self.metrics_str:
self.metrics_params.append(dict())

self._setup_metrics()

# initialize prediction parameters
self.pred_n: Optional[int] = None
self.pred_num_samples: Optional[int] = None
Expand All @@ -126,6 +151,7 @@ def training_step(self, train_batch, batch_idx) -> torch.Tensor:
] # By convention target is always the last element returned by datasets
loss = self._compute_loss(output, target)
self.log("train_loss", loss, batch_size=train_batch[0].shape[0], prog_bar=True)
_ = self.calculate_metrics(output, target, tag="train")
return loss

def validation_step(self, val_batch, batch_idx) -> torch.Tensor:
Expand All @@ -134,6 +160,7 @@ def validation_step(self, val_batch, batch_idx) -> torch.Tensor:
target = val_batch[-1]
loss = self._compute_loss(output, target)
self.log("val_loss", loss, batch_size=val_batch[0].shape[0], prog_bar=True)
_ = self.calculate_metrics(output, target, tag="valid")
return loss

def predict_step(
Expand Down Expand Up @@ -230,6 +257,45 @@ def _compute_loss(self, output, target):
# last dimension of model output, for properly computing the loss.
return self.criterion(output.squeeze(dim=-1), target)

def _setup_metrics(self):
if self.metrics_str:
self.metrics = []
task_module = torchmetrics.functional
for metric in self.metrics_str:
try:
self.metrics.append(getattr(task_module, metric))
except AttributeError as e:
raise_log(
ValueError(
f"{metric} is not a valid functional metric defined in the torchmetrics.functional module"
)
)
raise e

def calculate_metrics(self, y, y_hat, tag):
metrics = []
for metric, metric_str, metric_params in zip(
self.metrics, self.metrics_str, self.metrics_params
):
if self.likelihood:
_metric = metric(y_hat, self.likelihood.sample(y), **metric_params)
else:
# If there's no likelihood, nr_params=1 and we need to squeeze out the
# last dimension of model output, for properly computing the metric.
_metric = metric(y_hat, y.squeeze(dim=-1), **metric_params)

metrics.append(_metric)

self.log(
f"{tag}_{metric_str}",
_metric,
on_epoch=True,
on_step=False,
logger=True,
prog_bar=True,
)
return metrics

def configure_optimizers(self):
"""configures optimizers and learning rate schedulers for for model optimization."""

Expand Down
103 changes: 60 additions & 43 deletions darts/tests/models/forecasting/test_torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,16 @@
logger.warning("Torch not available. RNN tests will be skipped.")
TORCH_AVAILABLE = False


if TORCH_AVAILABLE:

class TestTorchForecastingModel(DartsBaseTestClass):
def setUp(self):
self.temp_work_dir = tempfile.mkdtemp(prefix="darts")

times = pd.date_range("20130101", "20130410")
pd_series = pd.Series(range(100), index=times)
self.series = TimeSeries.from_series(pd_series)

def tearDown(self):
shutil.rmtree(self.temp_work_dir)

Expand Down Expand Up @@ -61,11 +64,8 @@ def test_suppress_automatic_save(self, patch_save_model):
save_checkpoints=False,
)

times = pd.date_range("20130101", "20130410")
pd_series = pd.Series(range(100), index=times)
series = TimeSeries.from_series(pd_series)
model1.fit(series, epochs=1)
model2.fit(series, epochs=1)
model1.fit(self.series, epochs=1)
model2.fit(self.series, epochs=1)

model1.predict(n=1)
model2.predict(n=2)
Expand Down Expand Up @@ -101,12 +101,8 @@ def test_manual_save_and_load(self):
random_state=42,
)

times = pd.date_range("20130101", "20130410")
pd_series = pd.Series(range(100), index=times)
series = TimeSeries.from_series(pd_series)

model_manual_save.fit(series, epochs=1)
model_auto_save.fit(series, epochs=1)
model_manual_save.fit(self.series, epochs=1)
model_auto_save.fit(self.series, epochs=1)

model_dir = os.path.join(self.temp_work_dir)

Expand Down Expand Up @@ -215,10 +211,7 @@ def test_create_instance_existing_model_with_name_force_fit_with_reset(
)
# no exception is raised

times = pd.date_range("20130101", "20130410")
pd_series = pd.Series(range(100), index=times)
series = TimeSeries.from_series(pd_series)
model1.fit(series, epochs=1)
model1.fit(self.series, epochs=1)

RNNModel(
12,
Expand All @@ -242,10 +235,7 @@ def test_train_from_0_n_epochs_20_no_fit_epochs(self):
12, "RNN", 10, 10, n_epochs=20, work_dir=self.temp_work_dir
)

times = pd.date_range("20130101", "20130410")
pd_series = pd.Series(range(100), index=times)
series = TimeSeries.from_series(pd_series)
model1.fit(series)
model1.fit(self.series)

self.assertEqual(20, model1.epochs_trained)

Expand All @@ -255,13 +245,10 @@ def test_train_from_20_n_epochs_40_no_fit_epochs(self):
12, "RNN", 10, 10, n_epochs=20, work_dir=self.temp_work_dir
)

times = pd.date_range("20130101", "20130410")
pd_series = pd.Series(range(100), index=times)
series = TimeSeries.from_series(pd_series)
model1.fit(series)
model1.fit(self.series)
self.assertEqual(20, model1.epochs_trained)

model1.fit(series)
model1.fit(self.series)
self.assertEqual(20, model1.epochs_trained)

# n_epochs = 20, fit|epochs=None, epochs_trained=10 - train for another 20 epochs
Expand All @@ -270,14 +257,11 @@ def test_train_from_10_n_epochs_20_no_fit_epochs(self):
12, "RNN", 10, 10, n_epochs=20, work_dir=self.temp_work_dir
)

times = pd.date_range("20130101", "20130410")
pd_series = pd.Series(range(100), index=times)
series = TimeSeries.from_series(pd_series)
# simulate the case that user interrupted training with Ctrl-C after 10 epochs
model1.fit(series, epochs=10)
model1.fit(self.series, epochs=10)
self.assertEqual(10, model1.epochs_trained)

model1.fit(series)
model1.fit(self.series)
self.assertEqual(20, model1.epochs_trained)

# n_epochs = 20, fit|epochs=15, epochs_trained=10 - train for 15 epochs
Expand All @@ -286,20 +270,14 @@ def test_train_from_10_n_epochs_20_fit_15_epochs(self):
12, "RNN", 10, 10, n_epochs=20, work_dir=self.temp_work_dir
)

times = pd.date_range("20130101", "20130410")
pd_series = pd.Series(range(100), index=times)
series = TimeSeries.from_series(pd_series)
# simulate the case that user interrupted training with Ctrl-C after 10 epochs
model1.fit(series, epochs=10)
model1.fit(self.series, epochs=10)
self.assertEqual(10, model1.epochs_trained)

model1.fit(series, epochs=15)
model1.fit(self.series, epochs=15)
self.assertEqual(15, model1.epochs_trained)

def test_optimizers(self):
times = pd.date_range("20130101", "20130410")
pd_series = pd.Series(range(100), index=times)
series = TimeSeries.from_series(pd_series)

optimizers = [
(torch.optim.Adam, {"lr": 0.001}),
Expand All @@ -316,12 +294,9 @@ def test_optimizers(self):
optimizer_kwargs=optim_kwargs,
)
# should not raise an error
model.fit(series, epochs=1)
model.fit(self.series, epochs=1)

def test_lr_schedulers(self):
times = pd.date_range("20130101", "20130410")
pd_series = pd.Series(range(100), index=times)
series = TimeSeries.from_series(pd_series)

lr_schedulers = [
(torch.optim.lr_scheduler.StepLR, {"step_size": 10}),
Expand All @@ -342,7 +317,7 @@ def test_lr_schedulers(self):
lr_scheduler_kwargs=lr_scheduler_kwargs,
)
# should not raise an error
model.fit(series, epochs=1)
model.fit(self.series, epochs=1)

def test_devices(self):
torch_devices = [
Expand Down Expand Up @@ -373,3 +348,45 @@ def test_wrong_model_creation_params(self):
# invalid params should raise an error
with self.assertRaises(ValueError):
_ = RNNModel(12, "RNN", 10, 10, **invalid_kwarg)

def test_metrics(self):
torch_metrics = ["mean_squared_error", "mean_absolute_percentage_error"]
model = RNNModel(12, "RNN", 10, 10, n_epochs=1, torch_metrics=torch_metrics)
model.fit(self.series)

def test_metrics_w_params(self):
torch_metrics = ["mean_squared_error", "mean_absolute_percentage_error"]
metrics_params = [{}, {}]
model = RNNModel(
12,
"RNN",
10,
10,
n_epochs=1,
torch_metrics=torch_metrics,
metrics_params=metrics_params,
)
model.fit(self.series)

def test_invalid_metrics(self):
torch_metrics = ["invalid"]
with self.assertRaises(ValueError):
model = RNNModel(
12, "RNN", 10, 10, n_epochs=1, torch_metrics=torch_metrics
)
model.fit(self.series)

def test_wrong_metrics_param_count(self):
torch_metrics = ["mean_squared_error", "mean_absolute_percentage_error"]
metrics_params = [{}]
with self.assertRaises(ValueError):
model = RNNModel(
12,
"RNN",
10,
10,
n_epochs=1,
torch_metrics=torch_metrics,
metrics_params=metrics_params,
)
model.fit(self.series)