diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 44422cc49cddc..0f55100bf1ab9 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -203,12 +203,14 @@ def test_step(self, *args, **kwargs): def predict(self, *args, **kwargs): return self.lightning_module.predict(*args, **kwargs) - def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: + def save_checkpoint(self, filepath: str, weights_only: bool = False) -> None: """Save model/training states as a checkpoint file through state-dump and file-write. Args: - checkpoint: dict containing model and trainer state filepath: write-target file's path + weights_only: saving model weights only """ + # dump states as a checkpoint dictionary object + checkpoint = self.lightning_module.trainer.checkpoint_connector.dump_checkpoint(weights_only) # Todo: TypeError: 'mappingproxy' object does not support item assignment if _OMEGACONF_AVAILABLE: checkpoint = apply_to_collection(checkpoint, (DictConfig, ListConfig), OmegaConf.to_container) diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index b64b22c66caa7..24c0b615b95bb 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -122,7 +122,7 @@ def test_model_16bit_tpu_cores_1(tmpdir): progress_bar_refresh_rate=0, max_epochs=2, tpu_cores=1, - limit_train_batches=8, + limit_train_batches=0.7, limit_val_batches=2, ) @@ -210,8 +210,8 @@ def test_tpu_grad_norm(tmpdir): progress_bar_refresh_rate=0, max_epochs=4, tpu_cores=1, - limit_train_batches=0.4, - limit_val_batches=0.4, + limit_train_batches=10, + limit_val_batches=10, gradient_clip_val=0.5, ) diff --git a/tests/test_profiler.py b/tests/test_profiler.py index 667e153a9edd4..6abcf17a04893 100644 --- a/tests/test_profiler.py +++ b/tests/test_profiler.py @@ -20,6 +20,7 @@ import pytest from pytorch_lightning.profiler import AdvancedProfiler, SimpleProfiler +from tests.helpers.runif import RunIf PROFILER_OVERHEAD_MAX_TOLERANCE = 0.0005 @@ -165,6 +166,7 @@ def test_advanced_profiler_overhead(advanced_profiler, n_iter=5): assert average_duration < PROFILER_OVERHEAD_MAX_TOLERANCE +@RunIf(max_torch="1.8.1") def test_advanced_profiler_describe(tmpdir, advanced_profiler): """ ensure the profiler won't fail when reporting the summary diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index fd2b48a3fa140..306d38d2d651b 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -42,6 +42,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate from tests.helpers import BoringModel, RandomDataset +from tests.helpers.runif import RunIf @pytest.fixture @@ -1499,6 +1500,7 @@ def test_trainer_predict_ddp_cpu(tmpdir): predict(tmpdir, "ddp_cpu", 0, 2) +@RunIf(max_torch="1.8.1") def test_pytorch_profiler_describe(pytorch_profiler): """Ensure the profiler won't fail when reporting the summary.""" with pytorch_profiler.profile("test_step"):