From aa35583afb3eae9d309aec4cf01b22802e2342f7 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 29 Mar 2021 20:15:19 +0100 Subject: [PATCH] add docstring --- pytorch_lightning/accelerators/accelerator.py | 6 ++++++ pytorch_lightning/plugins/training_type/tpu_spawn.py | 3 +-- .../plugins/training_type/training_type_plugin.py | 6 ++++++ 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 31736c13c6351..7d16d91e3bf82 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -481,6 +481,12 @@ def connect_precision_plugin(self, plugin: PrecisionPlugin) -> None: self.setup_precision_plugin(plugin) def save_checkpoint(self, checkpoint: Dict[str, Any], filepath) -> None: + """Save model/training states as a checkpoint file through state-dump and file-write. + + Args: + checkpoint: dict containing model and trainer state + filepath: write-target file's path + """ self.training_type_plugin.save_checkpoint(checkpoint, filepath) @property diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index aee2b8914b579..ba074e7cfb206 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -300,9 +300,8 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: - trainer: PyTorch Lightning Trainer + checkpoint: dict containing model and trainer state filepath: write-target file's path - weights_only: saving model weights only """ # Todo: TypeError: 'mappingproxy' object does not support item assignment self.save({k: v for k, v in checkpoint.items() if k != "callbacks"}, filepath) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 0250933c9da3c..1eac88212e0fb 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -198,6 +198,12 @@ def setup_optimizers_in_pre_dispatch(self) -> bool: return False def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: + """Save model/training states as a checkpoint file through state-dump and file-write. + + Args: + checkpoint: dict containing model and trainer state + filepath: write-target file's path + """ # dump states as a checkpoint dictionary object if self.is_global_zero: checkpoint = self.on_save(checkpoint)