From 04af624d63a08ded70387a6a200c9b3297936fc0 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 30 Sep 2020 13:42:19 -0400 Subject: [PATCH] ref: decoupled ddp spawn --- pytorch_lightning/accelerators/ddp_backend.py | 4 ++-- pytorch_lightning/trainer/trainer.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/accelerators/ddp_backend.py b/pytorch_lightning/accelerators/ddp_backend.py index 5cf6edc52f200c..0ff263166006ba 100644 --- a/pytorch_lightning/accelerators/ddp_backend.py +++ b/pytorch_lightning/accelerators/ddp_backend.py @@ -25,7 +25,7 @@ from pytorch_lightning import _logger as log from pytorch_lightning.utilities.distributed import find_free_network_port -from pytorch_lightning.accelerators.ddp_base_backend import DDPBase +from pytorch_lightning.accelerators.base_backend import Accelerator from pytorch_lightning.utilities.distributed import rank_zero_only from pytorch_lightning.utilities import AMPType @@ -39,7 +39,7 @@ HYDRA_AVAILABLE = True -class DDPBackend(DDPBase): +class DDPBackend(Accelerator): def __init__(self, trainer, mode: str = 'ddp'): super().__init__(trainer) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b1b27e821a1fcc..d312e562476480 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -712,8 +712,6 @@ def test( # Attach datamodule to get setup/prepare_data added to model before the call to it below self.data_connector.attach_datamodule(model or self.get_model(), datamodule, 'test') - print('-' * 100, f'\n {self.accelerator_backend.task_idx} TEST-DM \n', '-' * 100) - if model is not None: results = self.__test_given_model(model, test_dataloaders) else: @@ -726,6 +724,8 @@ def test( def __test_using_best_weights(self, ckpt_path, test_dataloaders): model = self.get_model() + print('-' * 100, f'\n {self.accelerator_backend.task_idx} TEST-DM \n', '-' * 100) + # if user requests the best checkpoint but we don't have it, error if ckpt_path == 'best' and not self.checkpoint_callback.best_model_path: raise MisconfigurationException(