Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CLI] Drop ArgumentParser when pickling and save before spawning #8017

Merged
merged 29 commits into from
Jul 7, 2021
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
dcbf6fd
Drop `ArgumentParser` when pickling and save before spawning
carmocca Jun 17, 2021
aeaac3b
Minor changes
carmocca Jun 17, 2021
c6a596e
Merge branch 'master' into bugfix/cli-drop-parser-spawn
carmocca Jun 30, 2021
2540f97
Fix code
carmocca Jun 30, 2021
dbd3841
Merge branch 'master' into bugfix/cli-drop-parser-spawn
carmocca Jun 30, 2021
831f93b
Spawn changes
carmocca Jun 30, 2021
cb92eae
Add TPU tests and early exit
carmocca Jun 30, 2021
7f70052
Update CHANGELOG
carmocca Jun 30, 2021
4931912
Deepsource
carmocca Jun 30, 2021
cae3de1
mypy
carmocca Jun 30, 2021
d274a91
Python 3.7 friendly notation
carmocca Jun 30, 2021
5bad76d
makedirs
carmocca Jun 30, 2021
a82180e
Merge branch 'master' into bugfix/cli-drop-parser-spawn
carmocca Jul 1, 2021
86ce9b5
fmt
Borda Jul 2, 2021
c3d79e0
Overwrite config
carmocca Jul 2, 2021
4dc1cf8
Merge branch 'master' into bugfix/cli-drop-parser-spawn
carmocca Jul 2, 2021
fa0ea1e
Use new ddp on cpu
carmocca Jul 2, 2021
ad76825
Merge branch 'master' into bugfix/cli-drop-parser-spawn
carmocca Jul 6, 2021
e439d8f
move device
awaelchli Jul 6, 2021
c39dd57
debug
awaelchli Jul 6, 2021
64264df
missing super call
awaelchli Jul 6, 2021
8561edb
set_device in ddp plugin
awaelchli Jul 6, 2021
519d01c
redundant set device in single device plugin
awaelchli Jul 6, 2021
83d6bfa
remove redundant set_device in ddp subclasses
awaelchli Jul 6, 2021
9058ef7
Merge branch 'bugfix/set-device' into bugfix/cli-drop-parser-spawn
awaelchli Jul 7, 2021
b8507bf
Merge branch 'master' into bugfix/cli-drop-parser-spawn
carmocca Jul 7, 2021
62da070
Update pytorch_lightning/trainer/trainer.py
carmocca Jul 7, 2021
ef40533
Old windows fix
carmocca Jul 7, 2021
d2acbfa
Merge branch 'master' into bugfix/cli-drop-parser-spawn
carmocca Jul 7, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- `LightningCLI` now aborts with a clearer message if config already exists and disables save config during `fast_dev_run`([#7963](https://github.com/PyTorchLightning/pytorch-lightning/pull/7963))


- Save the `LightningCLI` config on `setup` and only on the main process ([#8017](https://github.com/PyTorchLightning/pytorch-lightning/pull/8017))


- Drop the `LightningCLI` `ArgumentParser` when pickling ([#8017](https://github.com/PyTorchLightning/pytorch-lightning/pull/8017))


- Skip `broadcast` if distributed not initialized for the spawn plugins ([#8017](https://github.com/PyTorchLightning/pytorch-lightning/pull/8017))


- `Trainer(resume_from_checkpoint=...)` now restores the model directly after `LightningModule.setup()`, which is before `LightningModule.configure_sharded_model()` ([#7652](https://github.com/PyTorchLightning/pytorch-lightning/pull/7652))


Expand Down
2 changes: 1 addition & 1 deletion pl_examples/basic_examples/dali_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def cli_main():
if not _DALI_AVAILABLE:
return

cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234)
cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234, save_config_overwrite=True)
cli.trainer.test(cli.model, datamodule=cli.datamodule)


Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,8 @@ def barrier(self, *args, **kwargs) -> None:
torch.distributed.barrier()

def broadcast(self, obj: object, src: int = 0) -> object:
if not distributed_available():
return obj
return self.dist.broadcast(obj)

def model_to_device(self):
Expand Down
10 changes: 6 additions & 4 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.data import has_len
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp, tpu_distributed
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import STEP_OUTPUT
Expand Down Expand Up @@ -127,7 +127,8 @@ def distributed_sampler_kwargs(self) -> Dict[str, int]:

@property
def is_distributed(self) -> bool:
return self.world_size != 1
# HOST_WORLD_SIZE is None outside the xmp.spawn process
return os.getenv(xenv.HOST_WORLD_SIZE, None) and self.world_size != 1

def process_dataloader(self, dataloader: DataLoader) -> MpDeviceLoader:
TPUSpawnPlugin._validate_dataloader(dataloader)
Expand Down Expand Up @@ -178,8 +179,7 @@ def model_to_device(self) -> None:
self.model = self.wrapped_model.to(self.root_device)

def barrier(self, name: Optional[str] = None) -> None:
# HOST_WORLD_SIZE is None outside the xmp.spawn process
if os.getenv(xenv.HOST_WORLD_SIZE, None) and tpu_distributed():
if self.is_distributed:
rendezvous(name)

def transfer_distrib_spawn_state_on_fit_end(self, results):
Expand Down Expand Up @@ -212,6 +212,8 @@ def save(self, state_dict: Dict, path: str) -> None:
xm.save(state_dict, path)

def broadcast(self, obj: object, src: int = 0) -> object:
if not self.is_distributed:
return obj
buffer = io.BytesIO()
torch.save(obj, buffer)
data = bytearray(buffer.getbuffer())
Expand Down
20 changes: 18 additions & 2 deletions pytorch_lightning/utilities/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.trainer.trainer import Trainer
from pytorch_lightning.utilities import _module_available
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.seed import seed_everything
Expand Down Expand Up @@ -162,7 +163,9 @@ def __init__(
self.config_filename = config_filename
self.overwrite = overwrite

def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None:
# save the config in `setup` because (1) we want it to save regardless of the trainer function run
# and we want to save before processes are spawned
log_dir = trainer.log_dir or trainer.default_root_dir
config_path = os.path.join(log_dir, self.config_filename)
if not self.overwrite and os.path.isfile(config_path):
Expand All @@ -172,7 +175,20 @@ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
' set `LightningCLI(save_config_callback=None)` to disable config saving,'
' or set `LightningCLI(save_config_overwrite=True)` to overwrite the config file.'
)
self.parser.save(self.config, config_path, skip_none=False, overwrite=self.overwrite)
if trainer.is_global_zero:
# save only on rank zero to avoid race conditions on DDP.
# the `log_dir` needs to be created as we rely on the logger to do it usually
# but it hasn't logged anything at this point
get_filesystem(log_dir).makedirs(log_dir, exist_ok=True)
self.parser.save(self.config, config_path, skip_none=False, overwrite=self.overwrite)

def __reduce__(self) -> Tuple[Type['SaveConfigCallback'], Tuple, Dict]:
# `ArgumentParser` is un-pickleable. Drop it
return (
self.__class__,
(None, self.config, self.config_filename),
{},
)


class LightningCLI:
Expand Down
4 changes: 1 addition & 3 deletions pytorch_lightning/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,4 @@ def register_ddp_comm_hook(


def tpu_distributed() -> bool:
if _TPU_AVAILABLE:
return xm.xrt_world_size() > 1
return False
return _TPU_AVAILABLE and xm.xrt_world_size() > 1
37 changes: 37 additions & 0 deletions tests/utilities/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from pytorch_lightning.utilities.cli import instantiate_class, LightningArgumentParser, LightningCLI, SaveConfigCallback
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE
from tests.helpers import BoringDataModule, BoringModel
from tests.helpers.runif import RunIf

torchvision_version = version.parse('0')
if _TORCHVISION_AVAILABLE:
Expand Down Expand Up @@ -605,6 +606,42 @@ def add_arguments_to_parser(self, parser):
assert cli.model.num_classes == 5


class EarlyExitTestModel(BoringModel):

def on_fit_start(self):
raise KeyboardInterrupt()


@pytest.mark.parametrize('logger', (False, True))
@pytest.mark.parametrize(
'trainer_kwargs', (
dict(accelerator='ddp_cpu'),
dict(accelerator='ddp_cpu', plugins="ddp_find_unused_parameters_false"),
pytest.param({'tpu_cores': 1}, marks=RunIf(tpu=True)),
)
)
def test_cli_ddp_spawn_save_config_callback(tmpdir, logger, trainer_kwargs):
with mock.patch('sys.argv', ['any.py']), pytest.raises(KeyboardInterrupt):
LightningCLI(
EarlyExitTestModel,
trainer_defaults={
'default_root_dir': str(tmpdir),
'logger': logger,
'max_steps': 1,
'max_epochs': 1,
**trainer_kwargs,
}
)
if logger:
config_dir = tmpdir / 'lightning_logs'
# no more version dirs should get created
assert os.listdir(config_dir) == ['version_0']
config_path = config_dir / 'version_0' / 'config.yaml'
else:
config_path = tmpdir / 'config.yaml'
assert os.path.isfile(config_path)


def test_cli_config_overwrite(tmpdir):
trainer_defaults = {'default_root_dir': str(tmpdir), 'logger': False, 'max_steps': 1, 'max_epochs': 1}

Expand Down