Skip to content

Commit

Permalink
Merge 590e0df into 55dd3a4
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton authored Mar 9, 2021
2 parents 55dd3a4 + 590e0df commit f819212
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 0 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `Trainer` not resetting `lightning_optimizers` when calling `Trainer.fit()` multiple times ([#6372](https://github.com/PyTorchLightning/pytorch-lightning/pull/6372))


- Fixed LightningModule `all_gather` on cpu tensors ([#6416](https://github.com/PyTorchLightning/pytorch-lightning/pull/6416))


- Fixed an issue where the tuner would not tune the learning rate if also tuning the batch size ([#4688](https://github.com/PyTorchLightning/pytorch-lightning/pull/4688))


Expand Down
8 changes: 8 additions & 0 deletions pytorch_lightning/utilities/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,14 @@ def batch_to(data):
def convert_to_tensors(data, device: torch.device = None):
if device is None:
raise MisconfigurationException("device (torch.device) should be provided.")

for src_dtype, conversion_func in CONVERSION_DTYPES:
data = apply_to_collection(data, src_dtype, partial(conversion_func, device=device))

def _move_to_device_and_make_contiguous(t: torch.Tensor, device: torch.device):
if t.device != device:
t = t.to(device)
return t.contiguous()

data = apply_to_collection(data, torch.Tensor, partial(_move_to_device_and_make_contiguous, device=device))
return data
4 changes: 4 additions & 0 deletions tests/utilities/test_all_gather_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,17 @@ def training_epoch_end(self, outputs) -> None:
self.training_epoch_end_called = True
losses = torch.stack([x["loss"] for x in outputs])
gathered_loss = self.all_gather({
"losses_tensor_int": torch.tensor([1, 2, 3]),
"losses_tensor_float": torch.tensor([1., 2., 3.]),
"losses_np_ndarray": np.array([1, 2, 3]),
"losses_bool": [True, False],
"losses_float": [0., 1., 2.],
"losses_int": [0, 1, 2],
"losses": losses,
"losses_list": [losses, losses]
})
assert gathered_loss["losses_tensor_int"][0].dtype == torch.int64
assert gathered_loss["losses_tensor_float"][0].dtype == torch.float
assert gathered_loss["losses_np_ndarray"][0].dtype == torch.int64
# torch.bool can't be all_gathered
assert gathered_loss["losses_bool"][0].dtype == torch.uint8
Expand Down

0 comments on commit f819212

Please sign in to comment.