diff --git a/CHANGELOG.md b/CHANGELOG.md index d1c347c00a3f1f..20e4666842d19c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -131,6 +131,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `Trainer` not resetting `lightning_optimizers` when calling `Trainer.fit()` multiple times ([#6372](https://github.com/PyTorchLightning/pytorch-lightning/pull/6372)) +- Fixed LightningModule `all_gather` on cpu tensors ([#6416](https://github.com/PyTorchLightning/pytorch-lightning/pull/6416)) + + - Fixed an issue where the tuner would not tune the learning rate if also tuning the batch size ([#4688](https://github.com/PyTorchLightning/pytorch-lightning/pull/4688)) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 9fd42008b9d8d4..ddecb24a13f237 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -164,6 +164,14 @@ def batch_to(data): def convert_to_tensors(data, device: torch.device = None): if device is None: raise MisconfigurationException("device (torch.device) should be provided.") + for src_dtype, conversion_func in CONVERSION_DTYPES: data = apply_to_collection(data, src_dtype, partial(conversion_func, device=device)) + + def _move_to_device_and_make_contiguous(t: torch.Tensor, device: torch.device): + if t.device != device: + t = t.to(device) + return t.contiguous() + + data = apply_to_collection(data, torch.Tensor, partial(_move_to_device_and_make_contiguous, device=device)) return data diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index ae3addb4d6a667..55e746fb20013a 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -58,6 +58,8 @@ def training_epoch_end(self, outputs) -> None: self.training_epoch_end_called = True losses = torch.stack([x["loss"] for x in outputs]) gathered_loss = self.all_gather({ + "losses_tensor_int": torch.tensor([1, 2, 3]), + "losses_tensor_float": torch.tensor([1., 2., 3.]), "losses_np_ndarray": np.array([1, 2, 3]), "losses_bool": [True, False], "losses_float": [0., 1., 2.], @@ -65,6 +67,8 @@ def training_epoch_end(self, outputs) -> None: "losses": losses, "losses_list": [losses, losses] }) + assert gathered_loss["losses_tensor_int"][0].dtype == torch.int64 + assert gathered_loss["losses_tensor_float"][0].dtype == torch.float assert gathered_loss["losses_np_ndarray"][0].dtype == torch.int64 # torch.bool can't be all_gathered assert gathered_loss["losses_bool"][0].dtype == torch.uint8