Skip to content

Commit

Permalink
Fix TPU tests for checkpoint
Browse files Browse the repository at this point in the history
Skip advanced profiler for torch > 1.8

Skip pytorch profiler for torch > 1.8

Fix save checkpoint logic for TPUs
  • Loading branch information
kaushikb11 authored and lexierule committed Apr 7, 2021
1 parent 123e20d commit f5f4f03
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 5 deletions.
6 changes: 4 additions & 2 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,12 +203,14 @@ def test_step(self, *args, **kwargs):
def predict(self, *args, **kwargs):
return self.lightning_module.predict(*args, **kwargs)

def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
def save_checkpoint(self, filepath: str, weights_only: bool = False) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.
Args:
checkpoint: dict containing model and trainer state
filepath: write-target file's path
weights_only: saving model weights only
"""
# dump states as a checkpoint dictionary object
checkpoint = self.lightning_module.trainer.checkpoint_connector.dump_checkpoint(weights_only)
# Todo: TypeError: 'mappingproxy' object does not support item assignment
if _OMEGACONF_AVAILABLE:
checkpoint = apply_to_collection(checkpoint, (DictConfig, ListConfig), OmegaConf.to_container)
Expand Down
6 changes: 3 additions & 3 deletions tests/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def test_model_16bit_tpu_cores_1(tmpdir):
progress_bar_refresh_rate=0,
max_epochs=2,
tpu_cores=1,
limit_train_batches=8,
limit_train_batches=0.7,
limit_val_batches=2,
)

Expand Down Expand Up @@ -210,8 +210,8 @@ def test_tpu_grad_norm(tmpdir):
progress_bar_refresh_rate=0,
max_epochs=4,
tpu_cores=1,
limit_train_batches=0.4,
limit_val_batches=0.4,
limit_train_batches=10,
limit_val_batches=10,
gradient_clip_val=0.5,
)

Expand Down
2 changes: 2 additions & 0 deletions tests/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import pytest

from pytorch_lightning.profiler import AdvancedProfiler, SimpleProfiler
from tests.helpers.runif import RunIf

PROFILER_OVERHEAD_MAX_TOLERANCE = 0.0005

Expand Down Expand Up @@ -165,6 +166,7 @@ def test_advanced_profiler_overhead(advanced_profiler, n_iter=5):
assert average_duration < PROFILER_OVERHEAD_MAX_TOLERANCE


@RunIf(max_torch="1.8.1")
def test_advanced_profiler_describe(tmpdir, advanced_profiler):
"""
ensure the profiler won't fail when reporting the summary
Expand Down
2 changes: 2 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
from tests.helpers import BoringModel, RandomDataset
from tests.helpers.runif import RunIf


@pytest.fixture
Expand Down Expand Up @@ -1499,6 +1500,7 @@ def test_trainer_predict_ddp_cpu(tmpdir):
predict(tmpdir, "ddp_cpu", 0, 2)


@RunIf(max_torch="1.8.1")
def test_pytorch_profiler_describe(pytorch_profiler):
"""Ensure the profiler won't fail when reporting the summary."""
with pytorch_profiler.profile("test_step"):
Expand Down

0 comments on commit f5f4f03

Please sign in to comment.