Skip to content

Commit e4a22dd

Browse files
madtoinoudennisbader
authored andcommitted
Fix/ loading metrics and loss in load_from_checkpoint (unit8co#1759)
* fix: loss_fn and torch_metrics are properly restored when calling laoding_from_checkpoint() * fix: moved fix to the PL on_save/on_load methods instead of load_from_checkpoint() * fix: address reviewer comments, loss and metrics objects are saved in the constructor * update changelog --------- Co-authored-by: Dennis Bader <[email protected]>
1 parent df67dec commit e4a22dd

File tree

4 files changed

+111
-18
lines changed

4 files changed

+111
-18
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
1111

1212
**Improved**
1313
- Added support for `PathLike` to the `save()` and `load()` functions of `ForecastingModel`. [#1754](https://github.com/unit8co/darts/pull/1754) by [Simon Sudrich](https://github.com/sudrich).
14+
- Fixed an issue with `TorchForecastingModel.load_from_checkpoint()` not properly loading the loss function and metrics. [#1749](https://github.com/unit8co/darts/pull/1749) by [Antoine Madrona](https://github.com/madtoinou).
1415

1516
**Fixed**
1617
- Fixed an issue not considering original component names for `TimeSeries.plot()` when providing a label prefix. [#1783](https://github.com/unit8co/darts/pull/1783) by [Simon Sudrich](https://github.com/sudrich).

darts/models/forecasting/pl_forecasting_module.py

+23-15
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,7 @@ def __init__(
8888
super().__init__()
8989

9090
# save hyper parameters for saving/loading
91-
# do not save type nn.Module params
92-
self.save_hyperparameters(ignore=["loss_fn", "torch_metrics"])
91+
self.save_hyperparameters()
9392

9493
raise_if(
9594
input_chunk_length is None or output_chunk_length is None,
@@ -116,19 +115,8 @@ def __init__(
116115
dict() if lr_scheduler_kwargs is None else lr_scheduler_kwargs
117116
)
118117

119-
if torch_metrics is None:
120-
torch_metrics = torchmetrics.MetricCollection([])
121-
elif isinstance(torch_metrics, torchmetrics.Metric):
122-
torch_metrics = torchmetrics.MetricCollection([torch_metrics])
123-
elif isinstance(torch_metrics, torchmetrics.MetricCollection):
124-
pass
125-
else:
126-
raise_log(
127-
AttributeError(
128-
"`torch_metrics` only accepts type torchmetrics.Metric or torchmetrics.MetricCollection"
129-
),
130-
logger,
131-
)
118+
# convert torch_metrics to torchmetrics.MetricCollection
119+
torch_metrics = self.configure_torch_metrics(torch_metrics)
132120
self.train_metrics = torch_metrics.clone(prefix="train_")
133121
self.val_metrics = torch_metrics.clone(prefix="val_")
134122

@@ -425,6 +413,26 @@ def epochs_trained(self):
425413

426414
return current_epoch
427415

416+
@staticmethod
417+
def configure_torch_metrics(
418+
torch_metrics: Union[torchmetrics.Metric, torchmetrics.MetricCollection]
419+
) -> torchmetrics.MetricCollection:
420+
"""process the torch_metrics parameter."""
421+
if torch_metrics is None:
422+
torch_metrics = torchmetrics.MetricCollection([])
423+
elif isinstance(torch_metrics, torchmetrics.Metric):
424+
torch_metrics = torchmetrics.MetricCollection([torch_metrics])
425+
elif isinstance(torch_metrics, torchmetrics.MetricCollection):
426+
pass
427+
else:
428+
raise_log(
429+
AttributeError(
430+
"`torch_metrics` only accepts type torchmetrics.Metric or torchmetrics.MetricCollection"
431+
),
432+
logger,
433+
)
434+
return torch_metrics
435+
428436

429437
class PLPastCovariatesModule(PLForecastingModule, ABC):
430438
def _produce_train_output(self, input_batch: Tuple):

darts/models/forecasting/torch_forecasting_model.py

+1
Original file line numberDiff line numberDiff line change
@@ -1682,6 +1682,7 @@ def load_from_checkpoint(
16821682
logger.info(f"loading {file_name}")
16831683

16841684
model.model = model._load_from_checkpoint(file_path, **kwargs)
1685+
16851686
# restore _fit_called attribute, set to False in load() if no .ckpt is found/provided
16861687
model._fit_called = True
16871688
model.load_ckpt_path = file_path

darts/tests/models/forecasting/test_torch_forecasting_model.py

+86-3
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
try:
1818
import torch
19+
from pytorch_lightning.loggers.logger import DummyLogger
1920
from pytorch_lightning.tuner.lr_finder import _LRFinder
2021
from torchmetrics import (
2122
MeanAbsoluteError,
@@ -471,6 +472,63 @@ def test_load_weights(self):
471472
f"respectively {retrained_mape} and {original_mape}",
472473
)
473474

475+
def test_load_from_checkpoint_w_custom_loss(self):
476+
model_name = "pretraining_custom_loss"
477+
# model with a custom loss
478+
model = RNNModel(
479+
12,
480+
"RNN",
481+
5,
482+
1,
483+
n_epochs=1,
484+
work_dir=self.temp_work_dir,
485+
model_name=model_name,
486+
save_checkpoints=True,
487+
force_reset=True,
488+
loss_fn=torch.nn.L1Loss(),
489+
)
490+
model.fit(self.series)
491+
492+
loaded_model = RNNModel.load_from_checkpoint(
493+
model_name, self.temp_work_dir, best=False
494+
)
495+
# custom loss function should be properly restored from ckpt
496+
self.assertTrue(isinstance(loaded_model.model.criterion, torch.nn.L1Loss))
497+
498+
loaded_model.fit(self.series, epochs=2)
499+
# calling fit() should not impact the loss function
500+
self.assertTrue(isinstance(loaded_model.model.criterion, torch.nn.L1Loss))
501+
502+
def test_load_from_checkpoint_w_metrics(self):
503+
model_name = "pretraining_metrics"
504+
# model with one torch_metrics
505+
model = RNNModel(
506+
12,
507+
"RNN",
508+
5,
509+
1,
510+
n_epochs=1,
511+
work_dir=self.temp_work_dir,
512+
model_name=model_name,
513+
save_checkpoints=True,
514+
force_reset=True,
515+
torch_metrics=MeanAbsolutePercentageError(),
516+
pl_trainer_kwargs={"logger": DummyLogger(), "log_every_n_steps": 1},
517+
)
518+
model.fit(self.series)
519+
# check train_metrics before loading
520+
self.assertTrue(isinstance(model.model.train_metrics, MetricCollection))
521+
self.assertEqual(len(model.model.train_metrics), 1)
522+
523+
loaded_model = RNNModel.load_from_checkpoint(
524+
model_name, self.temp_work_dir, best=False
525+
)
526+
# custom loss function should be properly restored from ckpt torchmetrics.Metric
527+
self.assertTrue(
528+
isinstance(loaded_model.model.train_metrics, MetricCollection)
529+
)
530+
self.assertEqual(len(loaded_model.model.train_metrics), 1)
531+
474532
def test_optimizers(self):
475533

476534
optimizers = [
@@ -531,17 +589,39 @@ def test_metrics(self):
531589
)
532590

533591
# test single metric
534-
model = RNNModel(12, "RNN", 10, 10, n_epochs=1, torch_metrics=metric)
592+
model = RNNModel(
593+
12,
594+
"RNN",
595+
10,
596+
10,
597+
n_epochs=1,
598+
torch_metrics=metric,
599+
pl_trainer_kwargs={"logger": DummyLogger(), "log_every_n_steps": 1},
600+
)
535601
model.fit(self.series)
536602

537603
# test metric collection
538604
model = RNNModel(
539-
12, "RNN", 10, 10, n_epochs=1, torch_metrics=metric_collection
605+
12,
606+
"RNN",
607+
10,
608+
10,
609+
n_epochs=1,
610+
torch_metrics=metric_collection,
611+
pl_trainer_kwargs={"logger": DummyLogger(), "log_every_n_steps": 1},
540612
)
541613
model.fit(self.series)
542614

543615
# test multivariate series
544-
model = RNNModel(12, "RNN", 10, 10, n_epochs=1, torch_metrics=metric)
616+
model = RNNModel(
617+
12,
618+
"RNN",
619+
10,
620+
10,
621+
n_epochs=1,
622+
torch_metrics=metric,
623+
pl_trainer_kwargs={"logger": DummyLogger(), "log_every_n_steps": 1},
624+
)
545625
model.fit(self.multivariate_series)
546626

547627
def test_metrics_w_likelihood(self):
@@ -559,6 +639,7 @@ def test_metrics_w_likelihood(self):
559639
n_epochs=1,
560640
likelihood=GaussianLikelihood(),
561641
torch_metrics=metric,
642+
pl_trainer_kwargs={"logger": DummyLogger(), "log_every_n_steps": 1},
562643
)
563644
model.fit(self.series)
564645

@@ -571,6 +652,7 @@ def test_metrics_w_likelihood(self):
571652
n_epochs=1,
572653
likelihood=GaussianLikelihood(),
573654
torch_metrics=metric_collection,
655+
pl_trainer_kwargs={"logger": DummyLogger(), "log_every_n_steps": 1},
574656
)
575657
model.fit(self.series)
576658

@@ -583,6 +665,7 @@ def test_metrics_w_likelihood(self):
583665
n_epochs=1,
584666
likelihood=GaussianLikelihood(),
585667
torch_metrics=metric_collection,
668+
pl_trainer_kwargs={"logger": DummyLogger(), "log_every_n_steps": 1},
586669
)
587670
model.fit(self.multivariate_series)
588671

0 commit comments

Comments
 (0)