Skip to content

Commit ab2df1a

Browse files
authored
fix deep copying of TFM trainer parameters (#1459)
* fix deep copying of TFM trainer parameters * fix failing tests
1 parent fab7ddf commit ab2df1a

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

darts/models/forecasting/torch_forecasting_model.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -475,11 +475,16 @@ def _init_trainer(
475475
trainer_params: dict, max_epochs: Optional[int] = None
476476
) -> pl.Trainer:
477477
"""Initializes a PyTorch-Lightning trainer for training or prediction from `trainer_params`."""
478-
trainer_params_copy = {param: val for param, val in trainer_params.items()}
478+
trainer_params_copy = {key: val for key, val in trainer_params.items()}
479479
if max_epochs is not None:
480480
trainer_params_copy["max_epochs"] = max_epochs
481481

482-
return pl.Trainer(**trainer_params_copy)
482+
# prevent lightning from adding callbacks to the callbacks list in `self.trainer_params`
483+
callbacks = trainer_params_copy.pop("callbacks", None)
484+
return pl.Trainer(
485+
callbacks=[cb for cb in callbacks] if callbacks is not None else callbacks,
486+
**trainer_params_copy,
487+
)
483488

484489
@abstractmethod
485490
def _create_model(self, train_sample: Tuple[Tensor]) -> torch.nn.Module:

darts/tests/models/forecasting/test_ptl_trainer.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,9 @@ def on_train_epoch_end(self, *args, **kwargs):
151151

152152
# check if callbacks were added
153153
self.assertEqual(len(model.trainer_params["callbacks"]), 2)
154-
model.fit(self.series, epochs=2)
154+
model.fit(self.series, epochs=2, verbose=True)
155+
# check that lightning did not mutate callbacks (verbosity adds a progress bar callback)
156+
self.assertEqual(len(model.trainer_params["callbacks"]), 2)
155157

156158
self.assertEqual(my_counter_0.counter, model.epochs_trained)
157159
self.assertEqual(my_counter_2.counter, model.epochs_trained + 2)

0 commit comments

Comments
 (0)