Skip to content

Commit

Permalink
[bug] All_gather support tensor on cpu (#6416)
Browse files Browse the repository at this point in the history
* add test

* update changelog

* update

* rename function
  • Loading branch information
tchaton authored Mar 10, 2021
1 parent c81b2a8 commit 7d4e74c
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 0 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed logger creating directory structure too early in DDP ([#6380](https://github.com/PyTorchLightning/pytorch-lightning/pull/6380))


- Fixed LightningModule `all_gather` on cpu tensors ([#6416](https://github.com/PyTorchLightning/pytorch-lightning/pull/6416))


## [1.2.3] - 2021-03-09

### Fixed
Expand Down
6 changes: 6 additions & 0 deletions pytorch_lightning/utilities/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,12 @@ def batch_to(data):
def convert_to_tensors(data, device: torch.device = None):
if device is None:
raise MisconfigurationException("device (torch.device) should be provided.")

for src_dtype, conversion_func in CONVERSION_DTYPES:
data = apply_to_collection(data, src_dtype, partial(conversion_func, device=device))

def _move_to_device_and_make_contiguous(t: torch.Tensor, device: torch.device):
return t.to(device).contiguous()

data = apply_to_collection(data, torch.Tensor, partial(_move_to_device_and_make_contiguous, device=device))
return data
4 changes: 4 additions & 0 deletions tests/utilities/test_all_gather_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,17 @@ def training_epoch_end(self, outputs) -> None:
self.training_epoch_end_called = True
losses = torch.stack([x["loss"] for x in outputs])
gathered_loss = self.all_gather({
"losses_tensor_int": torch.rand(2, 2).int().t(),
"losses_tensor_float": torch.rand(2, 2).t(),
"losses_np_ndarray": np.array([1, 2, 3]),
"losses_bool": [True, False],
"losses_float": [0., 1., 2.],
"losses_int": [0, 1, 2],
"losses": losses,
"losses_list": [losses, losses]
})
assert gathered_loss["losses_tensor_int"][0].dtype == torch.int64
assert gathered_loss["losses_tensor_float"][0].dtype == torch.float
assert gathered_loss["losses_np_ndarray"][0].dtype == torch.int64
# torch.bool can't be all_gathered
assert gathered_loss["losses_bool"][0].dtype == torch.uint8
Expand Down

0 comments on commit 7d4e74c

Please sign in to comment.