Skip to content

Commit

Permalink
add docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton committed Mar 29, 2021
1 parent 2dcafd0 commit aa35583
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 2 deletions.
6 changes: 6 additions & 0 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,12 @@ def connect_precision_plugin(self, plugin: PrecisionPlugin) -> None:
self.setup_precision_plugin(plugin)

def save_checkpoint(self, checkpoint: Dict[str, Any], filepath) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.
Args:
checkpoint: dict containing model and trainer state
filepath: write-target file's path
"""
self.training_type_plugin.save_checkpoint(checkpoint, filepath)

@property
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,9 +300,8 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.
Args:
trainer: PyTorch Lightning Trainer
checkpoint: dict containing model and trainer state
filepath: write-target file's path
weights_only: saving model weights only
"""
# Todo: TypeError: 'mappingproxy' object does not support item assignment
self.save({k: v for k, v in checkpoint.items() if k != "callbacks"}, filepath)
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,12 @@ def setup_optimizers_in_pre_dispatch(self) -> bool:
return False

def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.
Args:
checkpoint: dict containing model and trainer state
filepath: write-target file's path
"""
# dump states as a checkpoint dictionary object
if self.is_global_zero:
checkpoint = self.on_save(checkpoint)
Expand Down

0 comments on commit aa35583

Please sign in to comment.