Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up #2467

Merged
merged 49 commits into from
Jul 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
b32f6d6
Fixes #2455
williamFalcon Jul 2, 2020
59dff54
Fixes #2455
williamFalcon Jul 2, 2020
4d2c127
Fixes #2455
williamFalcon Jul 2, 2020
2ab5928
Fixes #2455
williamFalcon Jul 2, 2020
0264783
Fixes #2455
williamFalcon Jul 2, 2020
80988a3
Fixes #2455
williamFalcon Jul 2, 2020
f30358c
Fixes #2455
williamFalcon Jul 2, 2020
c2afd05
Fixes #2455
williamFalcon Jul 2, 2020
b3e5cfb
Fixes #2455
williamFalcon Jul 2, 2020
e399545
Fixes #2455
williamFalcon Jul 2, 2020
77c5daa
Fixes #2455
williamFalcon Jul 2, 2020
9874b5e
Fixes #2455
williamFalcon Jul 2, 2020
beeee3a
Fixes #2455
williamFalcon Jul 2, 2020
f8736b5
Fixes #2455
williamFalcon Jul 2, 2020
0f70120
Fixes #2455
williamFalcon Jul 2, 2020
26936bb
Fixes #2455
williamFalcon Jul 2, 2020
4610f68
Fixes #2455
williamFalcon Jul 2, 2020
e0ddc90
Fixes #2455
williamFalcon Jul 2, 2020
f113088
Fixes #2455
williamFalcon Jul 2, 2020
cc8d1cd
Fixes #2455
williamFalcon Jul 2, 2020
fc1254b
Fixes #2455
williamFalcon Jul 2, 2020
4492804
Fixes #2455
williamFalcon Jul 2, 2020
c59df13
Fixes #2455
williamFalcon Jul 2, 2020
bea5171
Fixes #2455
williamFalcon Jul 2, 2020
6d2e0c5
Fixes #2455
williamFalcon Jul 2, 2020
ffa65ad
Fixes #2455
williamFalcon Jul 2, 2020
7d5af1c
Fixes #2455
williamFalcon Jul 2, 2020
c907e36
added early stop tpu test
williamFalcon Jul 2, 2020
ce37587
added early stop tpu test
williamFalcon Jul 2, 2020
6c77aef
added early stop tpu test
williamFalcon Jul 2, 2020
6cd4fdc
added early stop tpu test
williamFalcon Jul 2, 2020
7fdc7ec
added early stop tpu test
williamFalcon Jul 2, 2020
7879fe2
added early stop tpu test
williamFalcon Jul 2, 2020
3d77c36
added early stop tpu test
williamFalcon Jul 2, 2020
2ff19ba
added early stop tpu test
williamFalcon Jul 2, 2020
57b601b
added early stop tpu test
williamFalcon Jul 2, 2020
9dcc73e
added early stop tpu test
williamFalcon Jul 2, 2020
82df22e
added early stop tpu test
williamFalcon Jul 2, 2020
fafe7af
added early stop tpu test
williamFalcon Jul 2, 2020
5af7f69
added early stop tpu test
williamFalcon Jul 2, 2020
fef08e2
added early stop tpu test
williamFalcon Jul 2, 2020
b4bbe1c
added early stop tpu test
williamFalcon Jul 2, 2020
7f711a4
added early stop tpu test
williamFalcon Jul 2, 2020
51d4740
added early stop tpu test
williamFalcon Jul 3, 2020
58b66bc
added early stop tpu test
williamFalcon Jul 3, 2020
50b5874
added early stop tpu test
williamFalcon Jul 3, 2020
c75e71b
added early stop tpu test
williamFalcon Jul 3, 2020
71ab1f6
added early stop tpu test
williamFalcon Jul 3, 2020
43fa463
added early stop tpu test
williamFalcon Jul 3, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 33 additions & 3 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,22 @@

import numpy as np
import torch
import torch.distributed as dist

from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_warn

torch_inf = torch.tensor(np.Inf)

