Skip to content

Commit efa955a

Browse files
authored
Fix/ptl1.6.0 (#888)
* fix epochs trained count * save PTL module and trainer using PTL checkpointing * dynamically compute right number of epochs trained * test checkpoint file existence * restore model saving in tests
1 parent fb5a59e commit efa955a

File tree

3 files changed

+34
-3
lines changed

3 files changed

+34
-3
lines changed

darts/models/forecasting/pl_forecasting_module.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717

1818
logger = get_logger(__name__)
1919

20+
# Check whether we are running pytorch-lightning >= 1.6.0 or not:
21+
tokens = pl.__version__.split(".")
22+
pl_160_or_above = int(tokens[0]) >= 1 and int(tokens[1]) >= 6
23+
2024

2125
class PLForecastingModule(pl.LightningModule, ABC):
2226
@abstractmethod
@@ -324,10 +328,12 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
324328

325329
@property
326330
def epochs_trained(self):
327-
# trained epochs are only 0 when global step and current epoch are 0, else current epoch + 1
328331
current_epoch = self.current_epoch
329-
if self.current_epoch or self.global_step:
332+
333+
# For PTL < 1.6.0 we have to adjust:
334+
if not pl_160_or_above and (self.current_epoch or self.global_step):
330335
current_epoch += 1
336+
331337
return current_epoch
332338

333339

darts/models/forecasting/torch_forecasting_model.py

+17
Original file line numberDiff line numberDiff line change
@@ -1299,16 +1299,25 @@ def save_model(self, path: str) -> None:
12991299
path
13001300
Path under which to save the model at its current state.
13011301
"""
1302+
# TODO: the parameters are saved twice currently, once with complete
1303+
# object, and once with PTL checkpointing.
13021304

13031305
raise_if_not(
13041306
path.endswith(".pth.tar"),
13051307
"The given path should end with '.pth.tar'.",
13061308
logger,
13071309
)
13081310

1311+
# We save the whole object to keep track of everything
13091312
with open(path, "wb") as f_out:
13101313
torch.save(self, f_out)
13111314

1315+
# In addition, we need to use PTL save_checkpoint() to properly save the trainer and model
1316+
if self.trainer is not None:
1317+
base_path = path[:-8]
1318+
path_ptl_ckpt = base_path + "_ptl-ckpt.pth.tar"
1319+
self.trainer.save_checkpoint(path_ptl_ckpt)
1320+
13121321
@staticmethod
13131322
def load_model(path: str) -> "TorchForecastingModel":
13141323
"""loads a model from a given file path. The file name should end with '.pth.tar'
@@ -1337,6 +1346,14 @@ def load_model(path: str) -> "TorchForecastingModel":
13371346

13381347
with open(path, "rb") as fin:
13391348
model = torch.load(fin)
1349+
1350+
# If a PTL checkpoint was saved, we also need to load it:
1351+
base_path = path[:-8]
1352+
path_ptl_ckpt = base_path + "_ptl-ckpt.pth.tar"
1353+
if os.path.exists(path_ptl_ckpt):
1354+
model.model = model.model.__class__.load_from_checkpoint(path_ptl_ckpt)
1355+
model.trainer = model.model.trainer
1356+
13401357
return model
13411358

13421359
@staticmethod

darts/tests/models/forecasting/test_torch_forecasting_model.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -123,14 +123,22 @@ def test_manual_save_and_load(self):
123123
checkpoint_path_manual = os.path.join(model_dir, manual_name)
124124
os.mkdir(checkpoint_path_manual)
125125

126-
# save manually saved model
127126
checkpoint_file_name = "checkpoint_0.pth.tar"
128127
model_path_manual = os.path.join(
129128
checkpoint_path_manual, checkpoint_file_name
130129
)
130+
checkpoint_file_name_cpkt = "checkpoint_0_ptl-ckpt.pth.tar"
131+
model_path_manual_ckpt = os.path.join(
132+
checkpoint_path_manual, checkpoint_file_name_cpkt
133+
)
134+
135+
# save manually saved model
131136
model_manual_save.save_model(model_path_manual)
132137
self.assertTrue(os.path.exists(model_path_manual))
133138

139+
# check that the PTL checkpoint path is also there
140+
self.assertTrue(os.path.exists(model_path_manual_ckpt))
141+
134142
# load manual save model and compare with automatic model results
135143
model_manual_save = RNNModel.load_model(model_path_manual)
136144
self.assertEqual(

0 commit comments

Comments
 (0)