From a6b26a93b1404222694ee29880453630fe85c27c Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 8 Mar 2021 19:27:04 +0000 Subject: [PATCH 1/4] add test --- pytorch_lightning/utilities/apply_func.py | 7 +++++++ tests/utilities/test_all_gather_grad.py | 4 ++++ 2 files changed, 11 insertions(+) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 9fd42008b9d8d4..026b6604593af4 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -166,4 +166,11 @@ def convert_to_tensors(data, device: torch.device = None): raise MisconfigurationException("device (torch.device) should be provided.") for src_dtype, conversion_func in CONVERSION_DTYPES: data = apply_to_collection(data, src_dtype, partial(conversion_func, device=device)) + + def _move_to_device(value, device=None): + if device is not None and value.device != device: + value = value.to(device) + return value.contiguous() + + data = apply_to_collection(data, torch.Tensor, partial(_move_to_device, device=device)) return data diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index ae3addb4d6a667..55e746fb20013a 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -58,6 +58,8 @@ def training_epoch_end(self, outputs) -> None: self.training_epoch_end_called = True losses = torch.stack([x["loss"] for x in outputs]) gathered_loss = self.all_gather({ + "losses_tensor_int": torch.tensor([1, 2, 3]), + "losses_tensor_float": torch.tensor([1., 2., 3.]), "losses_np_ndarray": np.array([1, 2, 3]), "losses_bool": [True, False], "losses_float": [0., 1., 2.], @@ -65,6 +67,8 @@ def training_epoch_end(self, outputs) -> None: "losses": losses, "losses_list": [losses, losses] }) + assert gathered_loss["losses_tensor_int"][0].dtype == torch.int64 + assert gathered_loss["losses_tensor_float"][0].dtype == torch.float assert gathered_loss["losses_np_ndarray"][0].dtype == torch.int64 # torch.bool can't be all_gathered assert gathered_loss["losses_bool"][0].dtype == torch.uint8 From 71dcb569ce74ee84b6f11f2b1c396c39c3af466f Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 8 Mar 2021 19:30:16 +0000 Subject: [PATCH 2/4] update changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f78569c1b7a0b4..ea131ca363305e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -131,6 +131,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `Trainer` not resetting `lightning_optimizers` when calling `Trainer.fit()` multiple times ([#6372](https://github.com/PyTorchLightning/pytorch-lightning/pull/6372)) +- Fixed LightningModule `all_gather` on cpu tensors ([#6416](https://github.com/PyTorchLightning/pytorch-lightning/pull/6416)) + + ## [1.2.2] - 2021-03-02 ### Added From fd9cf93f107d2fd8e2264613cbe2417d954dc4f7 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 9 Mar 2021 11:23:35 +0000 Subject: [PATCH 3/4] update --- pytorch_lightning/utilities/apply_func.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 026b6604593af4..4f5d525e6baa63 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -164,13 +164,14 @@ def batch_to(data): def convert_to_tensors(data, device: torch.device = None): if device is None: raise MisconfigurationException("device (torch.device) should be provided.") + for src_dtype, conversion_func in CONVERSION_DTYPES: data = apply_to_collection(data, src_dtype, partial(conversion_func, device=device)) - def _move_to_device(value, device=None): - if device is not None and value.device != device: - value = value.to(device) - return value.contiguous() + def _move_to_device_and_contiguous(t: torch.Tensor, device: torch.device): + if t.device != device: + t = t.to(device) + return t.contiguous() - data = apply_to_collection(data, torch.Tensor, partial(_move_to_device, device=device)) + data = apply_to_collection(data, torch.Tensor, partial(_move_to_device_and_contiguous, device=device)) return data From eea26a0b1354c6f6779dc4c086d387abd147852c Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 9 Mar 2021 11:24:07 +0000 Subject: [PATCH 4/4] rename function --- pytorch_lightning/utilities/apply_func.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 4f5d525e6baa63..ddecb24a13f237 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -168,10 +168,10 @@ def convert_to_tensors(data, device: torch.device = None): for src_dtype, conversion_func in CONVERSION_DTYPES: data = apply_to_collection(data, src_dtype, partial(conversion_func, device=device)) - def _move_to_device_and_contiguous(t: torch.Tensor, device: torch.device): + def _move_to_device_and_make_contiguous(t: torch.Tensor, device: torch.device): if t.device != device: t = t.to(device) return t.contiguous() - data = apply_to_collection(data, torch.Tensor, partial(_move_to_device_and_contiguous, device=device)) + data = apply_to_collection(data, torch.Tensor, partial(_move_to_device_and_make_contiguous, device=device)) return data