Skip to content

Commit

Permalink
resolve flake8
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton committed Jan 9, 2021
1 parent 46126a2 commit 0df1a98
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
7 changes: 5 additions & 2 deletions pytorch_lightning/utilities/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,10 @@ def from_numpy(value, device: torch.device = None):
)
return torch.from_numpy(value).to(device)


CONVERSION_DTYPES = [
# bool -> int as torch.bool: RuntimeError: Unsupported data type for NCCL process group
(bool, partial(to_dtype_tensor, dtype=torch.int)),
# bool -> uint8 as bool -> torch.bool triggers RuntimeError: Unsupported data type for NCCL process group
(bool, partial(to_dtype_tensor, dtype=torch.uint8)),
(int, partial(to_dtype_tensor, dtype=torch.int)),
(float, partial(to_dtype_tensor, dtype=torch.float)),
(np.ndarray, from_numpy),
Expand Down Expand Up @@ -147,6 +148,8 @@ def batch_to(data):

dtype = (TransferableDataType, Batch) if _TORCHTEXT_AVAILABLE else TransferableDataType
return apply_to_collection(batch, dtype=dtype, function=batch_to)


def convert_to_tensors(data, device: torch.device = None):
if device is None:
raise MisconfigurationException(
Expand Down
6 changes: 3 additions & 3 deletions tests/utilities/test_all_gather_grad.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import os
import sys

import numpy as np
import pytest
import torch
import numpy as np

from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.utilities import AllGatherGrad
from tests.base.boring_model import BoringModel

Expand Down Expand Up @@ -70,7 +70,7 @@ def training_epoch_end(self, outputs) -> None:
})
assert gathered_loss["losses_np_ndarray"][0].dtype == torch.int64
# torch.bool can't be all_gathered
assert gathered_loss["losses_bool"][0].dtype == torch.int32
assert gathered_loss["losses_bool"][0].dtype == torch.uint8
assert gathered_loss["losses_float"][0].dtype == torch.float
assert gathered_loss["losses_int"][0].dtype == torch.int
assert gathered_loss["losses_list"][0].numel() == 2 * len(losses)
Expand Down

0 comments on commit 0df1a98

Please sign in to comment.