try:
import torch_xla
import torch_xla.core.xla_model as xm
except ImportError:
XLA_AVAILABLE = False
else:
XLA_AVAILABLE = True


class EarlyStopping(Callback):
r"""
Expand Down Expand Up @@ -138,17 +147,38 @@ def _run_early_stopping_check(self, trainer, pl_module):

current = logs.get(self.monitor)
if not isinstance(current, torch.Tensor):
current = torch.tensor(current)
current = torch.tensor(current, device=pl_module.device)

if self.monitor_op(current - self.min_delta, self.best_score):
if self.monitor_op(current - self.min_delta, self.best_score.to(pl_module.device)):
self.best_score = current
self.wait_count = 0
else:
self.wait_count += 1
if self.wait_count >= self.patience:
should_stop = self.wait_count >= self.patience

if bool(should_stop):
self.stopped_epoch = trainer.current_epoch
trainer.should_stop = True

# stop every ddp process if any world process decides to stop
self._stop_distributed_training(trainer, pl_module)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function name is misleading. This does not stop training, it just updates the trainer.should_stop state.


def _stop_distributed_training(self, trainer, pl_module):

# in ddp make sure all processes stop when one is flagged
if trainer.use_ddp or trainer.use_ddp2:
stop = torch.tensor(int(trainer.should_stop), device=pl_module.device)
dist.all_reduce(stop, op=dist.reduce_op.SUM)
dist.barrier()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@williamFalcon Is a barrier needed after an all reduce?

trainer.should_stop = stop == trainer.world_size

# if trainer.use_tpu:
# stop = torch.tensor(int(trainer.should_stop), device=pl_module.device)
# xm.all_reduce('sum', [stop])
# print(type(stop))
# torch_xla.core.xla_model.rendezvous("pl.EarlyStoppingCallback.stop_distributed_training_check")
# trainer.should_stop = stop.item() == trainer.world_size

def on_train_end(self, trainer, pl_module):
if self.stopped_epoch > 0 and self.verbose > 0:
rank_zero_warn('Displayed epoch numbers by `EarlyStopping` start from "1" until v0.6.x,'
Expand Down
24 changes: 24 additions & 0 deletions tests/models/test_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,30 @@ def test_multi_gpu_model(tmpdir, backend):
memory.get_memory_profile('min_max')


@pytest.mark.spawn
@pytest.mark.parametrize("backend", ['dp', 'ddp', 'ddp2'])
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_multi_gpu_early_stop(tmpdir, backend):
"""Make sure DDP works. with early stopping"""
tutils.set_random_master_port()

trainer_options = dict(
default_root_dir=tmpdir,
early_stop_callback=True,
max_epochs=50,
limit_train_batches=10,
limit_val_batches=10,
gpus=[0, 1],
distributed_backend=backend,
)

model = EvalModelTemplate()
# tutils.run_model_test(trainer_options, model)
trainer = Trainer(**trainer_options)
result = trainer.fit(model)
assert result


@pytest.mark.spawn
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_ddp_all_dataloaders_passed_to_fit(tmpdir):
Expand Down
21 changes: 21 additions & 0 deletions tests/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,27 @@
TPU_AVAILABLE = True


@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
@pytest.mark.parametrize(['tpu_cores', 'expected_device'], [
pytest.param([1], 'xla:1'),
pytest.param([8], 'xla:8'),
])
def test_early_stop_checkpoints_on_tpu(tmpdir, tpu_cores, expected_device):
"""Test if single TPU core training works"""
model = EvalModelTemplate()
trainer = Trainer(
early_stop_callback=True,
default_root_dir=tmpdir,
progress_bar_refresh_rate=0,
max_epochs=50,
limit_train_batches=10,
limit_val_batches=10,
tpu_cores=tpu_cores,
)
trainer.fit(model)
assert torch_xla._XLAC._xla_get_default_device() == expected_device


@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
@pytest.mark.parametrize(['tpu_cores', 'expected_device'], [
pytest.param([1], 'xla:1'),
Expand Down