From b32f6d6d22491545c1f9c2e4f85dcd1d0ef38274 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 2 Jul 2020 08:08:51 -0400 Subject: [PATCH 01/49] Fixes #2455 --- pytorch_lightning/callbacks/early_stopping.py | 4 ++++ pytorch_lightning/trainer/training_loop.py | 3 +++ 2 files changed, 7 insertions(+) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index b4113ebeccfca..7a7ae4ca50162 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -148,6 +148,10 @@ def _run_early_stopping_check(self, trainer, pl_module): if self.wait_count >= self.patience: self.stopped_epoch = trainer.current_epoch trainer.should_stop = True + print('-' * 100) + print('RUNNING EARLY STOP CHECK', trainer.global_rank) + print('stop:', trainer.should_stop) + print('-' * 100) def on_train_end(self, trainer, pl_module): if self.stopped_epoch > 0 and self.verbose > 0: diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index aa41087d29e8e..8d1ff035b5e10 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -389,6 +389,9 @@ def train(self): f' ({self.min_epochs}) or minimum steps ({self.min_steps}) has' ' not been met. Training will continue...') + print('-' * 100) + print('SHUTTING DOWN', self.global_rank) + print('-' * 100) self.run_training_teardown() except KeyboardInterrupt: From 59dff545be27e635e2058c72c6455b3e8082bb1e Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 2 Jul 2020 08:25:47 -0400 Subject: [PATCH 02/49] Fixes #2455 --- pytorch_lightning/callbacks/early_stopping.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 7a7ae4ca50162..af73e98c8a5e6 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -151,6 +151,7 @@ def _run_early_stopping_check(self, trainer, pl_module): print('-' * 100) print('RUNNING EARLY STOP CHECK', trainer.global_rank) print('stop:', trainer.should_stop) + print('epoch', trainer.current_epoch) print('-' * 100) def on_train_end(self, trainer, pl_module): From 4d2c12752cc3c337c3b325dad4cb488c2f3d3906 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 2 Jul 2020 08:28:55 -0400 Subject: [PATCH 03/49] Fixes #2455 --- pytorch_lightning/callbacks/early_stopping.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index af73e98c8a5e6..570a60f78a710 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -152,6 +152,7 @@ def _run_early_stopping_check(self, trainer, pl_module): print('RUNNING EARLY STOP CHECK', trainer.global_rank) print('stop:', trainer.should_stop) print('epoch', trainer.current_epoch) + print('metric value', self.best_score) print('-' * 100) def on_train_end(self, trainer, pl_module): From 2ab592828bd6e05327568454c5db4ecf4d93c4e0 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 2 Jul 2020 08:41:30 -0400 Subject: [PATCH 04/49] Fixes #2455 --- pytorch_lightning/callbacks/early_stopping.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 570a60f78a710..f9851e429c709 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -9,6 +9,7 @@ import numpy as np import torch +import torch.distributed as dist from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback @@ -145,9 +146,18 @@ def _run_early_stopping_check(self, trainer, pl_module): self.wait_count = 0 else: self.wait_count += 1 - if self.wait_count >= self.patience: + should_stop = self.wait_count >= self.patience + + # check flag across all GPUs + should_stop = torch.tensor(should_stop) + if trainer.use_ddp or trainer.use_ddp2: + dist.all_reduce(should_stop, op=dist.ReduceOp.Max) + + # do actual stop + if should_stop: self.stopped_epoch = trainer.current_epoch trainer.should_stop = True + print('-' * 100) print('RUNNING EARLY STOP CHECK', trainer.global_rank) print('stop:', trainer.should_stop) From 0264783e317cb9949d698ba963b0df693d8ec07c Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 2 Jul 2020 08:42:08 -0400 Subject: [PATCH 05/49] Fixes #2455 --- pytorch_lightning/callbacks/early_stopping.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index f9851e429c709..348fe42e92422 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -151,7 +151,7 @@ def _run_early_stopping_check(self, trainer, pl_module): # check flag across all GPUs should_stop = torch.tensor(should_stop) if trainer.use_ddp or trainer.use_ddp2: - dist.all_reduce(should_stop, op=dist.ReduceOp.Max) + dist.all_reduce(should_stop, op=dist.ReduceOp.MAX) # do actual stop if should_stop: From 80988a3ac1216b6579b2493f3c8e3982e85edd4e Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 2 Jul 2020 08:43:02 -0400 Subject: [PATCH 06/49] Fixes #2455 --- pytorch_lightning/callbacks/early_stopping.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 348fe42e92422..49695e77fe408 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -149,8 +149,8 @@ def _run_early_stopping_check(self, trainer, pl_module): should_stop = self.wait_count >= self.patience # check flag across all GPUs - should_stop = torch.tensor(should_stop) if trainer.use_ddp or trainer.use_ddp2: + should_stop = torch.tensor(should_stop, device=pl_module.device) dist.all_reduce(should_stop, op=dist.ReduceOp.MAX) # do actual stop From f30358ca26337aba67cfcb03ae7b2bd95a34eeb5 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 2 Jul 2020 08:43:42 -0400 Subject: [PATCH 07/49] Fixes #2455 --- pytorch_lightning/callbacks/early_stopping.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 49695e77fe408..d10150cf56eba 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -150,7 +150,7 @@ def _run_early_stopping_check(self, trainer, pl_module): # check flag across all GPUs if trainer.use_ddp or trainer.use_ddp2: - should_stop = torch.tensor(should_stop, device=pl_module.device) + should_stop = torch.tensor(int(should_stop), device=pl_module.device) dist.all_reduce(should_stop, op=dist.ReduceOp.MAX) # do actual stop From c2afd053c83001a7588209578070c118e997f2e6 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 2 Jul 2020 08:45:25 -0400 Subject: [PATCH 08/49] Fixes #2455 --- pytorch_lightning/callbacks/early_stopping.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index d10150cf56eba..19631753540d6 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -151,7 +151,9 @@ def _run_early_stopping_check(self, trainer, pl_module): # check flag across all GPUs if trainer.use_ddp or trainer.use_ddp2: should_stop = torch.tensor(int(should_stop), device=pl_module.device) + print(should_stop) dist.all_reduce(should_stop, op=dist.ReduceOp.MAX) + print(should_stop) # do actual stop if should_stop: From b3e5cfbbccfc1f741da40699cf82611533f48532 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 2 Jul 2020 08:45:47 -0400 Subject: [PATCH 09/49] Fixes #2455 --- pytorch_lightning/callbacks/early_stopping.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 19631753540d6..d5b43b1c22dc6 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -133,6 +133,8 @@ def on_validation_end(self, trainer, pl_module): self._run_early_stopping_check(trainer, pl_module) def _run_early_stopping_check(self, trainer, pl_module): + print('-' * 100) + print('RUNNING EARLY STOP CHECK', trainer.global_rank) logs = trainer.callback_metrics if not self._validate_condition_metric(logs): return # short circuit if metric not present @@ -160,8 +162,6 @@ def _run_early_stopping_check(self, trainer, pl_module): self.stopped_epoch = trainer.current_epoch trainer.should_stop = True - print('-' * 100) - print('RUNNING EARLY STOP CHECK', trainer.global_rank) print('stop:', trainer.should_stop) print('epoch', trainer.current_epoch) print('metric value', self.best_score) From e39954553f48cd8ccb5abd5adc21db8fdc0f5662 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 2 Jul 2020 08:47:28 -0400 Subject: [PATCH 10/49] Fixes #2455 --- pytorch_lightning/callbacks/early_stopping.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index d5b43b1c22dc6..a9cbc39bb2810 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -133,7 +133,7 @@ def on_validation_end(self, trainer, pl_module): self._run_early_stopping_check(trainer, pl_module) def _run_early_stopping_check(self, trainer, pl_module): - print('-' * 100) + print(f'{trainer.global_rank}' * 100) print('RUNNING EARLY STOP CHECK', trainer.global_rank) logs = trainer.callback_metrics if not self._validate_condition_metric(logs): @@ -165,7 +165,7 @@ def _run_early_stopping_check(self, trainer, pl_module): print('stop:', trainer.should_stop) print('epoch', trainer.current_epoch) print('metric value', self.best_score) - print('-' * 100) + print(f'{trainer.global_rank}' * 100) def on_train_end(self, trainer, pl_module): if self.stopped_epoch > 0 and self.verbose > 0: From 77c5daa3fc3361c2279aebd6e6aa51deb32d8355 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 2 Jul 2020 08:50:50 -0400 Subject: [PATCH 11/49] Fixes #2455 --- pytorch_lightning/callbacks/early_stopping.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index a9cbc39bb2810..82e34c5692731 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -134,7 +134,6 @@ def on_validation_end(self, trainer, pl_module): def _run_early_stopping_check(self, trainer, pl_module): print(f'{trainer.global_rank}' * 100) - print('RUNNING EARLY STOP CHECK', trainer.global_rank) logs = trainer.callback_metrics if not self._validate_condition_metric(logs): return # short circuit if metric not present @@ -153,18 +152,15 @@ def _run_early_stopping_check(self, trainer, pl_module): # check flag across all GPUs if trainer.use_ddp or trainer.use_ddp2: should_stop = torch.tensor(int(should_stop), device=pl_module.device) - print(should_stop) dist.all_reduce(should_stop, op=dist.ReduceOp.MAX) - print(should_stop) + + print(f'RANK: {trainer.global_rank} SHOULD STOP: {should_stop} BEST: {self.best_score}') # do actual stop if should_stop: self.stopped_epoch = trainer.current_epoch trainer.should_stop = True - print('stop:', trainer.should_stop) - print('epoch', trainer.current_epoch) - print('metric value', self.best_score) print(f'{trainer.global_rank}' * 100) def on_train_end(self, trainer, pl_module): From 9874b5e9402afa0587f94e0a299289faffda34e4 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 2 Jul 2020 08:53:25 -0400 Subject: [PATCH 12/49] Fixes #2455 --- pytorch_lightning/callbacks/early_stopping.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 82e34c5692731..6d4ac8d232a16 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -133,7 +133,6 @@ def on_validation_end(self, trainer, pl_module): self._run_early_stopping_check(trainer, pl_module) def _run_early_stopping_check(self, trainer, pl_module): - print(f'{trainer.global_rank}' * 100) logs = trainer.callback_metrics if not self._validate_condition_metric(logs): return # short circuit if metric not present @@ -161,8 +160,6 @@ def _run_early_stopping_check(self, trainer, pl_module): self.stopped_epoch = trainer.current_epoch trainer.should_stop = True - print(f'{trainer.global_rank}' * 100) - def on_train_end(self, trainer, pl_module): if self.stopped_epoch > 0 and self.verbose > 0: rank_zero_warn('Displayed epoch numbers by `EarlyStopping` start from "1" until v0.6.x,' From beeee3acd77d956ce799725a34b36e56fa09ce91 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 2 Jul 2020 08:53:44 -0400 Subject: [PATCH 13/49] Fixes #2455 --- pytorch_lightning/callbacks/early_stopping.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 6d4ac8d232a16..1e1eec4f9e680 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -157,6 +157,7 @@ def _run_early_stopping_check(self, trainer, pl_module): # do actual stop if should_stop: + print(f'RANK: {trainer.global_rank}, STOPPING...') self.stopped_epoch = trainer.current_epoch trainer.should_stop = True From f8736b5f3e67d27c141ad8d8219fd53146741a3b Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 2 Jul 2020 08:55:14 -0400 Subject: [PATCH 14/49] Fixes #2455 --- pytorch_lightning/callbacks/early_stopping.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 1e1eec4f9e680..2569c85ff3bbc 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -149,9 +149,9 @@ def _run_early_stopping_check(self, trainer, pl_module): should_stop = self.wait_count >= self.patience # check flag across all GPUs - if trainer.use_ddp or trainer.use_ddp2: - should_stop = torch.tensor(int(should_stop), device=pl_module.device) - dist.all_reduce(should_stop, op=dist.ReduceOp.MAX) + # if trainer.use_ddp or trainer.use_ddp2: + # should_stop = torch.tensor(int(should_stop), device=pl_module.device) + # dist.all_reduce(should_stop, op=dist.ReduceOp.MAX) print(f'RANK: {trainer.global_rank} SHOULD STOP: {should_stop} BEST: {self.best_score}') From 0f701207f0a4dacbd7c60e9fb021ca0abe90d13c Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 2 Jul 2020 09:03:24 -0400 Subject: [PATCH 15/49] Fixes #2455 --- pytorch_lightning/callbacks/early_stopping.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 2569c85ff3bbc..48994256fe30f 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -149,13 +149,14 @@ def _run_early_stopping_check(self, trainer, pl_module): should_stop = self.wait_count >= self.patience # check flag across all GPUs - # if trainer.use_ddp or trainer.use_ddp2: - # should_stop = torch.tensor(int(should_stop), device=pl_module.device) - # dist.all_reduce(should_stop, op=dist.ReduceOp.MAX) + if trainer.use_ddp or trainer.use_ddp2: + should_stop = torch.tensor(int(should_stop), device=pl_module.device) + dist.all_reduce(should_stop, op=dist.ReduceOp.MAX) print(f'RANK: {trainer.global_rank} SHOULD STOP: {should_stop} BEST: {self.best_score}') # do actual stop + print(f'RANK: {trainer.global_rank}, SHOULD STOP: {should_stop}, EPOCH: {trainer.current_epoch}') if should_stop: print(f'RANK: {trainer.global_rank}, STOPPING...') self.stopped_epoch = trainer.current_epoch From 26936bb93f4be7f77e69e90c4e468e52d0630a99 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 2 Jul 2020 09:39:03 -0400 Subject: [PATCH 16/49] Fixes #2455 --- pytorch_lightning/callbacks/early_stopping.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 48994256fe30f..fe6b0bc6ed815 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -152,6 +152,7 @@ def _run_early_stopping_check(self, trainer, pl_module): if trainer.use_ddp or trainer.use_ddp2: should_stop = torch.tensor(int(should_stop), device=pl_module.device) dist.all_reduce(should_stop, op=dist.ReduceOp.MAX) + dist.barrier() print(f'RANK: {trainer.global_rank} SHOULD STOP: {should_stop} BEST: {self.best_score}') From 4610f68f8fa6b49861acf4187e9734529ec3e790 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 2 Jul 2020 09:40:50 -0400 Subject: [PATCH 17/49] Fixes #2455 --- pytorch_lightning/callbacks/early_stopping.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index fe6b0bc6ed815..813a2e6c50b2f 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -151,8 +151,8 @@ def _run_early_stopping_check(self, trainer, pl_module): # check flag across all GPUs if trainer.use_ddp or trainer.use_ddp2: should_stop = torch.tensor(int(should_stop), device=pl_module.device) - dist.all_reduce(should_stop, op=dist.ReduceOp.MAX) dist.barrier() + dist.all_reduce(should_stop, op=dist.ReduceOp.MAX) print(f'RANK: {trainer.global_rank} SHOULD STOP: {should_stop} BEST: {self.best_score}') From e0ddc9084a8d0d6cb9bf71c114f3adf65d3d55fd Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 2 Jul 2020 09:43:04 -0400 Subject: [PATCH 18/49] Fixes #2455 --- pytorch_lightning/callbacks/early_stopping.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 813a2e6c50b2f..22d75e7a7b92b 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -149,8 +149,8 @@ def _run_early_stopping_check(self, trainer, pl_module): should_stop = self.wait_count >= self.patience # check flag across all GPUs + should_stop = torch.tensor(int(should_stop), device=pl_module.device) if trainer.use_ddp or trainer.use_ddp2: - should_stop = torch.tensor(int(should_stop), device=pl_module.device) dist.barrier() dist.all_reduce(should_stop, op=dist.ReduceOp.MAX) @@ -158,7 +158,7 @@ def _run_early_stopping_check(self, trainer, pl_module): # do actual stop print(f'RANK: {trainer.global_rank}, SHOULD STOP: {should_stop}, EPOCH: {trainer.current_epoch}') - if should_stop: + if should_stop.item(): print(f'RANK: {trainer.global_rank}, STOPPING...') self.stopped_epoch = trainer.current_epoch trainer.should_stop = True From f11308808d0f697eeea42e61c78ce4e52faa0cee Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 2 Jul 2020 09:43:16 -0400 Subject: [PATCH 19/49] Fixes #2455 --- pytorch_lightning/callbacks/early_stopping.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 22d75e7a7b92b..355c88abf3ea1 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -158,7 +158,7 @@ def _run_early_stopping_check(self, trainer, pl_module): # do actual stop print(f'RANK: {trainer.global_rank}, SHOULD STOP: {should_stop}, EPOCH: {trainer.current_epoch}') - if should_stop.item(): + if bool(should_stop.item()): print(f'RANK: {trainer.global_rank}, STOPPING...') self.stopped_epoch = trainer.current_epoch trainer.should_stop = True From cc8d1cdcb3d6672af22bfabd19f5b1eace1711ae Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 2 Jul 2020 09:46:35 -0400 Subject: [PATCH 20/49] Fixes #2455 --- pytorch_lightning/callbacks/early_stopping.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 355c88abf3ea1..ae810a12753ce 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -152,7 +152,7 @@ def _run_early_stopping_check(self, trainer, pl_module): should_stop = torch.tensor(int(should_stop), device=pl_module.device) if trainer.use_ddp or trainer.use_ddp2: dist.barrier() - dist.all_reduce(should_stop, op=dist.ReduceOp.MAX) + dist.all_reduce(should_stop, op=dist.reduce_op.MAX) print(f'RANK: {trainer.global_rank} SHOULD STOP: {should_stop} BEST: {self.best_score}') From fc1254b3080bdec45b93d0e47c4e8fd4b80b4ced Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 2 Jul 2020 09:46:42 -0400 Subject: [PATCH 21/49] Fixes #2455 --- pytorch_lightning/callbacks/early_stopping.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index ae810a12753ce..ca8db11a2a9bf 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -151,7 +151,6 @@ def _run_early_stopping_check(self, trainer, pl_module): # check flag across all GPUs should_stop = torch.tensor(int(should_stop), device=pl_module.device) if trainer.use_ddp or trainer.use_ddp2: - dist.barrier() dist.all_reduce(should_stop, op=dist.reduce_op.MAX) print(f'RANK: {trainer.global_rank} SHOULD STOP: {should_stop} BEST: {self.best_score}') From 44928043dc528e74d57f455cbe8809bc86e061d5 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 2 Jul 2020 09:54:09 -0400 Subject: [PATCH 22/49] Fixes #2455 --- pytorch_lightning/callbacks/early_stopping.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index ca8db11a2a9bf..76f87b8fa036e 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -152,6 +152,7 @@ def _run_early_stopping_check(self, trainer, pl_module): should_stop = torch.tensor(int(should_stop), device=pl_module.device) if trainer.use_ddp or trainer.use_ddp2: dist.all_reduce(should_stop, op=dist.reduce_op.MAX) + dist.barrier() print(f'RANK: {trainer.global_rank} SHOULD STOP: {should_stop} BEST: {self.best_score}') From c59df132072a02026a5a7acbec697114a25fa824 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 2 Jul 2020 09:55:34 -0400 Subject: [PATCH 23/49] Fixes #2455 --- pytorch_lightning/callbacks/early_stopping.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 76f87b8fa036e..51e31f7aba05f 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -151,7 +151,9 @@ def _run_early_stopping_check(self, trainer, pl_module): # check flag across all GPUs should_stop = torch.tensor(int(should_stop), device=pl_module.device) if trainer.use_ddp or trainer.use_ddp2: + print(f'RANK: {trainer.global_rank} REDUCING...') dist.all_reduce(should_stop, op=dist.reduce_op.MAX) + print(f'RANK: {trainer.global_rank} REDUCED...') dist.barrier() print(f'RANK: {trainer.global_rank} SHOULD STOP: {should_stop} BEST: {self.best_score}') From bea517169e0d7ef8c26238d95d146dd3ce07682a Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 2 Jul 2020 09:57:11 -0400 Subject: [PATCH 24/49] Fixes #2455 --- pytorch_lightning/callbacks/early_stopping.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 51e31f7aba05f..e9937089b750b 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -154,7 +154,6 @@ def _run_early_stopping_check(self, trainer, pl_module): print(f'RANK: {trainer.global_rank} REDUCING...') dist.all_reduce(should_stop, op=dist.reduce_op.MAX) print(f'RANK: {trainer.global_rank} REDUCED...') - dist.barrier() print(f'RANK: {trainer.global_rank} SHOULD STOP: {should_stop} BEST: {self.best_score}') @@ -165,6 +164,8 @@ def _run_early_stopping_check(self, trainer, pl_module): self.stopped_epoch = trainer.current_epoch trainer.should_stop = True + dist.barrier() + def on_train_end(self, trainer, pl_module): if self.stopped_epoch > 0 and self.verbose > 0: rank_zero_warn('Displayed epoch numbers by `EarlyStopping` start from "1" until v0.6.x,' From 6d2e0c5e8ca8e7296867eeba46e4521dc74dbe12 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 2 Jul 2020 10:02:00 -0400 Subject: [PATCH 25/49] Fixes #2455 --- pytorch_lightning/callbacks/early_stopping.py | 20 ++++++++----------- 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index e9937089b750b..2fc1e17eb7eea 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -141,6 +141,13 @@ def _run_early_stopping_check(self, trainer, pl_module): if not isinstance(current, torch.Tensor): current = torch.tensor(current) + # in ddp, reduce the stopping metric so every process conditions the same + if trainer.use_ddp or trainer.use_ddp2: + print(f'RANK: {trainer.global_rank}, BEFORE: {current}') + current = current.to(pl_module.device) + dist.all_reduce(current, op=dist.reduce_op.MAX) + print(f'RANK: {trainer.global_rank}, AFTER: {current}') + if self.monitor_op(current - self.min_delta, self.best_score): self.best_score = current self.wait_count = 0 @@ -148,18 +155,7 @@ def _run_early_stopping_check(self, trainer, pl_module): self.wait_count += 1 should_stop = self.wait_count >= self.patience - # check flag across all GPUs - should_stop = torch.tensor(int(should_stop), device=pl_module.device) - if trainer.use_ddp or trainer.use_ddp2: - print(f'RANK: {trainer.global_rank} REDUCING...') - dist.all_reduce(should_stop, op=dist.reduce_op.MAX) - print(f'RANK: {trainer.global_rank} REDUCED...') - - print(f'RANK: {trainer.global_rank} SHOULD STOP: {should_stop} BEST: {self.best_score}') - - # do actual stop - print(f'RANK: {trainer.global_rank}, SHOULD STOP: {should_stop}, EPOCH: {trainer.current_epoch}') - if bool(should_stop.item()): + if bool(should_stop): print(f'RANK: {trainer.global_rank}, STOPPING...') self.stopped_epoch = trainer.current_epoch trainer.should_stop = True From ffa65adf29ddfa85785ca740421d647738ee7cd5 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 2 Jul 2020 10:12:55 -0400 Subject: [PATCH 26/49] Fixes #2455 --- pytorch_lightning/callbacks/early_stopping.py | 27 ++++++++++++++----- 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 2fc1e17eb7eea..217f7ae5f29a8 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -17,6 +17,12 @@ torch_inf = torch.tensor(np.Inf) +try: + import torch_xla.core.xla_model as xm +except ImportError: + XLA_AVAILABLE = False +else: + XLA_AVAILABLE = True class EarlyStopping(Callback): r""" @@ -141,12 +147,7 @@ def _run_early_stopping_check(self, trainer, pl_module): if not isinstance(current, torch.Tensor): current = torch.tensor(current) - # in ddp, reduce the stopping metric so every process conditions the same - if trainer.use_ddp or trainer.use_ddp2: - print(f'RANK: {trainer.global_rank}, BEFORE: {current}') - current = current.to(pl_module.device) - dist.all_reduce(current, op=dist.reduce_op.MAX) - print(f'RANK: {trainer.global_rank}, AFTER: {current}') + current = self._reduce_metric(trainer, pl_module, current) if self.monitor_op(current - self.min_delta, self.best_score): self.best_score = current @@ -162,6 +163,20 @@ def _run_early_stopping_check(self, trainer, pl_module): dist.barrier() + def _reduce_metric(self, trainer, pl_module, metric): + + # in ddp, reduce the stopping metric so every process conditions the same + if trainer.use_ddp or trainer.use_ddp2: + metric = metric.to(pl_module.device) + dist.all_reduce(metric, op=dist.reduce_op.AVG) + + if trainer.use_tpu: + metric = metric.to(pl_module.device) + xm.all_reduce('sum', [metric]) + metric = metric / trainer.world_size + + return metric + def on_train_end(self, trainer, pl_module): if self.stopped_epoch > 0 and self.verbose > 0: rank_zero_warn('Displayed epoch numbers by `EarlyStopping` start from "1" until v0.6.x,' From 7d5af1cfd0eee9673eea3cab19936fb97e4b6741 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 2 Jul 2020 10:15:43 -0400 Subject: [PATCH 27/49] Fixes #2455 --- pytorch_lightning/callbacks/early_stopping.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 217f7ae5f29a8..881f92527ca65 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -168,7 +168,8 @@ def _reduce_metric(self, trainer, pl_module, metric): # in ddp, reduce the stopping metric so every process conditions the same if trainer.use_ddp or trainer.use_ddp2: metric = metric.to(pl_module.device) - dist.all_reduce(metric, op=dist.reduce_op.AVG) + dist.all_reduce(metric, op=dist.reduce_op.SUM) + metric = metric / trainer.world_size if trainer.use_tpu: metric = metric.to(pl_module.device) From c907e3606fec12b0c32655695df20548e1e5e6e8 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 2 Jul 2020 10:21:05 -0400 Subject: [PATCH 28/49] added early stop tpu test --- pytorch_lightning/callbacks/early_stopping.py | 2 +- pytorch_lightning/trainer/training_loop.py | 3 --- tests/models/test_tpu.py | 21 +++++++++++++++++++ 3 files changed, 22 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 881f92527ca65..ab5f45d8e02b9 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -147,6 +147,7 @@ def _run_early_stopping_check(self, trainer, pl_module): if not isinstance(current, torch.Tensor): current = torch.tensor(current) + # make sure the metric is consistent across processes current = self._reduce_metric(trainer, pl_module, current) if self.monitor_op(current - self.min_delta, self.best_score): @@ -157,7 +158,6 @@ def _run_early_stopping_check(self, trainer, pl_module): should_stop = self.wait_count >= self.patience if bool(should_stop): - print(f'RANK: {trainer.global_rank}, STOPPING...') self.stopped_epoch = trainer.current_epoch trainer.should_stop = True diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 8d1ff035b5e10..aa41087d29e8e 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -389,9 +389,6 @@ def train(self): f' ({self.min_epochs}) or minimum steps ({self.min_steps}) has' ' not been met. Training will continue...') - print('-' * 100) - print('SHUTTING DOWN', self.global_rank) - print('-' * 100) self.run_training_teardown() except KeyboardInterrupt: diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index d4ab9c340ec5c..5fa60311a38c4 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -19,6 +19,27 @@ TPU_AVAILABLE = True +@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine") +@pytest.mark.parametrize(['tpu_cores', 'expected_device'], [ + pytest.param([1], 'xla:1'), + pytest.param([8], 'xla:8'), +]) +def test_early_stop_checkpoints_on_tpu(tmpdir, tpu_cores, expected_device): + """Test if single TPU core training works""" + model = EvalModelTemplate() + trainer = Trainer( + early_stop_callback=True, + default_root_dir=tmpdir, + progress_bar_refresh_rate=0, + max_epochs=50, + limit_train_batches=10, + limit_val_batches=10, + tpu_cores=tpu_cores, + ) + trainer.fit(model) + assert torch_xla._XLAC._xla_get_default_device() == expected_device + + @pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine") @pytest.mark.parametrize(['tpu_cores', 'expected_device'], [ pytest.param([1], 'xla:1'), From ce37587a3e1f7ea50328243878eb89e75a9db60f Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 2 Jul 2020 10:23:16 -0400 Subject: [PATCH 29/49] added early stop tpu test --- tests/models/test_gpu.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/models/test_gpu.py b/tests/models/test_gpu.py index 734478f26a7a7..8401f62070564 100644 --- a/tests/models/test_gpu.py +++ b/tests/models/test_gpu.py @@ -58,6 +58,30 @@ def test_multi_gpu_model(tmpdir, backend): memory.get_memory_profile('min_max') +@pytest.mark.spawn +@pytest.mark.parametrize("backend", ['dp', 'ddp', 'ddp2']) +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +def test_multi_gpu_early_stop(tmpdir, backend): + """Make sure DDP works. with early stopping""" + tutils.set_random_master_port() + + trainer_options = dict( + default_root_dir=tmpdir, + early_stop_callback=True, + max_epochs=50, + limit_train_batches=10, + limit_val_batches=10, + gpus=[0, 1], + distributed_backend=backend, + ) + + model = EvalModelTemplate() + # tutils.run_model_test(trainer_options, model) + trainer = Trainer(**trainer_options) + result = trainer.fit(model) + assert result + + @pytest.mark.spawn @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") def test_ddp_all_dataloaders_passed_to_fit(tmpdir): From 6c77aef439f0c83ae1c6a8ab36db0d353b14a358 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 2 Jul 2020 10:58:14 -0400 Subject: [PATCH 30/49] added early stop tpu test --- pytorch_lightning/callbacks/early_stopping.py | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index ab5f45d8e02b9..2ec44fc20ac93 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -147,9 +147,6 @@ def _run_early_stopping_check(self, trainer, pl_module): if not isinstance(current, torch.Tensor): current = torch.tensor(current) - # make sure the metric is consistent across processes - current = self._reduce_metric(trainer, pl_module, current) - if self.monitor_op(current - self.min_delta, self.best_score): self.best_score = current self.wait_count = 0 @@ -161,22 +158,26 @@ def _run_early_stopping_check(self, trainer, pl_module): self.stopped_epoch = trainer.current_epoch trainer.should_stop = True - dist.barrier() + # stop every ddp process + if trainer.should_stop: + print(f'EARLY STOPPING. RANK {trainer.global_rank}, TRAINER: {trainer.should_stop}') + self._stop_distributed_training(trainer) + print(f'EARLY STOPPING. RANK {trainer.global_rank}, TRAINER: {trainer.should_stop}') - def _reduce_metric(self, trainer, pl_module, metric): + def _stop_distributed_training(self, trainer): # in ddp, reduce the stopping metric so every process conditions the same if trainer.use_ddp or trainer.use_ddp2: - metric = metric.to(pl_module.device) - dist.all_reduce(metric, op=dist.reduce_op.SUM) - metric = metric / trainer.world_size + stop = torch.tensor(1, device=trainer.device) + dist.all_reduce(stop, op=dist.reduce_op.MAX) + trainer.should_stop = stop.item() + dist.barrier() - if trainer.use_tpu: - metric = metric.to(pl_module.device) - xm.all_reduce('sum', [metric]) - metric = metric / trainer.world_size + # if trainer.use_tpu: + # metric = metric.to(pl_module.device) + # xm.all_reduce('sum', [metric]) + # metric = metric / trainer.world_size - return metric def on_train_end(self, trainer, pl_module): if self.stopped_epoch > 0 and self.verbose > 0: From 6cd4fdc407113de7edc39d32f0a340a8225638b8 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 2 Jul 2020 10:59:05 -0400 Subject: [PATCH 31/49] added early stop tpu test --- pytorch_lightning/callbacks/early_stopping.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 2ec44fc20ac93..b23bf0ba1e068 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -161,14 +161,14 @@ def _run_early_stopping_check(self, trainer, pl_module): # stop every ddp process if trainer.should_stop: print(f'EARLY STOPPING. RANK {trainer.global_rank}, TRAINER: {trainer.should_stop}') - self._stop_distributed_training(trainer) + self._stop_distributed_training(trainer, pl_module) print(f'EARLY STOPPING. RANK {trainer.global_rank}, TRAINER: {trainer.should_stop}') - def _stop_distributed_training(self, trainer): + def _stop_distributed_training(self, trainer, pl_module): # in ddp, reduce the stopping metric so every process conditions the same if trainer.use_ddp or trainer.use_ddp2: - stop = torch.tensor(1, device=trainer.device) + stop = torch.tensor(1, device=pl_module.device) dist.all_reduce(stop, op=dist.reduce_op.MAX) trainer.should_stop = stop.item() dist.barrier() From 7fdc7ec1abc45a6c471770df956ac34e4c91e188 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 2 Jul 2020 10:59:57 -0400 Subject: [PATCH 32/49] added early stop tpu test --- pytorch_lightning/callbacks/early_stopping.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index b23bf0ba1e068..ec3910daa0011 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -159,10 +159,9 @@ def _run_early_stopping_check(self, trainer, pl_module): trainer.should_stop = True # stop every ddp process - if trainer.should_stop: - print(f'EARLY STOPPING. RANK {trainer.global_rank}, TRAINER: {trainer.should_stop}') - self._stop_distributed_training(trainer, pl_module) - print(f'EARLY STOPPING. RANK {trainer.global_rank}, TRAINER: {trainer.should_stop}') + print(f'EARLY STOPPING. RANK {trainer.global_rank}, TRAINER: {trainer.should_stop}') + self._stop_distributed_training(trainer, pl_module) + print(f'EARLY STOPPING. RANK {trainer.global_rank}, TRAINER: {trainer.should_stop}') def _stop_distributed_training(self, trainer, pl_module): From 7879fe2986ed555548d297748788f119fbfc46e5 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 2 Jul 2020 11:00:06 -0400 Subject: [PATCH 33/49] added early stop tpu test --- pytorch_lightning/callbacks/early_stopping.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index ec3910daa0011..b8d63e98fd286 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -167,7 +167,7 @@ def _stop_distributed_training(self, trainer, pl_module): # in ddp, reduce the stopping metric so every process conditions the same if trainer.use_ddp or trainer.use_ddp2: - stop = torch.tensor(1, device=pl_module.device) + stop = torch.tensor(trainer.should_stop, device=pl_module.device) dist.all_reduce(stop, op=dist.reduce_op.MAX) trainer.should_stop = stop.item() dist.barrier() From 3d77c365bc4e53d7af50b3696388b88d8f29ed6a Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 2 Jul 2020 11:01:43 -0400 Subject: [PATCH 34/49] added early stop tpu test --- pytorch_lightning/callbacks/early_stopping.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index b8d63e98fd286..4e543befcd89e 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -167,7 +167,7 @@ def _stop_distributed_training(self, trainer, pl_module): # in ddp, reduce the stopping metric so every process conditions the same if trainer.use_ddp or trainer.use_ddp2: - stop = torch.tensor(trainer.should_stop, device=pl_module.device) + stop = torch.tensor(int(trainer.should_stop), device=pl_module.device) dist.all_reduce(stop, op=dist.reduce_op.MAX) trainer.should_stop = stop.item() dist.barrier() From 2ff19ba0bd600d39b5110f483a53cda96e437100 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 2 Jul 2020 11:01:54 -0400 Subject: [PATCH 35/49] added early stop tpu test --- pytorch_lightning/callbacks/early_stopping.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 4e543befcd89e..7cd80bbad5b9c 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -169,7 +169,7 @@ def _stop_distributed_training(self, trainer, pl_module): if trainer.use_ddp or trainer.use_ddp2: stop = torch.tensor(int(trainer.should_stop), device=pl_module.device) dist.all_reduce(stop, op=dist.reduce_op.MAX) - trainer.should_stop = stop.item() + trainer.should_stop = int(stop.item()) dist.barrier() # if trainer.use_tpu: From 57b601bb72133f771af3fc89f42566d4f190b38f Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 2 Jul 2020 11:03:22 -0400 Subject: [PATCH 36/49] added early stop tpu test --- pytorch_lightning/callbacks/early_stopping.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 7cd80bbad5b9c..1ebfdaf5ab991 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -159,13 +159,11 @@ def _run_early_stopping_check(self, trainer, pl_module): trainer.should_stop = True # stop every ddp process - print(f'EARLY STOPPING. RANK {trainer.global_rank}, TRAINER: {trainer.should_stop}') self._stop_distributed_training(trainer, pl_module) - print(f'EARLY STOPPING. RANK {trainer.global_rank}, TRAINER: {trainer.should_stop}') def _stop_distributed_training(self, trainer, pl_module): - # in ddp, reduce the stopping metric so every process conditions the same + # in ddp make sure all processes stop when one is flagged if trainer.use_ddp or trainer.use_ddp2: stop = torch.tensor(int(trainer.should_stop), device=pl_module.device) dist.all_reduce(stop, op=dist.reduce_op.MAX) From 9dcc73e584446e9558e83563f1eeee3a20db482f Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 2 Jul 2020 11:03:57 -0400 Subject: [PATCH 37/49] added early stop tpu test --- pytorch_lightning/callbacks/early_stopping.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 1ebfdaf5ab991..fdbe7f85a5c01 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -167,7 +167,7 @@ def _stop_distributed_training(self, trainer, pl_module): if trainer.use_ddp or trainer.use_ddp2: stop = torch.tensor(int(trainer.should_stop), device=pl_module.device) dist.all_reduce(stop, op=dist.reduce_op.MAX) - trainer.should_stop = int(stop.item()) + trainer.should_stop = stop dist.barrier() # if trainer.use_tpu: From 82df22e589df824b76481cadd7593856ef752d18 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 2 Jul 2020 11:04:59 -0400 Subject: [PATCH 38/49] added early stop tpu test --- pytorch_lightning/callbacks/early_stopping.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index fdbe7f85a5c01..3560e6a8c72e1 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -170,11 +170,11 @@ def _stop_distributed_training(self, trainer, pl_module): trainer.should_stop = stop dist.barrier() - # if trainer.use_tpu: - # metric = metric.to(pl_module.device) - # xm.all_reduce('sum', [metric]) - # metric = metric / trainer.world_size - + if trainer.use_tpu: + stop = torch.tensor(int(trainer.should_stop), device=pl_module.device) + xm.all_reduce('max', [stop]) + trainer.should_stop = stop + dist.barrier() def on_train_end(self, trainer, pl_module): if self.stopped_epoch > 0 and self.verbose > 0: From fafe7af80f1f2e6f912956a007c5c6e049a70635 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 2 Jul 2020 11:05:19 -0400 Subject: [PATCH 39/49] added early stop tpu test --- pytorch_lightning/callbacks/early_stopping.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 3560e6a8c72e1..84dc862034589 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -158,7 +158,7 @@ def _run_early_stopping_check(self, trainer, pl_module): self.stopped_epoch = trainer.current_epoch trainer.should_stop = True - # stop every ddp process + # stop every ddp process if any world process decides to stop self._stop_distributed_training(trainer, pl_module) def _stop_distributed_training(self, trainer, pl_module): From 5af7f69eab061948c6d1604588c5c63413c90559 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 2 Jul 2020 11:10:13 -0400 Subject: [PATCH 40/49] added early stop tpu test --- pytorch_lightning/callbacks/early_stopping.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 84dc862034589..8604a57bd2276 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -166,8 +166,8 @@ def _stop_distributed_training(self, trainer, pl_module): # in ddp make sure all processes stop when one is flagged if trainer.use_ddp or trainer.use_ddp2: stop = torch.tensor(int(trainer.should_stop), device=pl_module.device) - dist.all_reduce(stop, op=dist.reduce_op.MAX) - trainer.should_stop = stop + dist.all_reduce(stop, op=dist.reduce_op.SUM) + trainer.should_stop = stop == trainer.world_size dist.barrier() if trainer.use_tpu: From fef08e2ab86df18195204f4a93d2eb5b289b23cb Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 2 Jul 2020 11:13:32 -0400 Subject: [PATCH 41/49] added early stop tpu test --- pytorch_lightning/callbacks/early_stopping.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 8604a57bd2276..da05ce08e20c2 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -172,8 +172,8 @@ def _stop_distributed_training(self, trainer, pl_module): if trainer.use_tpu: stop = torch.tensor(int(trainer.should_stop), device=pl_module.device) - xm.all_reduce('max', [stop]) - trainer.should_stop = stop + xm.all_reduce('sum', [stop]) + trainer.should_stop = stop == trainer.world_size dist.barrier() def on_train_end(self, trainer, pl_module): From b4bbe1c69c3983a96d540e4643d3aa45a572232c Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 2 Jul 2020 11:15:11 -0400 Subject: [PATCH 42/49] added early stop tpu test --- pytorch_lightning/callbacks/early_stopping.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index da05ce08e20c2..69fcb57da216c 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -167,14 +167,14 @@ def _stop_distributed_training(self, trainer, pl_module): if trainer.use_ddp or trainer.use_ddp2: stop = torch.tensor(int(trainer.should_stop), device=pl_module.device) dist.all_reduce(stop, op=dist.reduce_op.SUM) - trainer.should_stop = stop == trainer.world_size dist.barrier() + trainer.should_stop = stop == trainer.world_size if trainer.use_tpu: stop = torch.tensor(int(trainer.should_stop), device=pl_module.device) xm.all_reduce('sum', [stop]) - trainer.should_stop = stop == trainer.world_size dist.barrier() + trainer.should_stop = stop == trainer.world_size def on_train_end(self, trainer, pl_module): if self.stopped_epoch > 0 and self.verbose > 0: From 7f711a4efbc2a40a999201a0e72b23269b508737 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 2 Jul 2020 19:11:45 -0400 Subject: [PATCH 43/49] added early stop tpu test --- pytorch_lightning/callbacks/early_stopping.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 69fcb57da216c..1e6624e28f6eb 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -145,7 +145,7 @@ def _run_early_stopping_check(self, trainer, pl_module): current = logs.get(self.monitor) if not isinstance(current, torch.Tensor): - current = torch.tensor(current) + current = torch.tensor(current, device=pl_module.device) if self.monitor_op(current - self.min_delta, self.best_score): self.best_score = current From 51d4740544c7b5c32fd878c25e3c6ba6d95a04e1 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 2 Jul 2020 22:00:06 -0400 Subject: [PATCH 44/49] added early stop tpu test --- pytorch_lightning/callbacks/early_stopping.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 1e6624e28f6eb..534521af8e2cc 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -147,7 +147,7 @@ def _run_early_stopping_check(self, trainer, pl_module): if not isinstance(current, torch.Tensor): current = torch.tensor(current, device=pl_module.device) - if self.monitor_op(current - self.min_delta, self.best_score): + if self.monitor_op(current - self.min_delta, self.best_score.to(pl_module.device)): self.best_score = current self.wait_count = 0 else: From 58b66bc00c9312fd591d8889776c1522df73c46a Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 2 Jul 2020 22:18:16 -0400 Subject: [PATCH 45/49] added early stop tpu test --- pytorch_lightning/callbacks/early_stopping.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 534521af8e2cc..1f6b1ce2ff6e4 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -18,12 +18,14 @@ torch_inf = torch.tensor(np.Inf) try: + import torch_xla import torch_xla.core.xla_model as xm except ImportError: XLA_AVAILABLE = False else: XLA_AVAILABLE = True + class EarlyStopping(Callback): r""" @@ -173,7 +175,7 @@ def _stop_distributed_training(self, trainer, pl_module): if trainer.use_tpu: stop = torch.tensor(int(trainer.should_stop), device=pl_module.device) xm.all_reduce('sum', [stop]) - dist.barrier() + torch_xla.core.xla_model.rendezvous("pl.EarlyStoppingCallback.stop_distributed_training_check") trainer.should_stop = stop == trainer.world_size def on_train_end(self, trainer, pl_module): From 50b587437483fd181eda5b0e21180be88c63b619 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 2 Jul 2020 22:40:35 -0400 Subject: [PATCH 46/49] added early stop tpu test --- pytorch_lightning/callbacks/early_stopping.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 1f6b1ce2ff6e4..57a479aab45e7 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -176,7 +176,9 @@ def _stop_distributed_training(self, trainer, pl_module): stop = torch.tensor(int(trainer.should_stop), device=pl_module.device) xm.all_reduce('sum', [stop]) torch_xla.core.xla_model.rendezvous("pl.EarlyStoppingCallback.stop_distributed_training_check") + print(stop) trainer.should_stop = stop == trainer.world_size + trainer.should_stop = bool(trainer.should_stop.item()) def on_train_end(self, trainer, pl_module): if self.stopped_epoch > 0 and self.verbose > 0: From c75e71b7f1c3132fd3dcbfbfd23fcfe4fbc7b3be Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 2 Jul 2020 23:08:39 -0400 Subject: [PATCH 47/49] added early stop tpu test --- pytorch_lightning/callbacks/early_stopping.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 57a479aab45e7..545b58b1c165b 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -176,9 +176,7 @@ def _stop_distributed_training(self, trainer, pl_module): stop = torch.tensor(int(trainer.should_stop), device=pl_module.device) xm.all_reduce('sum', [stop]) torch_xla.core.xla_model.rendezvous("pl.EarlyStoppingCallback.stop_distributed_training_check") - print(stop) - trainer.should_stop = stop == trainer.world_size - trainer.should_stop = bool(trainer.should_stop.item()) + trainer.should_stop = stop.item() == trainer.world_size def on_train_end(self, trainer, pl_module): if self.stopped_epoch > 0 and self.verbose > 0: From 71ab1f6175910554a5ccd2c41d701f78c119d385 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 2 Jul 2020 23:08:51 -0400 Subject: [PATCH 48/49] added early stop tpu test --- pytorch_lightning/callbacks/early_stopping.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 545b58b1c165b..18a0da8f9dbae 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -175,6 +175,7 @@ def _stop_distributed_training(self, trainer, pl_module): if trainer.use_tpu: stop = torch.tensor(int(trainer.should_stop), device=pl_module.device) xm.all_reduce('sum', [stop]) + print(type(stop)) torch_xla.core.xla_model.rendezvous("pl.EarlyStoppingCallback.stop_distributed_training_check") trainer.should_stop = stop.item() == trainer.world_size From 43fa46307fd8e8d447561d5a64511c1f3d69ee26 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Fri, 3 Jul 2020 00:15:14 -0400 Subject: [PATCH 49/49] added early stop tpu test --- pytorch_lightning/callbacks/early_stopping.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 18a0da8f9dbae..d9a396f7176fc 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -172,12 +172,12 @@ def _stop_distributed_training(self, trainer, pl_module): dist.barrier() trainer.should_stop = stop == trainer.world_size - if trainer.use_tpu: - stop = torch.tensor(int(trainer.should_stop), device=pl_module.device) - xm.all_reduce('sum', [stop]) - print(type(stop)) - torch_xla.core.xla_model.rendezvous("pl.EarlyStoppingCallback.stop_distributed_training_check") - trainer.should_stop = stop.item() == trainer.world_size + # if trainer.use_tpu: + # stop = torch.tensor(int(trainer.should_stop), device=pl_module.device) + # xm.all_reduce('sum', [stop]) + # print(type(stop)) + # torch_xla.core.xla_model.rendezvous("pl.EarlyStoppingCallback.stop_distributed_training_check") + # trainer.should_stop = stop.item() == trainer.world_size def on_train_end(self, trainer, pl_module): if self.stopped_epoch > 0 and self.verbose > 0: