From 18cadc152d6ae3d3516337e4331c6a5dbb8c43ed Mon Sep 17 00:00:00 2001 From: dennisbader Date: Sun, 24 Apr 2022 10:59:37 -0600 Subject: [PATCH] replaced new LinearLR scheduler (only available with torch >= 1.10.0) with StepLR --- darts/models/forecasting/block_rnn_model.py | 3 --- darts/models/forecasting/nbeats.py | 3 --- darts/models/forecasting/nhits.py | 3 --- darts/models/forecasting/pl_forecasting_module.py | 3 +++ darts/models/forecasting/rnn_model.py | 3 --- darts/models/forecasting/tcn_model.py | 3 --- darts/models/forecasting/tft_model.py | 3 --- darts/models/forecasting/transformer_model.py | 3 --- darts/tests/models/forecasting/test_torch_forecasting_model.py | 2 +- 9 files changed, 4 insertions(+), 22 deletions(-) diff --git a/darts/models/forecasting/block_rnn_model.py b/darts/models/forecasting/block_rnn_model.py index cd45edd0fe..eccfc656e8 100644 --- a/darts/models/forecasting/block_rnn_model.py +++ b/darts/models/forecasting/block_rnn_model.py @@ -78,9 +78,6 @@ def __init__( super().__init__(**kwargs) - # required for all modules -> saves hparams for checkpoints - self.save_hyperparameters() - # Defining parameters self.hidden_dim = hidden_dim self.n_layers = num_layers diff --git a/darts/models/forecasting/nbeats.py b/darts/models/forecasting/nbeats.py index 6b4d611bd2..5656c179f5 100644 --- a/darts/models/forecasting/nbeats.py +++ b/darts/models/forecasting/nbeats.py @@ -358,9 +358,6 @@ def __init__( """ super().__init__(**kwargs) - # required for all modules -> saves hparams for checkpoints - self.save_hyperparameters() - self.input_dim = input_dim self.output_dim = output_dim self.nr_params = nr_params diff --git a/darts/models/forecasting/nhits.py b/darts/models/forecasting/nhits.py index 5f7cad1773..a5fb894e73 100644 --- a/darts/models/forecasting/nhits.py +++ b/darts/models/forecasting/nhits.py @@ -347,9 +347,6 @@ def __init__( """ super().__init__(**kwargs) - # required for all modules -> saves hparams for checkpoints - self.save_hyperparameters() - self.input_dim = input_dim self.output_dim = output_dim self.nr_params = nr_params diff --git a/darts/models/forecasting/pl_forecasting_module.py b/darts/models/forecasting/pl_forecasting_module.py index a5fd3f7330..0f85847108 100644 --- a/darts/models/forecasting/pl_forecasting_module.py +++ b/darts/models/forecasting/pl_forecasting_module.py @@ -75,6 +75,9 @@ def __init__( """ super().__init__() + # save hyper parameters for saving/loading + self.save_hyperparameters() + raise_if( input_chunk_length is None or output_chunk_length is None, "Both `input_chunk_length` and `output_chunk_length` must be passed to `PLForecastingModule`", diff --git a/darts/models/forecasting/rnn_model.py b/darts/models/forecasting/rnn_model.py index 1135eef9e2..424671250b 100644 --- a/darts/models/forecasting/rnn_model.py +++ b/darts/models/forecasting/rnn_model.py @@ -73,9 +73,6 @@ def __init__( # RNNModule doesn't really need input and output_chunk_length for PLModule super().__init__(**kwargs) - # required for all modules -> saves hparams for checkpoints - self.save_hyperparameters() - # Defining parameters self.target_size = target_size self.nr_params = nr_params diff --git a/darts/models/forecasting/tcn_model.py b/darts/models/forecasting/tcn_model.py index bd8006d1c0..da96cafe1f 100644 --- a/darts/models/forecasting/tcn_model.py +++ b/darts/models/forecasting/tcn_model.py @@ -183,9 +183,6 @@ def __init__( super().__init__(**kwargs) - # required for all modules -> saves hparams for checkpoints - self.save_hyperparameters() - # Defining parameters self.input_size = input_size self.n_filters = num_filters diff --git a/darts/models/forecasting/tft_model.py b/darts/models/forecasting/tft_model.py index 0282e3dc97..0190de948e 100644 --- a/darts/models/forecasting/tft_model.py +++ b/darts/models/forecasting/tft_model.py @@ -88,9 +88,6 @@ def __init__( super().__init__(**kwargs) - # required for all modules -> saves hparams for checkpoints - self.save_hyperparameters() - self.n_targets, self.loss_size = output_dim self.variables_meta = variables_meta self.hidden_size = hidden_size diff --git a/darts/models/forecasting/transformer_model.py b/darts/models/forecasting/transformer_model.py index 023ad6991b..c94a4db244 100644 --- a/darts/models/forecasting/transformer_model.py +++ b/darts/models/forecasting/transformer_model.py @@ -124,9 +124,6 @@ def __init__( super().__init__(**kwargs) - # required for all modules -> saves hparams for checkpoints - self.save_hyperparameters() - self.input_size = input_size self.target_size = output_size self.nr_params = nr_params diff --git a/darts/tests/models/forecasting/test_torch_forecasting_model.py b/darts/tests/models/forecasting/test_torch_forecasting_model.py index 93a099f87e..fb420a4c47 100644 --- a/darts/tests/models/forecasting/test_torch_forecasting_model.py +++ b/darts/tests/models/forecasting/test_torch_forecasting_model.py @@ -324,7 +324,7 @@ def test_lr_schedulers(self): series = TimeSeries.from_series(pd_series) lr_schedulers = [ - (torch.optim.lr_scheduler.LinearLR, {}), + (torch.optim.lr_scheduler.StepLR, {"step_size": 10}), ( torch.optim.lr_scheduler.ReduceLROnPlateau, {"threshold": 0.001, "monitor": "train_loss"},