Skip to content

Commit fc09e50

Browse files
authored
replaced new LinearLR scheduler (only available with torch >= 1.10.0) with StepLR (#928)
1 parent 0d1a208 commit fc09e50

9 files changed

+4
-22
lines changed

darts/models/forecasting/block_rnn_model.py

-3
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,6 @@ def __init__(
7878

7979
super().__init__(**kwargs)
8080

81-
# required for all modules -> saves hparams for checkpoints
82-
self.save_hyperparameters()
83-
8481
# Defining parameters
8582
self.hidden_dim = hidden_dim
8683
self.n_layers = num_layers

darts/models/forecasting/nbeats.py

-3
Original file line numberDiff line numberDiff line change
@@ -358,9 +358,6 @@ def __init__(
358358
"""
359359
super().__init__(**kwargs)
360360

361-
# required for all modules -> saves hparams for checkpoints
362-
self.save_hyperparameters()
363-
364361
self.input_dim = input_dim
365362
self.output_dim = output_dim
366363
self.nr_params = nr_params

darts/models/forecasting/nhits.py

-3
Original file line numberDiff line numberDiff line change
@@ -347,9 +347,6 @@ def __init__(
347347
"""
348348
super().__init__(**kwargs)
349349

350-
# required for all modules -> saves hparams for checkpoints
351-
self.save_hyperparameters()
352-
353350
self.input_dim = input_dim
354351
self.output_dim = output_dim
355352
self.nr_params = nr_params

darts/models/forecasting/pl_forecasting_module.py

+3
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ def __init__(
7575
"""
7676
super().__init__()
7777

78+
# save hyper parameters for saving/loading
79+
self.save_hyperparameters()
80+
7881
raise_if(
7982
input_chunk_length is None or output_chunk_length is None,
8083
"Both `input_chunk_length` and `output_chunk_length` must be passed to `PLForecastingModule`",

darts/models/forecasting/rnn_model.py

-3
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,6 @@ def __init__(
7373
# RNNModule doesn't really need input and output_chunk_length for PLModule
7474
super().__init__(**kwargs)
7575

76-
# required for all modules -> saves hparams for checkpoints
77-
self.save_hyperparameters()
78-
7976
# Defining parameters
8077
self.target_size = target_size
8178
self.nr_params = nr_params

darts/models/forecasting/tcn_model.py

-3
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,6 @@ def __init__(
183183

184184
super().__init__(**kwargs)
185185

186-
# required for all modules -> saves hparams for checkpoints
187-
self.save_hyperparameters()
188-
189186
# Defining parameters
190187
self.input_size = input_size
191188
self.n_filters = num_filters

darts/models/forecasting/tft_model.py

-3
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,6 @@ def __init__(
8888

8989
super().__init__(**kwargs)
9090

91-
# required for all modules -> saves hparams for checkpoints
92-
self.save_hyperparameters()
93-
9491
self.n_targets, self.loss_size = output_dim
9592
self.variables_meta = variables_meta
9693
self.hidden_size = hidden_size

darts/models/forecasting/transformer_model.py

-3
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,6 @@ def __init__(
124124

125125
super().__init__(**kwargs)
126126

127-
# required for all modules -> saves hparams for checkpoints
128-
self.save_hyperparameters()
129-
130127
self.input_size = input_size
131128
self.target_size = output_size
132129
self.nr_params = nr_params

darts/tests/models/forecasting/test_torch_forecasting_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def test_lr_schedulers(self):
324324
series = TimeSeries.from_series(pd_series)
325325

326326
lr_schedulers = [
327-
(torch.optim.lr_scheduler.LinearLR, {}),
327+
(torch.optim.lr_scheduler.StepLR, {"step_size": 10}),
328328
(
329329
torch.optim.lr_scheduler.ReduceLROnPlateau,
330330
{"threshold": 0.001, "monitor": "train_loss"},

0 commit comments

Comments
 (0)