From 34f5f07b5cbed147064a25b0541e4dc38c831060 Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Fri, 18 Dec 2020 19:58:16 +0900 Subject: [PATCH 01/75] Adding types to datamodules --- .../datamodules/binary_mnist_datamodule.py | 4 +- pl_bolts/datamodules/cifar10_datamodule.py | 6 +-- pl_bolts/datamodules/cityscapes_datamodule.py | 16 ++++---- .../datamodules/fashion_mnist_datamodule.py | 4 +- pl_bolts/datamodules/imagenet_datamodule.py | 18 ++++----- pl_bolts/datamodules/kitti_datamodule.py | 14 +++---- pl_bolts/datamodules/mnist_datamodule.py | 4 +- pl_bolts/datamodules/sklearn_datamodule.py | 38 +++++++++++-------- .../datamodules/ssl_imagenet_datamodule.py | 22 +++++------ pl_bolts/datamodules/stl10_datamodule.py | 22 +++++------ pl_bolts/datamodules/vision_datamodule.py | 16 ++++++-- 11 files changed, 90 insertions(+), 74 deletions(-) diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py index 142b3d54ef..c713abe107 100644 --- a/pl_bolts/datamodules/binary_mnist_datamodule.py +++ b/pl_bolts/datamodules/binary_mnist_datamodule.py @@ -56,7 +56,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ) -> None: + ): """ Args: data_dir: Where to save/load the data @@ -98,7 +98,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self): + def default_transforms(self) -> transform_lib.Compose: if self.normalize: mnist_transforms = transform_lib.Compose( [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))] diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index afb2df8c9a..12aea1ec87 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -112,7 +112,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self): + def default_transforms(self) -> transform_lib.Compose: if self.normalize: cf10_transforms = transform_lib.Compose([transform_lib.ToTensor(), cifar10_normalization()]) else: @@ -146,14 +146,14 @@ class TinyCIFAR10DataModule(CIFAR10DataModule): def __init__( self, - data_dir: str, + data_dir: Optional[str], val_split: int = 50, num_workers: int = 16, num_samples: int = 100, labels: Optional[Sequence] = (1, 5, 8), *args: Any, **kwargs: Any, - ) -> None: + ): """ Args: data_dir: where to save/load the data diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py index b851617225..d27bfc3196 100644 --- a/pl_bolts/datamodules/cityscapes_datamodule.py +++ b/pl_bolts/datamodules/cityscapes_datamodule.py @@ -69,8 +69,8 @@ def __init__( shuffle: bool = False, pin_memory: bool = False, drop_last: bool = False, - *args, - **kwargs, + *args: Any, + **kwargs: Any, ): """ Args: @@ -109,14 +109,14 @@ def __init__( self.target_transforms = None @property - def num_classes(self): + def num_classes(self) -> int: """ Return: 30 """ return 30 - def train_dataloader(self): + def train_dataloader(self) -> DataLoader: """ Cityscapes train set """ @@ -141,7 +141,7 @@ def train_dataloader(self): ) return loader - def val_dataloader(self): + def val_dataloader(self) -> DataLoader: """ Cityscapes val set """ @@ -166,7 +166,7 @@ def val_dataloader(self): ) return loader - def test_dataloader(self): + def test_dataloader(self) -> DataLoader: """ Cityscapes test set """ @@ -190,7 +190,7 @@ def test_dataloader(self): ) return loader - def _default_transforms(self): + def _default_transforms(self) -> transform_lib.Compose: cityscapes_transforms = transform_lib.Compose([ transform_lib.ToTensor(), transform_lib.Normalize( @@ -200,7 +200,7 @@ def _default_transforms(self): ]) return cityscapes_transforms - def _default_target_transforms(self): + def _default_target_transforms(self) -> transform_lib.Compose: cityscapes_target_trasnforms = transform_lib.Compose([ transform_lib.ToTensor(), transform_lib.Lambda(lambda t: t.squeeze()) diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py index 833c4599a6..9a73e6a637 100644 --- a/pl_bolts/datamodules/fashion_mnist_datamodule.py +++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py @@ -57,7 +57,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ) -> None: + ): """ Args: data_dir: Where to save/load the data @@ -93,7 +93,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self): + def default_transforms(self) -> transform_lib.Compose: if self.normalize: mnist_transforms = transform_lib.Compose( [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))] diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py index 38546e29ee..a31b637ba9 100644 --- a/pl_bolts/datamodules/imagenet_datamodule.py +++ b/pl_bolts/datamodules/imagenet_datamodule.py @@ -58,8 +58,8 @@ def __init__( shuffle: bool = False, pin_memory: bool = False, drop_last: bool = False, - *args, - **kwargs, + *args: Any, + **kwargs: Any, ): """ Args: @@ -94,7 +94,7 @@ def __init__( self.num_samples = 1281167 - self.num_imgs_per_val_class * self.num_classes @property - def num_classes(self): + def num_classes(self) -> int: """ Return: @@ -103,7 +103,7 @@ def num_classes(self): """ return 1000 - def _verify_splits(self, data_dir, split): + def _verify_splits(self, data_dir: str, split: str): dirs = os.listdir(data_dir) if split not in dirs: @@ -138,7 +138,7 @@ def prepare_data(self): UnlabeledImagenet.generate_meta_bins(path) """) - def train_dataloader(self): + def train_dataloader(self) -> DataLoader: """ Uses the train split of imagenet2012 and puts away a portion of it for the validation split """ @@ -160,7 +160,7 @@ def train_dataloader(self): ) return loader - def val_dataloader(self): + def val_dataloader(self) -> DataLoader: """ Uses the part of the train split of imagenet2012 that was not used for training via `num_imgs_per_val_class` @@ -185,7 +185,7 @@ def val_dataloader(self): ) return loader - def test_dataloader(self): + def test_dataloader(self) -> DataLoader: """ Uses the validation split of imagenet2012 for testing """ @@ -206,7 +206,7 @@ def test_dataloader(self): ) return loader - def train_transform(self): + def train_transform(self) -> transform_lib.Compose: """ The standard imagenet transforms @@ -232,7 +232,7 @@ def train_transform(self): return preprocessing - def val_transform(self): + def val_transform(self) -> transform_lib.Compose: """ The standard imagenet transforms for validation diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py index 433e7fffed..a50528028b 100644 --- a/pl_bolts/datamodules/kitti_datamodule.py +++ b/pl_bolts/datamodules/kitti_datamodule.py @@ -21,7 +21,7 @@ class KittiDataModule(LightningDataModule): def __init__( self, - data_dir: str, + data_dir: Optional[str], val_split: float = 0.2, test_split: float = 0.1, num_workers: int = 16, @@ -30,8 +30,8 @@ def __init__( shuffle: bool = False, pin_memory: bool = False, drop_last: bool = False, - *args, - **kwargs, + *args: Any, + **kwargs: Any, ): """ Kitti train, validation and test dataloaders. @@ -100,7 +100,7 @@ def __init__( lengths=[train_len, val_len, test_len], generator=torch.Generator().manual_seed(self.seed)) - def train_dataloader(self): + def train_dataloader(self) -> DataLoader: loader = DataLoader( self.trainset, batch_size=self.batch_size, @@ -111,7 +111,7 @@ def train_dataloader(self): ) return loader - def val_dataloader(self): + def val_dataloader(self) -> DataLoader: loader = DataLoader( self.valset, batch_size=self.batch_size, @@ -122,7 +122,7 @@ def val_dataloader(self): ) return loader - def test_dataloader(self): + def test_dataloader(self) -> DataLoader: loader = DataLoader( self.testset, batch_size=self.batch_size, @@ -133,7 +133,7 @@ def test_dataloader(self): ) return loader - def _default_transforms(self): + def _default_transforms(self) -> transform_lib.Compose: kitti_transforms = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753], diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py index 1dd5e927b6..76a0438a0c 100644 --- a/pl_bolts/datamodules/mnist_datamodule.py +++ b/pl_bolts/datamodules/mnist_datamodule.py @@ -56,7 +56,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ) -> None: + ): """ Args: data_dir: Where to save/load the data @@ -92,7 +92,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self): + def default_transforms(self) -> transform_lib.Compose: if self.normalize: mnist_transforms = transform_lib.Compose( [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))] diff --git a/pl_bolts/datamodules/sklearn_datamodule.py b/pl_bolts/datamodules/sklearn_datamodule.py index ed262b10c8..dcfd559441 100644 --- a/pl_bolts/datamodules/sklearn_datamodule.py +++ b/pl_bolts/datamodules/sklearn_datamodule.py @@ -42,10 +42,10 @@ def __init__(self, X: np.ndarray, y: np.ndarray, X_transform: Any = None, y_tran self.X_transform = X_transform self.y_transform = y_transform - def __len__(self): + def __len__(self) -> int: return len(self.X) - def __getitem__(self, idx): + def __getitem__(self, idx) -> Tuple[np.ndarray, np.ndarray]: x = self.X[idx].astype(np.float32) y = self.Y[idx] @@ -89,10 +89,10 @@ def __init__(self, X: torch.Tensor, y: torch.Tensor, X_transform: Any = None, y_ self.X_transform = X_transform self.y_transform = y_transform - def __len__(self): + def __len__(self) -> int: return len(self.X) - def __getitem__(self, idx): + def __getitem__(self, idx) -> Tuple[torch.Tensor, torch.Tensor]: x = self.X[idx].float() y = self.Y[idx] @@ -145,14 +145,14 @@ def __init__( x_val=None, y_val=None, x_test=None, y_test=None, val_split=0.2, test_split=0.1, - num_workers=2, - random_state=1234, - shuffle=True, + num_workers:int = 2, + random_state: int = 1234, + shuffle: bool = True, batch_size: int = 16, - pin_memory=False, - drop_last=False, - *args, - **kwargs, + pin_memory: bool = False, + drop_last: bool = False, + *args: Any, + **kwargs: Any, ): super().__init__(*args, **kwargs) @@ -193,12 +193,20 @@ def __init__( self._init_datasets(X, y, x_val, y_val, x_test, y_test) - def _init_datasets(self, X, y, x_val, y_val, x_test, y_test): + def _init_datasets( + self, + X: np.ndarray, + y: np.ndarray, + x_val: np.ndarray, + y_val: np.ndarray, + x_test: np.ndarray, + y_test: np.ndarray + ): self.train_dataset = SklearnDataset(X, y) self.val_dataset = SklearnDataset(x_val, y_val) self.test_dataset = SklearnDataset(x_test, y_test) - def train_dataloader(self): + def train_dataloader(self) -> DataLoader: loader = DataLoader( self.train_dataset, batch_size=self.batch_size, @@ -209,7 +217,7 @@ def train_dataloader(self): ) return loader - def val_dataloader(self): + def val_dataloader(self) -> DataLoader: loader = DataLoader( self.val_dataset, batch_size=self.batch_size, @@ -220,7 +228,7 @@ def val_dataloader(self): ) return loader - def test_dataloader(self): + def test_dataloader(self) -> DataLoader: loader = DataLoader( self.test_dataset, batch_size=self.batch_size, diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py index 03e459fd5e..50315245af 100644 --- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py +++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py @@ -20,15 +20,15 @@ class SSLImagenetDataModule(LightningDataModule): # pragma: no cover def __init__( self, - data_dir, - meta_dir=None, - num_workers=16, + data_dir: str, + meta_dir: Optional[str] = None, + num_workers: int = 16, batch_size: int = 32, shuffle: bool = False, pin_memory: bool = False, drop_last: bool = False, - *args, - **kwargs, + *args: Any, + **kwargs: Any, ): super().__init__(*args, **kwargs) @@ -46,10 +46,10 @@ def __init__( self.drop_last = drop_last @property - def num_classes(self): + def num_classes(self) -> int: return 1000 - def _verify_splits(self, data_dir, split): + def _verify_splits(self, data_dir: str, split: str): dirs = os.listdir(data_dir) if split not in dirs: @@ -79,7 +79,7 @@ def prepare_data(self): UnlabeledImagenet.generate_meta_bins(path) """) - def train_dataloader(self, num_images_per_class=-1, add_normalize=False): + def train_dataloader(self, num_images_per_class: int = -1, add_normalize: bool = False) -> DataLoader: transforms = self._default_transforms() if self.train_transforms is None else self.train_transforms dataset = UnlabeledImagenet(self.data_dir, @@ -97,7 +97,7 @@ def train_dataloader(self, num_images_per_class=-1, add_normalize=False): ) return loader - def val_dataloader(self, num_images_per_class=50, add_normalize=False): + def val_dataloader(self, num_images_per_class: int = 50, add_normalize: bool = False) -> DataLoader: transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms dataset = UnlabeledImagenet(self.data_dir, @@ -115,7 +115,7 @@ def val_dataloader(self, num_images_per_class=50, add_normalize=False): ) return loader - def test_dataloader(self, num_images_per_class, add_normalize=False): + def test_dataloader(self, num_images_per_class: int, add_normalize: bool = False) -> DataLoader: transforms = self._default_transforms() if self.test_transforms is None else self.test_transforms dataset = UnlabeledImagenet(self.data_dir, @@ -133,7 +133,7 @@ def test_dataloader(self, num_images_per_class, add_normalize=False): ) return loader - def _default_transforms(self): + def _default_transforms(self) -> transform_lib.Compose: mnist_transforms = transform_lib.Compose([ transform_lib.ToTensor(), imagenet_normalization() diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py index c666db9b9b..3b29995a1b 100644 --- a/pl_bolts/datamodules/stl10_datamodule.py +++ b/pl_bolts/datamodules/stl10_datamodule.py @@ -63,8 +63,8 @@ def __init__( shuffle: bool = False, pin_memory: bool = False, drop_last: bool = False, - *args, - **kwargs, + *args: Any, + **kwargs: Any, ): """ Args: @@ -99,7 +99,7 @@ def __init__( self.num_unlabeled_samples = 100000 - unlabeled_val_split @property - def num_classes(self): + def num_classes(self) -> int: return 10 def prepare_data(self): @@ -110,7 +110,7 @@ def prepare_data(self): STL10(self.data_dir, split='train', download=True, transform=transform_lib.ToTensor()) STL10(self.data_dir, split='test', download=True, transform=transform_lib.ToTensor()) - def train_dataloader(self): + def train_dataloader(self) -> DataLoader: """ Loads the 'unlabeled' split minus a portion set aside for validation via `unlabeled_val_split`. """ @@ -131,7 +131,7 @@ def train_dataloader(self): ) return loader - def train_dataloader_mixed(self): + def train_dataloader_mixed(self) -> DataLoader: """ Loads a portion of the 'unlabeled' training data and 'train' (labeled) data. both portions have a subset removed for validation via `unlabeled_val_split` and `train_val_split` @@ -169,7 +169,7 @@ def train_dataloader_mixed(self): ) return loader - def val_dataloader(self): + def val_dataloader(self) -> DataLoader: """ Loads a portion of the 'unlabeled' training data set aside for validation The val dataset = (unlabeled - train_val_split) @@ -196,7 +196,7 @@ def val_dataloader(self): ) return loader - def val_dataloader_mixed(self): + def val_dataloader_mixed(self) -> DataLoader: """ Loads a portion of the 'unlabeled' training data set aside for validation along with the portion of the 'train' dataset to be used for validation @@ -239,7 +239,7 @@ def val_dataloader_mixed(self): ) return loader - def test_dataloader(self): + def test_dataloader(self) -> DataLoader: """ Loads the test split of STL10 @@ -260,7 +260,7 @@ def test_dataloader(self): ) return loader - def train_dataloader_labeled(self): + def train_dataloader_labeled(self) -> DataLoader: transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms dataset = STL10(self.data_dir, split='train', download=False, transform=transforms) @@ -278,7 +278,7 @@ def train_dataloader_labeled(self): ) return loader - def val_dataloader_labeled(self): + def val_dataloader_labeled(self) -> DataLoader: transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms dataset = STL10(self.data_dir, split='train', @@ -299,7 +299,7 @@ def val_dataloader_labeled(self): ) return loader - def _default_transforms(self): + def _default_transforms(self) -> transform_lib.Compose: data_transforms = transform_lib.Compose([ transform_lib.ToTensor(), stl10_normalization() diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py index 2144f0f509..5b8c508904 100644 --- a/pl_bolts/datamodules/vision_datamodule.py +++ b/pl_bolts/datamodules/vision_datamodule.py @@ -6,6 +6,14 @@ from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader, Dataset, random_split +from pl_bolts.utils import _TORCHVISION_AVAILABLE +from pl_bolts.utils.warnings import warn_missing_pkg + +if _TORCHVISION_AVAILABLE: + from torchvision import transforms as transform_lib +else: + warn_missing_pkg('torchvision') # pragma: no-cover + class VisionDataModule(LightningDataModule): @@ -29,7 +37,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ) -> None: + ): """ Args: data_dir: Where to save/load the data @@ -56,14 +64,14 @@ def __init__( self.pin_memory = pin_memory self.drop_last = drop_last - def prepare_data(self) -> None: + def prepare_data(self): """ Saves files to data_dir """ self.dataset_cls(self.data_dir, train=True, download=True) self.dataset_cls(self.data_dir, train=False, download=True) - def setup(self, stage: Optional[str] = None) -> None: + def setup(self, stage: Optional[str] = None): """ Creates train, val, and test dataset """ @@ -115,7 +123,7 @@ def _get_splits(self, len_dataset: int) -> List[int]: return splits @abstractmethod - def default_transforms(self): + def default_transforms(self) -> transform_lib.Compose: """ Default transform for the dataset """ def train_dataloader(self) -> DataLoader: From 2b55b328900ad96ce504ca2ff3f244cbe97c0597 Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Fri, 18 Dec 2020 20:20:20 +0900 Subject: [PATCH 02/75] Fixing typing imports --- pl_bolts/datamodules/async_dataloader.py | 14 +++++++++++--- pl_bolts/datamodules/cityscapes_datamodule.py | 2 ++ pl_bolts/datamodules/imagenet_datamodule.py | 2 +- pl_bolts/datamodules/kitti_datamodule.py | 3 ++- pl_bolts/datamodules/sklearn_datamodule.py | 4 ++-- .../datamodules/ssl_imagenet_datamodule.py | 1 + pl_bolts/datamodules/stl10_datamodule.py | 2 +- pl_bolts/datamodules/vision_datamodule.py | 7 +------ .../datamodules/vocdetection_datamodule.py | 18 ++++++++++-------- 9 files changed, 31 insertions(+), 22 deletions(-) diff --git a/pl_bolts/datamodules/async_dataloader.py b/pl_bolts/datamodules/async_dataloader.py index 7ded9d9ef1..38a0b9bb58 100644 --- a/pl_bolts/datamodules/async_dataloader.py +++ b/pl_bolts/datamodules/async_dataloader.py @@ -1,10 +1,11 @@ import re from queue import Queue from threading import Thread +from typing import Any, Optional, Union import torch from torch._six import container_abcs, string_classes -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, Dataset class AsynchronousLoader(object): @@ -26,7 +27,14 @@ class AsynchronousLoader(object): constructing one here """ - def __init__(self, data, device=torch.device('cuda', 0), q_size=10, num_batches=None, **kwargs): + def __init__( + self, + data: Union[DataLoader, Dataset], + device: torch.device = torch.device('cuda', 0), + q_size: int = 10, + num_batches: Optional[int] = None, + **kwargs: Any + ): if isinstance(data, torch.utils.data.DataLoader): self.dataloader = data else: @@ -105,5 +113,5 @@ def __next__(self): self.idx += 1 return out - def __len__(self): + def __len__(self) -> int: return self.num_batches diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py index d27bfc3196..17812a0ac5 100644 --- a/pl_bolts/datamodules/cityscapes_datamodule.py +++ b/pl_bolts/datamodules/cityscapes_datamodule.py @@ -1,3 +1,5 @@ +from typing import Any + from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py index a31b637ba9..829c485aed 100644 --- a/pl_bolts/datamodules/imagenet_datamodule.py +++ b/pl_bolts/datamodules/imagenet_datamodule.py @@ -1,5 +1,5 @@ import os -from typing import Optional +from typing import Any, Optional from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py index a50528028b..9a82f0b7ec 100644 --- a/pl_bolts/datamodules/kitti_datamodule.py +++ b/pl_bolts/datamodules/kitti_datamodule.py @@ -1,4 +1,5 @@ import os +from typing import Any, Optional import torch from pytorch_lightning import LightningDataModule @@ -133,7 +134,7 @@ def test_dataloader(self) -> DataLoader: ) return loader - def _default_transforms(self) -> transform_lib.Compose: + def _default_transforms(self) -> transforms.Compose: kitti_transforms = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753], diff --git a/pl_bolts/datamodules/sklearn_datamodule.py b/pl_bolts/datamodules/sklearn_datamodule.py index dcfd559441..e80e1dfc9a 100644 --- a/pl_bolts/datamodules/sklearn_datamodule.py +++ b/pl_bolts/datamodules/sklearn_datamodule.py @@ -1,5 +1,5 @@ import math -from typing import Any +from typing import Any, Tuple import numpy as np import torch @@ -145,7 +145,7 @@ def __init__( x_val=None, y_val=None, x_test=None, y_test=None, val_split=0.2, test_split=0.1, - num_workers:int = 2, + num_workers: int = 2, random_state: int = 1234, shuffle: bool = True, batch_size: int = 16, diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py index 50315245af..1584583101 100644 --- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py +++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py @@ -1,4 +1,5 @@ import os +from typing import Any, Optional from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py index 3b29995a1b..8d46cfd7bf 100644 --- a/pl_bolts/datamodules/stl10_datamodule.py +++ b/pl_bolts/datamodules/stl10_datamodule.py @@ -1,5 +1,5 @@ import os -from typing import Optional +from typing import Any, Optional import torch from pytorch_lightning import LightningDataModule diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py index 5b8c508904..cdcefcb2eb 100644 --- a/pl_bolts/datamodules/vision_datamodule.py +++ b/pl_bolts/datamodules/vision_datamodule.py @@ -9,11 +9,6 @@ from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg -if _TORCHVISION_AVAILABLE: - from torchvision import transforms as transform_lib -else: - warn_missing_pkg('torchvision') # pragma: no-cover - class VisionDataModule(LightningDataModule): @@ -123,7 +118,7 @@ def _get_splits(self, len_dataset: int) -> List[int]: return splits @abstractmethod - def default_transforms(self) -> transform_lib.Compose: + def default_transforms(self): """ Default transform for the dataset """ def train_dataloader(self) -> DataLoader: diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py index 34dd86811e..b9071f17be 100644 --- a/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -1,3 +1,5 @@ +from typing import Any, Dict + import torch from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader @@ -17,7 +19,7 @@ class Compose(object): Like `torchvision.transforms.compose` but works for (image, target) """ - def __init__(self, transforms): + def __init__(self, transforms: T.Compose): self.transforms = transforms def __call__(self, image, target): @@ -55,7 +57,7 @@ def _collate_fn(batch): ) -def _prepare_voc_instance(image, target): +def _prepare_voc_instance(image, target: Dict[str, Any]): """ Prepares VOC dataset into appropriate target for fasterrcnn @@ -114,8 +116,8 @@ def __init__( shuffle: bool = False, pin_memory: bool = False, drop_last: bool = False, - *args, - **kwargs, + *args: Any, + **kwargs: Any, ): if not _TORCHVISION_AVAILABLE: raise ModuleNotFoundError( # pragma: no-cover @@ -133,7 +135,7 @@ def __init__( self.drop_last = drop_last @property - def num_classes(self): + def num_classes(self) -> int: """ Return: 21 @@ -147,7 +149,7 @@ def prepare_data(self): VOCDetection(self.data_dir, year=self.year, image_set="train", download=True) VOCDetection(self.data_dir, year=self.year, image_set="val", download=True) - def train_dataloader(self, batch_size=1, transforms=None): + def train_dataloader(self, batch_size: int = 1, transforms=None) -> DataLoader: """ VOCDetection train set uses the `train` subset @@ -175,7 +177,7 @@ def train_dataloader(self, batch_size=1, transforms=None): ) return loader - def val_dataloader(self, batch_size=1, transforms=None): + def val_dataloader(self, batch_size: int = 1, transforms=None) -> DataLoader: """ VOCDetection val set uses the `val` subset @@ -202,7 +204,7 @@ def val_dataloader(self, batch_size=1, transforms=None): ) return loader - def _default_transforms(self): + def _default_transforms(self) -> T.Compose: if self.normalize: return ( lambda image, target: ( From ac3377dea28a0bce8051683388d2924ebe15f24c Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Sun, 20 Dec 2020 03:03:00 +0900 Subject: [PATCH 03/75] Removing torchvision.transforms from return typing --- pl_bolts/datamodules/binary_mnist_datamodule.py | 2 +- pl_bolts/datamodules/cifar10_datamodule.py | 2 +- pl_bolts/datamodules/cityscapes_datamodule.py | 4 ++-- pl_bolts/datamodules/fashion_mnist_datamodule.py | 2 +- pl_bolts/datamodules/imagenet_datamodule.py | 4 ++-- pl_bolts/datamodules/kitti_datamodule.py | 2 +- pl_bolts/datamodules/mnist_datamodule.py | 2 +- pl_bolts/datamodules/ssl_imagenet_datamodule.py | 2 +- pl_bolts/datamodules/stl10_datamodule.py | 2 +- pl_bolts/datamodules/vocdetection_datamodule.py | 2 +- 10 files changed, 12 insertions(+), 12 deletions(-) diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py index c713abe107..8dc02ec95e 100644 --- a/pl_bolts/datamodules/binary_mnist_datamodule.py +++ b/pl_bolts/datamodules/binary_mnist_datamodule.py @@ -98,7 +98,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self) -> transform_lib.Compose: + def default_transforms(self): if self.normalize: mnist_transforms = transform_lib.Compose( [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))] diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index 12aea1ec87..b208172ed0 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -112,7 +112,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self) -> transform_lib.Compose: + def default_transforms(self): if self.normalize: cf10_transforms = transform_lib.Compose([transform_lib.ToTensor(), cifar10_normalization()]) else: diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py index 17812a0ac5..dc8b866fba 100644 --- a/pl_bolts/datamodules/cityscapes_datamodule.py +++ b/pl_bolts/datamodules/cityscapes_datamodule.py @@ -192,7 +192,7 @@ def test_dataloader(self) -> DataLoader: ) return loader - def _default_transforms(self) -> transform_lib.Compose: + def _default_transforms(self): cityscapes_transforms = transform_lib.Compose([ transform_lib.ToTensor(), transform_lib.Normalize( @@ -202,7 +202,7 @@ def _default_transforms(self) -> transform_lib.Compose: ]) return cityscapes_transforms - def _default_target_transforms(self) -> transform_lib.Compose: + def _default_target_transforms(self): cityscapes_target_trasnforms = transform_lib.Compose([ transform_lib.ToTensor(), transform_lib.Lambda(lambda t: t.squeeze()) diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py index 9a73e6a637..a128ddfaab 100644 --- a/pl_bolts/datamodules/fashion_mnist_datamodule.py +++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py @@ -93,7 +93,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self) -> transform_lib.Compose: + def default_transforms(self): if self.normalize: mnist_transforms = transform_lib.Compose( [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))] diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py index 829c485aed..f8d8262108 100644 --- a/pl_bolts/datamodules/imagenet_datamodule.py +++ b/pl_bolts/datamodules/imagenet_datamodule.py @@ -206,7 +206,7 @@ def test_dataloader(self) -> DataLoader: ) return loader - def train_transform(self) -> transform_lib.Compose: + def train_transform(self): """ The standard imagenet transforms @@ -232,7 +232,7 @@ def train_transform(self) -> transform_lib.Compose: return preprocessing - def val_transform(self) -> transform_lib.Compose: + def val_transform(self): """ The standard imagenet transforms for validation diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py index 9a82f0b7ec..06778fdbfc 100644 --- a/pl_bolts/datamodules/kitti_datamodule.py +++ b/pl_bolts/datamodules/kitti_datamodule.py @@ -134,7 +134,7 @@ def test_dataloader(self) -> DataLoader: ) return loader - def _default_transforms(self) -> transforms.Compose: + def _default_transforms(self): kitti_transforms = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753], diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py index 76a0438a0c..c57fe8ca82 100644 --- a/pl_bolts/datamodules/mnist_datamodule.py +++ b/pl_bolts/datamodules/mnist_datamodule.py @@ -92,7 +92,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self) -> transform_lib.Compose: + def default_transforms(self): if self.normalize: mnist_transforms = transform_lib.Compose( [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))] diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py index 1584583101..96041fd4d9 100644 --- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py +++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py @@ -134,7 +134,7 @@ def test_dataloader(self, num_images_per_class: int, add_normalize: bool = False ) return loader - def _default_transforms(self) -> transform_lib.Compose: + def _default_transforms(self): mnist_transforms = transform_lib.Compose([ transform_lib.ToTensor(), imagenet_normalization() diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py index 8d46cfd7bf..5cd680a535 100644 --- a/pl_bolts/datamodules/stl10_datamodule.py +++ b/pl_bolts/datamodules/stl10_datamodule.py @@ -299,7 +299,7 @@ def val_dataloader_labeled(self) -> DataLoader: ) return loader - def _default_transforms(self) -> transform_lib.Compose: + def _default_transforms(self): data_transforms = transform_lib.Compose([ transform_lib.ToTensor(), stl10_normalization() diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py index b9071f17be..3bea4ec2d4 100644 --- a/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -204,7 +204,7 @@ def val_dataloader(self, batch_size: int = 1, transforms=None) -> DataLoader: ) return loader - def _default_transforms(self) -> T.Compose: + def _default_transforms(self): if self.normalize: return ( lambda image, target: ( From a4c39c787fc83f998f42dfdab038d4a82a079ca2 Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Sun, 20 Dec 2020 03:09:03 +0900 Subject: [PATCH 04/75] Remove more torchvision.transforms typing --- pl_bolts/datamodules/vocdetection_datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py index 3bea4ec2d4..448df864b6 100644 --- a/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -19,7 +19,7 @@ class Compose(object): Like `torchvision.transforms.compose` but works for (image, target) """ - def __init__(self, transforms: T.Compose): + def __init__(self, transforms): self.transforms = transforms def __call__(self, image, target): From ffa0cb9fe6e3b2cd6f730520d55a494280f9d7f5 Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Mon, 21 Dec 2020 10:13:11 +0900 Subject: [PATCH 05/75] Removing return typing --- pl_bolts/datamodules/cifar10_datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index b208172ed0..85ba4de6e7 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -71,7 +71,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ) -> None: + ): """ Args: data_dir: Where to save/load the data From 5e6c5d406c0151bf4629fce60bf28dcf268fd0b9 Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Mon, 21 Dec 2020 16:00:55 +0900 Subject: [PATCH 06/75] Add `None` for optional arguments --- pl_bolts/datamodules/cifar10_datamodule.py | 2 +- pl_bolts/datamodules/kitti_datamodule.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index 85ba4de6e7..534774684f 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -146,7 +146,7 @@ class TinyCIFAR10DataModule(CIFAR10DataModule): def __init__( self, - data_dir: Optional[str], + data_dir: Optional[str] = None, val_split: int = 50, num_workers: int = 16, num_samples: int = 100, diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py index 06778fdbfc..0067a1e53d 100644 --- a/pl_bolts/datamodules/kitti_datamodule.py +++ b/pl_bolts/datamodules/kitti_datamodule.py @@ -22,7 +22,7 @@ class KittiDataModule(LightningDataModule): def __init__( self, - data_dir: Optional[str], + data_dir: Optional[str] = None, val_split: float = 0.2, test_split: float = 0.1, num_workers: int = 16, From 3a5a0ab24859431edb150df1934ff74b6b2e3b9f Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Mon, 21 Dec 2020 16:04:36 +0900 Subject: [PATCH 07/75] Remove unnecessary import --- pl_bolts/datamodules/vision_datamodule.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py index cdcefcb2eb..06ddc7ab18 100644 --- a/pl_bolts/datamodules/vision_datamodule.py +++ b/pl_bolts/datamodules/vision_datamodule.py @@ -6,7 +6,6 @@ from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader, Dataset, random_split -from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg From 30579ed0a42703a7207c8b9cc0afc8143891ec32 Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Mon, 21 Dec 2020 16:04:36 +0900 Subject: [PATCH 08/75] Remove unnecessary import --- pl_bolts/datamodules/vision_datamodule.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py index cdcefcb2eb..42252c4edf 100644 --- a/pl_bolts/datamodules/vision_datamodule.py +++ b/pl_bolts/datamodules/vision_datamodule.py @@ -6,9 +6,6 @@ from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader, Dataset, random_split -from pl_bolts.utils import _TORCHVISION_AVAILABLE -from pl_bolts.utils.warnings import warn_missing_pkg - class VisionDataModule(LightningDataModule): From c6759311a1a24b0cc6cb8d350d96cdc582f26317 Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Mon, 21 Dec 2020 16:51:36 +0900 Subject: [PATCH 09/75] Add `None` return type --- pl_bolts/datamodules/async_dataloader.py | 4 ++-- pl_bolts/datamodules/binary_mnist_datamodule.py | 2 +- pl_bolts/datamodules/cifar10_datamodule.py | 4 ++-- pl_bolts/datamodules/cityscapes_datamodule.py | 2 +- pl_bolts/datamodules/experience_source.py | 4 ++-- pl_bolts/datamodules/fashion_mnist_datamodule.py | 2 +- pl_bolts/datamodules/imagenet_datamodule.py | 6 +++--- pl_bolts/datamodules/kitti_datamodule.py | 2 +- pl_bolts/datamodules/mnist_datamodule.py | 2 +- pl_bolts/datamodules/sklearn_datamodule.py | 12 ++++++++---- pl_bolts/datamodules/ssl_imagenet_datamodule.py | 6 +++--- pl_bolts/datamodules/stl10_datamodule.py | 4 ++-- pl_bolts/datamodules/vision_datamodule.py | 6 +++--- pl_bolts/datamodules/vocdetection_datamodule.py | 6 +++--- 14 files changed, 33 insertions(+), 29 deletions(-) diff --git a/pl_bolts/datamodules/async_dataloader.py b/pl_bolts/datamodules/async_dataloader.py index 38a0b9bb58..224f34d5ee 100644 --- a/pl_bolts/datamodules/async_dataloader.py +++ b/pl_bolts/datamodules/async_dataloader.py @@ -34,7 +34,7 @@ def __init__( q_size: int = 10, num_batches: Optional[int] = None, **kwargs: Any - ): + ) -> None: if isinstance(data, torch.utils.data.DataLoader): self.dataloader = data else: @@ -57,7 +57,7 @@ def __init__( self.np_str_obj_array_pattern = re.compile(r'[SaUO]') - def load_loop(self): # The loop that will load into the queue in the background + def load_loop(self) -> None: # The loop that will load into the queue in the background for i, sample in enumerate(self.dataloader): self.queue.put(self.load_instance(sample)) if i == len(self): diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py index 8dc02ec95e..142b3d54ef 100644 --- a/pl_bolts/datamodules/binary_mnist_datamodule.py +++ b/pl_bolts/datamodules/binary_mnist_datamodule.py @@ -56,7 +56,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: Where to save/load the data diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index 534774684f..2cb894d749 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -71,7 +71,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: Where to save/load the data @@ -153,7 +153,7 @@ def __init__( labels: Optional[Sequence] = (1, 5, 8), *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: where to save/load the data diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py index dc8b866fba..61c1ae2bef 100644 --- a/pl_bolts/datamodules/cityscapes_datamodule.py +++ b/pl_bolts/datamodules/cityscapes_datamodule.py @@ -73,7 +73,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: where to load the data from path, i.e. where directory leftImg8bit and gtFine or gtCoarse diff --git a/pl_bolts/datamodules/experience_source.py b/pl_bolts/datamodules/experience_source.py index 5fe1332dfd..fac4e82f25 100644 --- a/pl_bolts/datamodules/experience_source.py +++ b/pl_bolts/datamodules/experience_source.py @@ -30,7 +30,7 @@ class ExperienceSourceDataset(IterableDataset): The logic for the experience source and how the batch is generated is defined the Lightning model itself """ - def __init__(self, generate_batch: Callable): + def __init__(self, generate_batch: Callable) -> None: self.generate_batch = generate_batch def __iter__(self) -> Iterable: @@ -243,7 +243,7 @@ def pop_rewards_steps(self): class DiscountedExperienceSource(ExperienceSource): """Outputs experiences with a discounted reward over N steps""" - def __init__(self, env: Env, agent, n_steps: int = 1, gamma: float = 0.99): + def __init__(self, env: Env, agent, n_steps: int = 1, gamma: float = 0.99) -> None: super().__init__(env, agent, (n_steps + 1)) self.gamma = gamma self.steps = n_steps diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py index a128ddfaab..833c4599a6 100644 --- a/pl_bolts/datamodules/fashion_mnist_datamodule.py +++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py @@ -57,7 +57,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: Where to save/load the data diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py index f8d8262108..6f06913f9f 100644 --- a/pl_bolts/datamodules/imagenet_datamodule.py +++ b/pl_bolts/datamodules/imagenet_datamodule.py @@ -60,7 +60,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: path to the imagenet dataset file @@ -103,14 +103,14 @@ def num_classes(self) -> int: """ return 1000 - def _verify_splits(self, data_dir: str, split: str): + def _verify_splits(self, data_dir: str, split: str) -> None: dirs = os.listdir(data_dir) if split not in dirs: raise FileNotFoundError(f'a {split} Imagenet split was not found in {data_dir},' f' make sure the folder contains a subfolder named {split}') - def prepare_data(self): + def prepare_data(self) -> None: """ This method already assumes you have imagenet2012 downloaded. It validates the data using the meta.bin. diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py index 0067a1e53d..3cf26dc762 100644 --- a/pl_bolts/datamodules/kitti_datamodule.py +++ b/pl_bolts/datamodules/kitti_datamodule.py @@ -33,7 +33,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Kitti train, validation and test dataloaders. diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py index c57fe8ca82..1dd5e927b6 100644 --- a/pl_bolts/datamodules/mnist_datamodule.py +++ b/pl_bolts/datamodules/mnist_datamodule.py @@ -56,7 +56,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: Where to save/load the data diff --git a/pl_bolts/datamodules/sklearn_datamodule.py b/pl_bolts/datamodules/sklearn_datamodule.py index e80e1dfc9a..d9477acc0b 100644 --- a/pl_bolts/datamodules/sklearn_datamodule.py +++ b/pl_bolts/datamodules/sklearn_datamodule.py @@ -28,7 +28,9 @@ class SklearnDataset(Dataset): >>> len(dataset) 506 """ - def __init__(self, X: np.ndarray, y: np.ndarray, X_transform: Any = None, y_transform: Any = None): + def __init__( + self, X: np.ndarray, y: np.ndarray, X_transform: Any = None, y_transform: Any = None + ) -> None: """ Args: X: Numpy ndarray @@ -75,7 +77,9 @@ class TensorDataset(Dataset): >>> len(dataset) 10 """ - def __init__(self, X: torch.Tensor, y: torch.Tensor, X_transform: Any = None, y_transform: Any = None): + def __init__( + self, X: torch.Tensor, y: torch.Tensor, X_transform: Any = None, y_transform: Any = None + ) -> None: """ Args: X: PyTorch tensor @@ -153,7 +157,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: super().__init__(*args, **kwargs) self.num_workers = num_workers @@ -201,7 +205,7 @@ def _init_datasets( y_val: np.ndarray, x_test: np.ndarray, y_test: np.ndarray - ): + ) -> None: self.train_dataset = SklearnDataset(X, y) self.val_dataset = SklearnDataset(x_val, y_val) self.test_dataset = SklearnDataset(x_test, y_test) diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py index 96041fd4d9..3dbda03527 100644 --- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py +++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py @@ -30,7 +30,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: super().__init__(*args, **kwargs) if not _TORCHVISION_AVAILABLE: @@ -50,14 +50,14 @@ def __init__( def num_classes(self) -> int: return 1000 - def _verify_splits(self, data_dir: str, split: str): + def _verify_splits(self, data_dir: str, split: str) -> None: dirs = os.listdir(data_dir) if split not in dirs: raise FileNotFoundError(f'a {split} Imagenet split was not found in {data_dir}, make sure the' f' folder contains a subfolder named {split}') - def prepare_data(self): + def prepare_data(self) -> None: # imagenet cannot be downloaded... must provide path to folder with the train/val splits self._verify_splits(self.data_dir, 'train') self._verify_splits(self.data_dir, 'val') diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py index 5cd680a535..79420be149 100644 --- a/pl_bolts/datamodules/stl10_datamodule.py +++ b/pl_bolts/datamodules/stl10_datamodule.py @@ -65,7 +65,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: where to save/load the data @@ -102,7 +102,7 @@ def __init__( def num_classes(self) -> int: return 10 - def prepare_data(self): + def prepare_data(self) -> None: """ Downloads the unlabeled, train and test split """ diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py index 42252c4edf..2144f0f509 100644 --- a/pl_bolts/datamodules/vision_datamodule.py +++ b/pl_bolts/datamodules/vision_datamodule.py @@ -29,7 +29,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: Where to save/load the data @@ -56,14 +56,14 @@ def __init__( self.pin_memory = pin_memory self.drop_last = drop_last - def prepare_data(self): + def prepare_data(self) -> None: """ Saves files to data_dir """ self.dataset_cls(self.data_dir, train=True, download=True) self.dataset_cls(self.data_dir, train=False, download=True) - def setup(self, stage: Optional[str] = None): + def setup(self, stage: Optional[str] = None) -> None: """ Creates train, val, and test dataset """ diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py index 448df864b6..e7fc989330 100644 --- a/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -19,7 +19,7 @@ class Compose(object): Like `torchvision.transforms.compose` but works for (image, target) """ - def __init__(self, transforms): + def __init__(self, transforms) -> None: self.transforms = transforms def __call__(self, image, target): @@ -118,7 +118,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: if not _TORCHVISION_AVAILABLE: raise ModuleNotFoundError( # pragma: no-cover 'You want to use VOC dataset loaded from `torchvision` which is not installed yet.' @@ -142,7 +142,7 @@ def num_classes(self) -> int: """ return 21 - def prepare_data(self): + def prepare_data(self) -> None: """ Saves VOCDetection files to data_dir """ From 267649c125273644deeb31746e54d5c5d95d8357 Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Tue, 5 Jan 2021 20:30:43 +0900 Subject: [PATCH 10/75] Add type for torchvision transforms --- pl_bolts/datamodules/binary_mnist_datamodule.py | 4 +++- pl_bolts/datamodules/cifar10_datamodule.py | 4 +++- pl_bolts/datamodules/cityscapes_datamodule.py | 6 ++++-- pl_bolts/datamodules/fashion_mnist_datamodule.py | 4 +++- pl_bolts/datamodules/imagenet_datamodule.py | 6 ++++-- pl_bolts/datamodules/kitti_datamodule.py | 4 +++- pl_bolts/datamodules/mnist_datamodule.py | 4 +++- pl_bolts/datamodules/ssl_imagenet_datamodule.py | 4 +++- pl_bolts/datamodules/stl10_datamodule.py | 4 +++- pl_bolts/datamodules/vision_datamodule.py | 9 ++++++++- 10 files changed, 37 insertions(+), 12 deletions(-) diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py index 142b3d54ef..ad17360d08 100644 --- a/pl_bolts/datamodules/binary_mnist_datamodule.py +++ b/pl_bolts/datamodules/binary_mnist_datamodule.py @@ -7,8 +7,10 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib + from torchvision.transforms import Compose else: # pragma: no-cover warn_missing_pkg('torchvision') + Compose = object class BinaryMNISTDataModule(VisionDataModule): @@ -98,7 +100,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self): + def default_transforms(self) -> Compose: if self.normalize: mnist_transforms = transform_lib.Compose( [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))] diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index 2cb894d749..9dbf10b670 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -9,9 +9,11 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib from torchvision.datasets import CIFAR10 + from torchvision.transforms import Compose else: warn_missing_pkg('torchvision') # pragma: no-cover CIFAR10 = None + Compose = object class CIFAR10DataModule(VisionDataModule): @@ -112,7 +114,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self): + def default_transforms(self) -> Compose: if self.normalize: cf10_transforms = transform_lib.Compose([transform_lib.ToTensor(), cifar10_normalization()]) else: diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py index 61c1ae2bef..130e333976 100644 --- a/pl_bolts/datamodules/cityscapes_datamodule.py +++ b/pl_bolts/datamodules/cityscapes_datamodule.py @@ -9,8 +9,10 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib from torchvision.datasets import Cityscapes + from torchvision.transforms import Compose else: warn_missing_pkg('torchvision') # pragma: no-cover + Compose = object class CityscapesDataModule(LightningDataModule): @@ -192,7 +194,7 @@ def test_dataloader(self) -> DataLoader: ) return loader - def _default_transforms(self): + def _default_transforms(self) -> Compose: cityscapes_transforms = transform_lib.Compose([ transform_lib.ToTensor(), transform_lib.Normalize( @@ -202,7 +204,7 @@ def _default_transforms(self): ]) return cityscapes_transforms - def _default_target_transforms(self): + def _default_target_transforms(self) -> Compose: cityscapes_target_trasnforms = transform_lib.Compose([ transform_lib.ToTensor(), transform_lib.Lambda(lambda t: t.squeeze()) diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py index 833c4599a6..b37221bc74 100644 --- a/pl_bolts/datamodules/fashion_mnist_datamodule.py +++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py @@ -7,9 +7,11 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib from torchvision.datasets import FashionMNIST + from torchvision.transforms import Compose else: warn_missing_pkg('torchvision') # pragma: no-cover FashionMNIST = None + Compose = object class FashionMNISTDataModule(VisionDataModule): @@ -93,7 +95,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self): + def default_transforms(self) -> Compose: if self.normalize: mnist_transforms = transform_lib.Compose( [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))] diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py index 6f06913f9f..db2fc68c0b 100644 --- a/pl_bolts/datamodules/imagenet_datamodule.py +++ b/pl_bolts/datamodules/imagenet_datamodule.py @@ -11,8 +11,10 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib + from torchvision.transforms import Compose else: warn_missing_pkg('torchvision') # pragma: no-cover + Compose = object class ImagenetDataModule(LightningDataModule): @@ -206,7 +208,7 @@ def test_dataloader(self) -> DataLoader: ) return loader - def train_transform(self): + def train_transform(self) -> Compose: """ The standard imagenet transforms @@ -232,7 +234,7 @@ def train_transform(self): return preprocessing - def val_transform(self): + def val_transform(self) -> Compose: """ The standard imagenet transforms for validation diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py index 3cf26dc762..b63040f3bf 100644 --- a/pl_bolts/datamodules/kitti_datamodule.py +++ b/pl_bolts/datamodules/kitti_datamodule.py @@ -12,8 +12,10 @@ if _TORCHVISION_AVAILABLE: import torchvision.transforms as transforms + from torchvision.transforms import Compose else: # pragma: no cover warn_missing_pkg('torchvision') + Compose = object class KittiDataModule(LightningDataModule): @@ -134,7 +136,7 @@ def test_dataloader(self) -> DataLoader: ) return loader - def _default_transforms(self): + def _default_transforms(self) -> Compose: kitti_transforms = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753], diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py index 1dd5e927b6..711460023c 100644 --- a/pl_bolts/datamodules/mnist_datamodule.py +++ b/pl_bolts/datamodules/mnist_datamodule.py @@ -7,9 +7,11 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib from torchvision.datasets import MNIST + from torchvision.transforms import Compose else: warn_missing_pkg('torchvision') # pragma: no-cover MNIST = None + Compose = object class MNISTDataModule(VisionDataModule): @@ -92,7 +94,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self): + def default_transforms(self) -> Compose: if self.normalize: mnist_transforms = transform_lib.Compose( [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))] diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py index 3dbda03527..d575eb2d01 100644 --- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py +++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py @@ -11,8 +11,10 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib + from torchvision.transforms import Compose else: warn_missing_pkg('torchvision') # pragma: no-cover + Compose = object class SSLImagenetDataModule(LightningDataModule): # pragma: no cover @@ -134,7 +136,7 @@ def test_dataloader(self, num_images_per_class: int, add_normalize: bool = False ) return loader - def _default_transforms(self): + def _default_transforms(self) -> Compose: mnist_transforms = transform_lib.Compose([ transform_lib.ToTensor(), imagenet_normalization() diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py index 79420be149..f9ac77e140 100644 --- a/pl_bolts/datamodules/stl10_datamodule.py +++ b/pl_bolts/datamodules/stl10_datamodule.py @@ -13,8 +13,10 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib from torchvision.datasets import STL10 + from torchvision.transforms import Compose else: warn_missing_pkg('torchvision') # pragma: no-cover + Compose = object class STL10DataModule(LightningDataModule): # pragma: no cover @@ -299,7 +301,7 @@ def val_dataloader_labeled(self) -> DataLoader: ) return loader - def _default_transforms(self): + def _default_transforms(self) -> Compose: data_transforms = transform_lib.Compose([ transform_lib.ToTensor(), stl10_normalization() diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py index 2144f0f509..92e6723968 100644 --- a/pl_bolts/datamodules/vision_datamodule.py +++ b/pl_bolts/datamodules/vision_datamodule.py @@ -6,6 +6,13 @@ from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader, Dataset, random_split +from pl_bolts.utils import _TORCHVISION_AVAILABLE + +if _TORCHVISION_AVAILABLE: + from torchvision.transforms import Compose +else: + Compose = object + class VisionDataModule(LightningDataModule): @@ -115,7 +122,7 @@ def _get_splits(self, len_dataset: int) -> List[int]: return splits @abstractmethod - def default_transforms(self): + def default_transforms(self) -> Compose: """ Default transform for the dataset """ def train_dataloader(self) -> DataLoader: From 3938fff80783b67ed26245a73dce63e7c52f42fb Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Fri, 18 Dec 2020 19:58:16 +0900 Subject: [PATCH 11/75] Adding types to datamodules --- .../datamodules/binary_mnist_datamodule.py | 4 +- pl_bolts/datamodules/cifar10_datamodule.py | 6 +-- pl_bolts/datamodules/cityscapes_datamodule.py | 16 ++++---- .../datamodules/fashion_mnist_datamodule.py | 4 +- pl_bolts/datamodules/imagenet_datamodule.py | 18 ++++----- pl_bolts/datamodules/kitti_datamodule.py | 14 +++---- pl_bolts/datamodules/mnist_datamodule.py | 4 +- pl_bolts/datamodules/sklearn_datamodule.py | 38 +++++++++++-------- .../datamodules/ssl_imagenet_datamodule.py | 22 +++++------ pl_bolts/datamodules/stl10_datamodule.py | 22 +++++------ pl_bolts/datamodules/vision_datamodule.py | 16 ++++++-- 11 files changed, 90 insertions(+), 74 deletions(-) diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py index 142b3d54ef..c713abe107 100644 --- a/pl_bolts/datamodules/binary_mnist_datamodule.py +++ b/pl_bolts/datamodules/binary_mnist_datamodule.py @@ -56,7 +56,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ) -> None: + ): """ Args: data_dir: Where to save/load the data @@ -98,7 +98,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self): + def default_transforms(self) -> transform_lib.Compose: if self.normalize: mnist_transforms = transform_lib.Compose( [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))] diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index afb2df8c9a..12aea1ec87 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -112,7 +112,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self): + def default_transforms(self) -> transform_lib.Compose: if self.normalize: cf10_transforms = transform_lib.Compose([transform_lib.ToTensor(), cifar10_normalization()]) else: @@ -146,14 +146,14 @@ class TinyCIFAR10DataModule(CIFAR10DataModule): def __init__( self, - data_dir: str, + data_dir: Optional[str], val_split: int = 50, num_workers: int = 16, num_samples: int = 100, labels: Optional[Sequence] = (1, 5, 8), *args: Any, **kwargs: Any, - ) -> None: + ): """ Args: data_dir: where to save/load the data diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py index b851617225..d27bfc3196 100644 --- a/pl_bolts/datamodules/cityscapes_datamodule.py +++ b/pl_bolts/datamodules/cityscapes_datamodule.py @@ -69,8 +69,8 @@ def __init__( shuffle: bool = False, pin_memory: bool = False, drop_last: bool = False, - *args, - **kwargs, + *args: Any, + **kwargs: Any, ): """ Args: @@ -109,14 +109,14 @@ def __init__( self.target_transforms = None @property - def num_classes(self): + def num_classes(self) -> int: """ Return: 30 """ return 30 - def train_dataloader(self): + def train_dataloader(self) -> DataLoader: """ Cityscapes train set """ @@ -141,7 +141,7 @@ def train_dataloader(self): ) return loader - def val_dataloader(self): + def val_dataloader(self) -> DataLoader: """ Cityscapes val set """ @@ -166,7 +166,7 @@ def val_dataloader(self): ) return loader - def test_dataloader(self): + def test_dataloader(self) -> DataLoader: """ Cityscapes test set """ @@ -190,7 +190,7 @@ def test_dataloader(self): ) return loader - def _default_transforms(self): + def _default_transforms(self) -> transform_lib.Compose: cityscapes_transforms = transform_lib.Compose([ transform_lib.ToTensor(), transform_lib.Normalize( @@ -200,7 +200,7 @@ def _default_transforms(self): ]) return cityscapes_transforms - def _default_target_transforms(self): + def _default_target_transforms(self) -> transform_lib.Compose: cityscapes_target_trasnforms = transform_lib.Compose([ transform_lib.ToTensor(), transform_lib.Lambda(lambda t: t.squeeze()) diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py index 833c4599a6..9a73e6a637 100644 --- a/pl_bolts/datamodules/fashion_mnist_datamodule.py +++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py @@ -57,7 +57,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ) -> None: + ): """ Args: data_dir: Where to save/load the data @@ -93,7 +93,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self): + def default_transforms(self) -> transform_lib.Compose: if self.normalize: mnist_transforms = transform_lib.Compose( [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))] diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py index 38546e29ee..a31b637ba9 100644 --- a/pl_bolts/datamodules/imagenet_datamodule.py +++ b/pl_bolts/datamodules/imagenet_datamodule.py @@ -58,8 +58,8 @@ def __init__( shuffle: bool = False, pin_memory: bool = False, drop_last: bool = False, - *args, - **kwargs, + *args: Any, + **kwargs: Any, ): """ Args: @@ -94,7 +94,7 @@ def __init__( self.num_samples = 1281167 - self.num_imgs_per_val_class * self.num_classes @property - def num_classes(self): + def num_classes(self) -> int: """ Return: @@ -103,7 +103,7 @@ def num_classes(self): """ return 1000 - def _verify_splits(self, data_dir, split): + def _verify_splits(self, data_dir: str, split: str): dirs = os.listdir(data_dir) if split not in dirs: @@ -138,7 +138,7 @@ def prepare_data(self): UnlabeledImagenet.generate_meta_bins(path) """) - def train_dataloader(self): + def train_dataloader(self) -> DataLoader: """ Uses the train split of imagenet2012 and puts away a portion of it for the validation split """ @@ -160,7 +160,7 @@ def train_dataloader(self): ) return loader - def val_dataloader(self): + def val_dataloader(self) -> DataLoader: """ Uses the part of the train split of imagenet2012 that was not used for training via `num_imgs_per_val_class` @@ -185,7 +185,7 @@ def val_dataloader(self): ) return loader - def test_dataloader(self): + def test_dataloader(self) -> DataLoader: """ Uses the validation split of imagenet2012 for testing """ @@ -206,7 +206,7 @@ def test_dataloader(self): ) return loader - def train_transform(self): + def train_transform(self) -> transform_lib.Compose: """ The standard imagenet transforms @@ -232,7 +232,7 @@ def train_transform(self): return preprocessing - def val_transform(self): + def val_transform(self) -> transform_lib.Compose: """ The standard imagenet transforms for validation diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py index 433e7fffed..a50528028b 100644 --- a/pl_bolts/datamodules/kitti_datamodule.py +++ b/pl_bolts/datamodules/kitti_datamodule.py @@ -21,7 +21,7 @@ class KittiDataModule(LightningDataModule): def __init__( self, - data_dir: str, + data_dir: Optional[str], val_split: float = 0.2, test_split: float = 0.1, num_workers: int = 16, @@ -30,8 +30,8 @@ def __init__( shuffle: bool = False, pin_memory: bool = False, drop_last: bool = False, - *args, - **kwargs, + *args: Any, + **kwargs: Any, ): """ Kitti train, validation and test dataloaders. @@ -100,7 +100,7 @@ def __init__( lengths=[train_len, val_len, test_len], generator=torch.Generator().manual_seed(self.seed)) - def train_dataloader(self): + def train_dataloader(self) -> DataLoader: loader = DataLoader( self.trainset, batch_size=self.batch_size, @@ -111,7 +111,7 @@ def train_dataloader(self): ) return loader - def val_dataloader(self): + def val_dataloader(self) -> DataLoader: loader = DataLoader( self.valset, batch_size=self.batch_size, @@ -122,7 +122,7 @@ def val_dataloader(self): ) return loader - def test_dataloader(self): + def test_dataloader(self) -> DataLoader: loader = DataLoader( self.testset, batch_size=self.batch_size, @@ -133,7 +133,7 @@ def test_dataloader(self): ) return loader - def _default_transforms(self): + def _default_transforms(self) -> transform_lib.Compose: kitti_transforms = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753], diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py index 1dd5e927b6..76a0438a0c 100644 --- a/pl_bolts/datamodules/mnist_datamodule.py +++ b/pl_bolts/datamodules/mnist_datamodule.py @@ -56,7 +56,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ) -> None: + ): """ Args: data_dir: Where to save/load the data @@ -92,7 +92,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self): + def default_transforms(self) -> transform_lib.Compose: if self.normalize: mnist_transforms = transform_lib.Compose( [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))] diff --git a/pl_bolts/datamodules/sklearn_datamodule.py b/pl_bolts/datamodules/sklearn_datamodule.py index ed262b10c8..dcfd559441 100644 --- a/pl_bolts/datamodules/sklearn_datamodule.py +++ b/pl_bolts/datamodules/sklearn_datamodule.py @@ -42,10 +42,10 @@ def __init__(self, X: np.ndarray, y: np.ndarray, X_transform: Any = None, y_tran self.X_transform = X_transform self.y_transform = y_transform - def __len__(self): + def __len__(self) -> int: return len(self.X) - def __getitem__(self, idx): + def __getitem__(self, idx) -> Tuple[np.ndarray, np.ndarray]: x = self.X[idx].astype(np.float32) y = self.Y[idx] @@ -89,10 +89,10 @@ def __init__(self, X: torch.Tensor, y: torch.Tensor, X_transform: Any = None, y_ self.X_transform = X_transform self.y_transform = y_transform - def __len__(self): + def __len__(self) -> int: return len(self.X) - def __getitem__(self, idx): + def __getitem__(self, idx) -> Tuple[torch.Tensor, torch.Tensor]: x = self.X[idx].float() y = self.Y[idx] @@ -145,14 +145,14 @@ def __init__( x_val=None, y_val=None, x_test=None, y_test=None, val_split=0.2, test_split=0.1, - num_workers=2, - random_state=1234, - shuffle=True, + num_workers:int = 2, + random_state: int = 1234, + shuffle: bool = True, batch_size: int = 16, - pin_memory=False, - drop_last=False, - *args, - **kwargs, + pin_memory: bool = False, + drop_last: bool = False, + *args: Any, + **kwargs: Any, ): super().__init__(*args, **kwargs) @@ -193,12 +193,20 @@ def __init__( self._init_datasets(X, y, x_val, y_val, x_test, y_test) - def _init_datasets(self, X, y, x_val, y_val, x_test, y_test): + def _init_datasets( + self, + X: np.ndarray, + y: np.ndarray, + x_val: np.ndarray, + y_val: np.ndarray, + x_test: np.ndarray, + y_test: np.ndarray + ): self.train_dataset = SklearnDataset(X, y) self.val_dataset = SklearnDataset(x_val, y_val) self.test_dataset = SklearnDataset(x_test, y_test) - def train_dataloader(self): + def train_dataloader(self) -> DataLoader: loader = DataLoader( self.train_dataset, batch_size=self.batch_size, @@ -209,7 +217,7 @@ def train_dataloader(self): ) return loader - def val_dataloader(self): + def val_dataloader(self) -> DataLoader: loader = DataLoader( self.val_dataset, batch_size=self.batch_size, @@ -220,7 +228,7 @@ def val_dataloader(self): ) return loader - def test_dataloader(self): + def test_dataloader(self) -> DataLoader: loader = DataLoader( self.test_dataset, batch_size=self.batch_size, diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py index 03e459fd5e..50315245af 100644 --- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py +++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py @@ -20,15 +20,15 @@ class SSLImagenetDataModule(LightningDataModule): # pragma: no cover def __init__( self, - data_dir, - meta_dir=None, - num_workers=16, + data_dir: str, + meta_dir: Optional[str] = None, + num_workers: int = 16, batch_size: int = 32, shuffle: bool = False, pin_memory: bool = False, drop_last: bool = False, - *args, - **kwargs, + *args: Any, + **kwargs: Any, ): super().__init__(*args, **kwargs) @@ -46,10 +46,10 @@ def __init__( self.drop_last = drop_last @property - def num_classes(self): + def num_classes(self) -> int: return 1000 - def _verify_splits(self, data_dir, split): + def _verify_splits(self, data_dir: str, split: str): dirs = os.listdir(data_dir) if split not in dirs: @@ -79,7 +79,7 @@ def prepare_data(self): UnlabeledImagenet.generate_meta_bins(path) """) - def train_dataloader(self, num_images_per_class=-1, add_normalize=False): + def train_dataloader(self, num_images_per_class: int = -1, add_normalize: bool = False) -> DataLoader: transforms = self._default_transforms() if self.train_transforms is None else self.train_transforms dataset = UnlabeledImagenet(self.data_dir, @@ -97,7 +97,7 @@ def train_dataloader(self, num_images_per_class=-1, add_normalize=False): ) return loader - def val_dataloader(self, num_images_per_class=50, add_normalize=False): + def val_dataloader(self, num_images_per_class: int = 50, add_normalize: bool = False) -> DataLoader: transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms dataset = UnlabeledImagenet(self.data_dir, @@ -115,7 +115,7 @@ def val_dataloader(self, num_images_per_class=50, add_normalize=False): ) return loader - def test_dataloader(self, num_images_per_class, add_normalize=False): + def test_dataloader(self, num_images_per_class: int, add_normalize: bool = False) -> DataLoader: transforms = self._default_transforms() if self.test_transforms is None else self.test_transforms dataset = UnlabeledImagenet(self.data_dir, @@ -133,7 +133,7 @@ def test_dataloader(self, num_images_per_class, add_normalize=False): ) return loader - def _default_transforms(self): + def _default_transforms(self) -> transform_lib.Compose: mnist_transforms = transform_lib.Compose([ transform_lib.ToTensor(), imagenet_normalization() diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py index c666db9b9b..3b29995a1b 100644 --- a/pl_bolts/datamodules/stl10_datamodule.py +++ b/pl_bolts/datamodules/stl10_datamodule.py @@ -63,8 +63,8 @@ def __init__( shuffle: bool = False, pin_memory: bool = False, drop_last: bool = False, - *args, - **kwargs, + *args: Any, + **kwargs: Any, ): """ Args: @@ -99,7 +99,7 @@ def __init__( self.num_unlabeled_samples = 100000 - unlabeled_val_split @property - def num_classes(self): + def num_classes(self) -> int: return 10 def prepare_data(self): @@ -110,7 +110,7 @@ def prepare_data(self): STL10(self.data_dir, split='train', download=True, transform=transform_lib.ToTensor()) STL10(self.data_dir, split='test', download=True, transform=transform_lib.ToTensor()) - def train_dataloader(self): + def train_dataloader(self) -> DataLoader: """ Loads the 'unlabeled' split minus a portion set aside for validation via `unlabeled_val_split`. """ @@ -131,7 +131,7 @@ def train_dataloader(self): ) return loader - def train_dataloader_mixed(self): + def train_dataloader_mixed(self) -> DataLoader: """ Loads a portion of the 'unlabeled' training data and 'train' (labeled) data. both portions have a subset removed for validation via `unlabeled_val_split` and `train_val_split` @@ -169,7 +169,7 @@ def train_dataloader_mixed(self): ) return loader - def val_dataloader(self): + def val_dataloader(self) -> DataLoader: """ Loads a portion of the 'unlabeled' training data set aside for validation The val dataset = (unlabeled - train_val_split) @@ -196,7 +196,7 @@ def val_dataloader(self): ) return loader - def val_dataloader_mixed(self): + def val_dataloader_mixed(self) -> DataLoader: """ Loads a portion of the 'unlabeled' training data set aside for validation along with the portion of the 'train' dataset to be used for validation @@ -239,7 +239,7 @@ def val_dataloader_mixed(self): ) return loader - def test_dataloader(self): + def test_dataloader(self) -> DataLoader: """ Loads the test split of STL10 @@ -260,7 +260,7 @@ def test_dataloader(self): ) return loader - def train_dataloader_labeled(self): + def train_dataloader_labeled(self) -> DataLoader: transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms dataset = STL10(self.data_dir, split='train', download=False, transform=transforms) @@ -278,7 +278,7 @@ def train_dataloader_labeled(self): ) return loader - def val_dataloader_labeled(self): + def val_dataloader_labeled(self) -> DataLoader: transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms dataset = STL10(self.data_dir, split='train', @@ -299,7 +299,7 @@ def val_dataloader_labeled(self): ) return loader - def _default_transforms(self): + def _default_transforms(self) -> transform_lib.Compose: data_transforms = transform_lib.Compose([ transform_lib.ToTensor(), stl10_normalization() diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py index 2144f0f509..5b8c508904 100644 --- a/pl_bolts/datamodules/vision_datamodule.py +++ b/pl_bolts/datamodules/vision_datamodule.py @@ -6,6 +6,14 @@ from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader, Dataset, random_split +from pl_bolts.utils import _TORCHVISION_AVAILABLE +from pl_bolts.utils.warnings import warn_missing_pkg + +if _TORCHVISION_AVAILABLE: + from torchvision import transforms as transform_lib +else: + warn_missing_pkg('torchvision') # pragma: no-cover + class VisionDataModule(LightningDataModule): @@ -29,7 +37,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ) -> None: + ): """ Args: data_dir: Where to save/load the data @@ -56,14 +64,14 @@ def __init__( self.pin_memory = pin_memory self.drop_last = drop_last - def prepare_data(self) -> None: + def prepare_data(self): """ Saves files to data_dir """ self.dataset_cls(self.data_dir, train=True, download=True) self.dataset_cls(self.data_dir, train=False, download=True) - def setup(self, stage: Optional[str] = None) -> None: + def setup(self, stage: Optional[str] = None): """ Creates train, val, and test dataset """ @@ -115,7 +123,7 @@ def _get_splits(self, len_dataset: int) -> List[int]: return splits @abstractmethod - def default_transforms(self): + def default_transforms(self) -> transform_lib.Compose: """ Default transform for the dataset """ def train_dataloader(self) -> DataLoader: From cae3f461afe23ef67f8e11b80ef19cce72905e93 Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Fri, 18 Dec 2020 20:20:20 +0900 Subject: [PATCH 12/75] Fixing typing imports --- pl_bolts/datamodules/async_dataloader.py | 14 +++++++++++--- pl_bolts/datamodules/cityscapes_datamodule.py | 2 ++ pl_bolts/datamodules/imagenet_datamodule.py | 2 +- pl_bolts/datamodules/kitti_datamodule.py | 3 ++- pl_bolts/datamodules/sklearn_datamodule.py | 4 ++-- .../datamodules/ssl_imagenet_datamodule.py | 1 + pl_bolts/datamodules/stl10_datamodule.py | 2 +- pl_bolts/datamodules/vision_datamodule.py | 7 +------ .../datamodules/vocdetection_datamodule.py | 18 ++++++++++-------- 9 files changed, 31 insertions(+), 22 deletions(-) diff --git a/pl_bolts/datamodules/async_dataloader.py b/pl_bolts/datamodules/async_dataloader.py index 7ded9d9ef1..38a0b9bb58 100644 --- a/pl_bolts/datamodules/async_dataloader.py +++ b/pl_bolts/datamodules/async_dataloader.py @@ -1,10 +1,11 @@ import re from queue import Queue from threading import Thread +from typing import Any, Optional, Union import torch from torch._six import container_abcs, string_classes -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, Dataset class AsynchronousLoader(object): @@ -26,7 +27,14 @@ class AsynchronousLoader(object): constructing one here """ - def __init__(self, data, device=torch.device('cuda', 0), q_size=10, num_batches=None, **kwargs): + def __init__( + self, + data: Union[DataLoader, Dataset], + device: torch.device = torch.device('cuda', 0), + q_size: int = 10, + num_batches: Optional[int] = None, + **kwargs: Any + ): if isinstance(data, torch.utils.data.DataLoader): self.dataloader = data else: @@ -105,5 +113,5 @@ def __next__(self): self.idx += 1 return out - def __len__(self): + def __len__(self) -> int: return self.num_batches diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py index d27bfc3196..17812a0ac5 100644 --- a/pl_bolts/datamodules/cityscapes_datamodule.py +++ b/pl_bolts/datamodules/cityscapes_datamodule.py @@ -1,3 +1,5 @@ +from typing import Any + from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py index a31b637ba9..829c485aed 100644 --- a/pl_bolts/datamodules/imagenet_datamodule.py +++ b/pl_bolts/datamodules/imagenet_datamodule.py @@ -1,5 +1,5 @@ import os -from typing import Optional +from typing import Any, Optional from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py index a50528028b..9a82f0b7ec 100644 --- a/pl_bolts/datamodules/kitti_datamodule.py +++ b/pl_bolts/datamodules/kitti_datamodule.py @@ -1,4 +1,5 @@ import os +from typing import Any, Optional import torch from pytorch_lightning import LightningDataModule @@ -133,7 +134,7 @@ def test_dataloader(self) -> DataLoader: ) return loader - def _default_transforms(self) -> transform_lib.Compose: + def _default_transforms(self) -> transforms.Compose: kitti_transforms = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753], diff --git a/pl_bolts/datamodules/sklearn_datamodule.py b/pl_bolts/datamodules/sklearn_datamodule.py index dcfd559441..e80e1dfc9a 100644 --- a/pl_bolts/datamodules/sklearn_datamodule.py +++ b/pl_bolts/datamodules/sklearn_datamodule.py @@ -1,5 +1,5 @@ import math -from typing import Any +from typing import Any, Tuple import numpy as np import torch @@ -145,7 +145,7 @@ def __init__( x_val=None, y_val=None, x_test=None, y_test=None, val_split=0.2, test_split=0.1, - num_workers:int = 2, + num_workers: int = 2, random_state: int = 1234, shuffle: bool = True, batch_size: int = 16, diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py index 50315245af..1584583101 100644 --- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py +++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py @@ -1,4 +1,5 @@ import os +from typing import Any, Optional from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py index 3b29995a1b..8d46cfd7bf 100644 --- a/pl_bolts/datamodules/stl10_datamodule.py +++ b/pl_bolts/datamodules/stl10_datamodule.py @@ -1,5 +1,5 @@ import os -from typing import Optional +from typing import Any, Optional import torch from pytorch_lightning import LightningDataModule diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py index 5b8c508904..cdcefcb2eb 100644 --- a/pl_bolts/datamodules/vision_datamodule.py +++ b/pl_bolts/datamodules/vision_datamodule.py @@ -9,11 +9,6 @@ from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg -if _TORCHVISION_AVAILABLE: - from torchvision import transforms as transform_lib -else: - warn_missing_pkg('torchvision') # pragma: no-cover - class VisionDataModule(LightningDataModule): @@ -123,7 +118,7 @@ def _get_splits(self, len_dataset: int) -> List[int]: return splits @abstractmethod - def default_transforms(self) -> transform_lib.Compose: + def default_transforms(self): """ Default transform for the dataset """ def train_dataloader(self) -> DataLoader: diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py index a2087f9448..3f3767b4f8 100644 --- a/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -1,3 +1,5 @@ +from typing import Any, Dict + import torch from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader @@ -17,7 +19,7 @@ class Compose(object): Like `torchvision.transforms.compose` but works for (image, target) """ - def __init__(self, transforms): + def __init__(self, transforms: T.Compose): self.transforms = transforms def __call__(self, image, target): @@ -55,7 +57,7 @@ def _collate_fn(batch): ) -def _prepare_voc_instance(image, target): +def _prepare_voc_instance(image, target: Dict[str, Any]): """ Prepares VOC dataset into appropriate target for fasterrcnn @@ -113,8 +115,8 @@ def __init__( shuffle: bool = False, pin_memory: bool = False, drop_last: bool = False, - *args, - **kwargs, + *args: Any, + **kwargs: Any, ): if not _TORCHVISION_AVAILABLE: raise ModuleNotFoundError( # pragma: no-cover @@ -132,7 +134,7 @@ def __init__( self.drop_last = drop_last @property - def num_classes(self): + def num_classes(self) -> int: """ Return: 21 @@ -146,7 +148,7 @@ def prepare_data(self): VOCDetection(self.data_dir, year=self.year, image_set="train", download=True) VOCDetection(self.data_dir, year=self.year, image_set="val", download=True) - def train_dataloader(self, batch_size=1, transforms=None): + def train_dataloader(self, batch_size: int = 1, transforms=None) -> DataLoader: """ VOCDetection train set uses the `train` subset @@ -174,7 +176,7 @@ def train_dataloader(self, batch_size=1, transforms=None): ) return loader - def val_dataloader(self, batch_size=1, transforms=None): + def val_dataloader(self, batch_size: int = 1, transforms=None) -> DataLoader: """ VOCDetection val set uses the `val` subset @@ -201,7 +203,7 @@ def val_dataloader(self, batch_size=1, transforms=None): ) return loader - def _default_transforms(self): + def _default_transforms(self) -> T.Compose: if self.normalize: return ( lambda image, target: ( From 7af027e43bf96fe8d171c1a2c34da174acbbd7ae Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Sun, 20 Dec 2020 03:03:00 +0900 Subject: [PATCH 13/75] Removing torchvision.transforms from return typing --- pl_bolts/datamodules/binary_mnist_datamodule.py | 2 +- pl_bolts/datamodules/cifar10_datamodule.py | 2 +- pl_bolts/datamodules/cityscapes_datamodule.py | 4 ++-- pl_bolts/datamodules/fashion_mnist_datamodule.py | 2 +- pl_bolts/datamodules/imagenet_datamodule.py | 4 ++-- pl_bolts/datamodules/kitti_datamodule.py | 2 +- pl_bolts/datamodules/mnist_datamodule.py | 2 +- pl_bolts/datamodules/ssl_imagenet_datamodule.py | 2 +- pl_bolts/datamodules/stl10_datamodule.py | 2 +- pl_bolts/datamodules/vocdetection_datamodule.py | 2 +- 10 files changed, 12 insertions(+), 12 deletions(-) diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py index c713abe107..8dc02ec95e 100644 --- a/pl_bolts/datamodules/binary_mnist_datamodule.py +++ b/pl_bolts/datamodules/binary_mnist_datamodule.py @@ -98,7 +98,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self) -> transform_lib.Compose: + def default_transforms(self): if self.normalize: mnist_transforms = transform_lib.Compose( [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))] diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index 12aea1ec87..b208172ed0 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -112,7 +112,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self) -> transform_lib.Compose: + def default_transforms(self): if self.normalize: cf10_transforms = transform_lib.Compose([transform_lib.ToTensor(), cifar10_normalization()]) else: diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py index 17812a0ac5..dc8b866fba 100644 --- a/pl_bolts/datamodules/cityscapes_datamodule.py +++ b/pl_bolts/datamodules/cityscapes_datamodule.py @@ -192,7 +192,7 @@ def test_dataloader(self) -> DataLoader: ) return loader - def _default_transforms(self) -> transform_lib.Compose: + def _default_transforms(self): cityscapes_transforms = transform_lib.Compose([ transform_lib.ToTensor(), transform_lib.Normalize( @@ -202,7 +202,7 @@ def _default_transforms(self) -> transform_lib.Compose: ]) return cityscapes_transforms - def _default_target_transforms(self) -> transform_lib.Compose: + def _default_target_transforms(self): cityscapes_target_trasnforms = transform_lib.Compose([ transform_lib.ToTensor(), transform_lib.Lambda(lambda t: t.squeeze()) diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py index 9a73e6a637..a128ddfaab 100644 --- a/pl_bolts/datamodules/fashion_mnist_datamodule.py +++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py @@ -93,7 +93,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self) -> transform_lib.Compose: + def default_transforms(self): if self.normalize: mnist_transforms = transform_lib.Compose( [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))] diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py index 829c485aed..f8d8262108 100644 --- a/pl_bolts/datamodules/imagenet_datamodule.py +++ b/pl_bolts/datamodules/imagenet_datamodule.py @@ -206,7 +206,7 @@ def test_dataloader(self) -> DataLoader: ) return loader - def train_transform(self) -> transform_lib.Compose: + def train_transform(self): """ The standard imagenet transforms @@ -232,7 +232,7 @@ def train_transform(self) -> transform_lib.Compose: return preprocessing - def val_transform(self) -> transform_lib.Compose: + def val_transform(self): """ The standard imagenet transforms for validation diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py index 9a82f0b7ec..06778fdbfc 100644 --- a/pl_bolts/datamodules/kitti_datamodule.py +++ b/pl_bolts/datamodules/kitti_datamodule.py @@ -134,7 +134,7 @@ def test_dataloader(self) -> DataLoader: ) return loader - def _default_transforms(self) -> transforms.Compose: + def _default_transforms(self): kitti_transforms = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753], diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py index 76a0438a0c..c57fe8ca82 100644 --- a/pl_bolts/datamodules/mnist_datamodule.py +++ b/pl_bolts/datamodules/mnist_datamodule.py @@ -92,7 +92,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self) -> transform_lib.Compose: + def default_transforms(self): if self.normalize: mnist_transforms = transform_lib.Compose( [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))] diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py index 1584583101..96041fd4d9 100644 --- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py +++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py @@ -134,7 +134,7 @@ def test_dataloader(self, num_images_per_class: int, add_normalize: bool = False ) return loader - def _default_transforms(self) -> transform_lib.Compose: + def _default_transforms(self): mnist_transforms = transform_lib.Compose([ transform_lib.ToTensor(), imagenet_normalization() diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py index 8d46cfd7bf..5cd680a535 100644 --- a/pl_bolts/datamodules/stl10_datamodule.py +++ b/pl_bolts/datamodules/stl10_datamodule.py @@ -299,7 +299,7 @@ def val_dataloader_labeled(self) -> DataLoader: ) return loader - def _default_transforms(self) -> transform_lib.Compose: + def _default_transforms(self): data_transforms = transform_lib.Compose([ transform_lib.ToTensor(), stl10_normalization() diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py index 3f3767b4f8..e0768b42d5 100644 --- a/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -203,7 +203,7 @@ def val_dataloader(self, batch_size: int = 1, transforms=None) -> DataLoader: ) return loader - def _default_transforms(self) -> T.Compose: + def _default_transforms(self): if self.normalize: return ( lambda image, target: ( From 7bec6059691a3b025a37e6e5d37d8d8a218832bf Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Sun, 20 Dec 2020 03:09:03 +0900 Subject: [PATCH 14/75] Remove more torchvision.transforms typing --- pl_bolts/datamodules/vocdetection_datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py index e0768b42d5..3753e727ad 100644 --- a/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -19,7 +19,7 @@ class Compose(object): Like `torchvision.transforms.compose` but works for (image, target) """ - def __init__(self, transforms: T.Compose): + def __init__(self, transforms): self.transforms = transforms def __call__(self, image, target): From afbc918f9e53d93aeba5ea7ff0072b073e395846 Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Mon, 21 Dec 2020 10:13:11 +0900 Subject: [PATCH 15/75] Removing return typing --- pl_bolts/datamodules/cifar10_datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index b208172ed0..85ba4de6e7 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -71,7 +71,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ) -> None: + ): """ Args: data_dir: Where to save/load the data From 17ce335c0deb8b6ac2252d03f405a91eff9d3425 Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Mon, 21 Dec 2020 16:00:55 +0900 Subject: [PATCH 16/75] Add `None` for optional arguments --- pl_bolts/datamodules/cifar10_datamodule.py | 2 +- pl_bolts/datamodules/kitti_datamodule.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index 85ba4de6e7..534774684f 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -146,7 +146,7 @@ class TinyCIFAR10DataModule(CIFAR10DataModule): def __init__( self, - data_dir: Optional[str], + data_dir: Optional[str] = None, val_split: int = 50, num_workers: int = 16, num_samples: int = 100, diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py index 06778fdbfc..0067a1e53d 100644 --- a/pl_bolts/datamodules/kitti_datamodule.py +++ b/pl_bolts/datamodules/kitti_datamodule.py @@ -22,7 +22,7 @@ class KittiDataModule(LightningDataModule): def __init__( self, - data_dir: Optional[str], + data_dir: Optional[str] = None, val_split: float = 0.2, test_split: float = 0.1, num_workers: int = 16, From d09f98d708f4cb75d1d8476bed0fe201f3bea9b9 Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Mon, 21 Dec 2020 16:04:36 +0900 Subject: [PATCH 17/75] Remove unnecessary import --- pl_bolts/datamodules/vision_datamodule.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py index cdcefcb2eb..42252c4edf 100644 --- a/pl_bolts/datamodules/vision_datamodule.py +++ b/pl_bolts/datamodules/vision_datamodule.py @@ -6,9 +6,6 @@ from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader, Dataset, random_split -from pl_bolts.utils import _TORCHVISION_AVAILABLE -from pl_bolts.utils.warnings import warn_missing_pkg - class VisionDataModule(LightningDataModule): From b61fdc07f107fc74d1cc0432f29561cc494f7b65 Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Mon, 21 Dec 2020 16:51:36 +0900 Subject: [PATCH 18/75] Add `None` return type --- pl_bolts/datamodules/async_dataloader.py | 4 ++-- pl_bolts/datamodules/binary_mnist_datamodule.py | 2 +- pl_bolts/datamodules/cifar10_datamodule.py | 4 ++-- pl_bolts/datamodules/cityscapes_datamodule.py | 2 +- pl_bolts/datamodules/experience_source.py | 4 ++-- pl_bolts/datamodules/fashion_mnist_datamodule.py | 2 +- pl_bolts/datamodules/imagenet_datamodule.py | 6 +++--- pl_bolts/datamodules/kitti_datamodule.py | 2 +- pl_bolts/datamodules/mnist_datamodule.py | 2 +- pl_bolts/datamodules/sklearn_datamodule.py | 12 ++++++++---- pl_bolts/datamodules/ssl_imagenet_datamodule.py | 6 +++--- pl_bolts/datamodules/stl10_datamodule.py | 4 ++-- pl_bolts/datamodules/vision_datamodule.py | 6 +++--- pl_bolts/datamodules/vocdetection_datamodule.py | 6 +++--- 14 files changed, 33 insertions(+), 29 deletions(-) diff --git a/pl_bolts/datamodules/async_dataloader.py b/pl_bolts/datamodules/async_dataloader.py index 38a0b9bb58..224f34d5ee 100644 --- a/pl_bolts/datamodules/async_dataloader.py +++ b/pl_bolts/datamodules/async_dataloader.py @@ -34,7 +34,7 @@ def __init__( q_size: int = 10, num_batches: Optional[int] = None, **kwargs: Any - ): + ) -> None: if isinstance(data, torch.utils.data.DataLoader): self.dataloader = data else: @@ -57,7 +57,7 @@ def __init__( self.np_str_obj_array_pattern = re.compile(r'[SaUO]') - def load_loop(self): # The loop that will load into the queue in the background + def load_loop(self) -> None: # The loop that will load into the queue in the background for i, sample in enumerate(self.dataloader): self.queue.put(self.load_instance(sample)) if i == len(self): diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py index 8dc02ec95e..142b3d54ef 100644 --- a/pl_bolts/datamodules/binary_mnist_datamodule.py +++ b/pl_bolts/datamodules/binary_mnist_datamodule.py @@ -56,7 +56,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: Where to save/load the data diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index 534774684f..2cb894d749 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -71,7 +71,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: Where to save/load the data @@ -153,7 +153,7 @@ def __init__( labels: Optional[Sequence] = (1, 5, 8), *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: where to save/load the data diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py index dc8b866fba..61c1ae2bef 100644 --- a/pl_bolts/datamodules/cityscapes_datamodule.py +++ b/pl_bolts/datamodules/cityscapes_datamodule.py @@ -73,7 +73,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: where to load the data from path, i.e. where directory leftImg8bit and gtFine or gtCoarse diff --git a/pl_bolts/datamodules/experience_source.py b/pl_bolts/datamodules/experience_source.py index 5fe1332dfd..fac4e82f25 100644 --- a/pl_bolts/datamodules/experience_source.py +++ b/pl_bolts/datamodules/experience_source.py @@ -30,7 +30,7 @@ class ExperienceSourceDataset(IterableDataset): The logic for the experience source and how the batch is generated is defined the Lightning model itself """ - def __init__(self, generate_batch: Callable): + def __init__(self, generate_batch: Callable) -> None: self.generate_batch = generate_batch def __iter__(self) -> Iterable: @@ -243,7 +243,7 @@ def pop_rewards_steps(self): class DiscountedExperienceSource(ExperienceSource): """Outputs experiences with a discounted reward over N steps""" - def __init__(self, env: Env, agent, n_steps: int = 1, gamma: float = 0.99): + def __init__(self, env: Env, agent, n_steps: int = 1, gamma: float = 0.99) -> None: super().__init__(env, agent, (n_steps + 1)) self.gamma = gamma self.steps = n_steps diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py index a128ddfaab..833c4599a6 100644 --- a/pl_bolts/datamodules/fashion_mnist_datamodule.py +++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py @@ -57,7 +57,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: Where to save/load the data diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py index f8d8262108..6f06913f9f 100644 --- a/pl_bolts/datamodules/imagenet_datamodule.py +++ b/pl_bolts/datamodules/imagenet_datamodule.py @@ -60,7 +60,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: path to the imagenet dataset file @@ -103,14 +103,14 @@ def num_classes(self) -> int: """ return 1000 - def _verify_splits(self, data_dir: str, split: str): + def _verify_splits(self, data_dir: str, split: str) -> None: dirs = os.listdir(data_dir) if split not in dirs: raise FileNotFoundError(f'a {split} Imagenet split was not found in {data_dir},' f' make sure the folder contains a subfolder named {split}') - def prepare_data(self): + def prepare_data(self) -> None: """ This method already assumes you have imagenet2012 downloaded. It validates the data using the meta.bin. diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py index 0067a1e53d..3cf26dc762 100644 --- a/pl_bolts/datamodules/kitti_datamodule.py +++ b/pl_bolts/datamodules/kitti_datamodule.py @@ -33,7 +33,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Kitti train, validation and test dataloaders. diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py index c57fe8ca82..1dd5e927b6 100644 --- a/pl_bolts/datamodules/mnist_datamodule.py +++ b/pl_bolts/datamodules/mnist_datamodule.py @@ -56,7 +56,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: Where to save/load the data diff --git a/pl_bolts/datamodules/sklearn_datamodule.py b/pl_bolts/datamodules/sklearn_datamodule.py index e80e1dfc9a..d9477acc0b 100644 --- a/pl_bolts/datamodules/sklearn_datamodule.py +++ b/pl_bolts/datamodules/sklearn_datamodule.py @@ -28,7 +28,9 @@ class SklearnDataset(Dataset): >>> len(dataset) 506 """ - def __init__(self, X: np.ndarray, y: np.ndarray, X_transform: Any = None, y_transform: Any = None): + def __init__( + self, X: np.ndarray, y: np.ndarray, X_transform: Any = None, y_transform: Any = None + ) -> None: """ Args: X: Numpy ndarray @@ -75,7 +77,9 @@ class TensorDataset(Dataset): >>> len(dataset) 10 """ - def __init__(self, X: torch.Tensor, y: torch.Tensor, X_transform: Any = None, y_transform: Any = None): + def __init__( + self, X: torch.Tensor, y: torch.Tensor, X_transform: Any = None, y_transform: Any = None + ) -> None: """ Args: X: PyTorch tensor @@ -153,7 +157,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: super().__init__(*args, **kwargs) self.num_workers = num_workers @@ -201,7 +205,7 @@ def _init_datasets( y_val: np.ndarray, x_test: np.ndarray, y_test: np.ndarray - ): + ) -> None: self.train_dataset = SklearnDataset(X, y) self.val_dataset = SklearnDataset(x_val, y_val) self.test_dataset = SklearnDataset(x_test, y_test) diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py index 96041fd4d9..3dbda03527 100644 --- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py +++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py @@ -30,7 +30,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: super().__init__(*args, **kwargs) if not _TORCHVISION_AVAILABLE: @@ -50,14 +50,14 @@ def __init__( def num_classes(self) -> int: return 1000 - def _verify_splits(self, data_dir: str, split: str): + def _verify_splits(self, data_dir: str, split: str) -> None: dirs = os.listdir(data_dir) if split not in dirs: raise FileNotFoundError(f'a {split} Imagenet split was not found in {data_dir}, make sure the' f' folder contains a subfolder named {split}') - def prepare_data(self): + def prepare_data(self) -> None: # imagenet cannot be downloaded... must provide path to folder with the train/val splits self._verify_splits(self.data_dir, 'train') self._verify_splits(self.data_dir, 'val') diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py index 5cd680a535..79420be149 100644 --- a/pl_bolts/datamodules/stl10_datamodule.py +++ b/pl_bolts/datamodules/stl10_datamodule.py @@ -65,7 +65,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: where to save/load the data @@ -102,7 +102,7 @@ def __init__( def num_classes(self) -> int: return 10 - def prepare_data(self): + def prepare_data(self) -> None: """ Downloads the unlabeled, train and test split """ diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py index 42252c4edf..2144f0f509 100644 --- a/pl_bolts/datamodules/vision_datamodule.py +++ b/pl_bolts/datamodules/vision_datamodule.py @@ -29,7 +29,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: Where to save/load the data @@ -56,14 +56,14 @@ def __init__( self.pin_memory = pin_memory self.drop_last = drop_last - def prepare_data(self): + def prepare_data(self) -> None: """ Saves files to data_dir """ self.dataset_cls(self.data_dir, train=True, download=True) self.dataset_cls(self.data_dir, train=False, download=True) - def setup(self, stage: Optional[str] = None): + def setup(self, stage: Optional[str] = None) -> None: """ Creates train, val, and test dataset """ diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py index 3753e727ad..52d5065d97 100644 --- a/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -19,7 +19,7 @@ class Compose(object): Like `torchvision.transforms.compose` but works for (image, target) """ - def __init__(self, transforms): + def __init__(self, transforms) -> None: self.transforms = transforms def __call__(self, image, target): @@ -117,7 +117,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: if not _TORCHVISION_AVAILABLE: raise ModuleNotFoundError( # pragma: no-cover 'You want to use VOC dataset loaded from `torchvision` which is not installed yet.' @@ -141,7 +141,7 @@ def num_classes(self) -> int: """ return 21 - def prepare_data(self): + def prepare_data(self) -> None: """ Saves VOCDetection files to data_dir """ From f2f4305d9b67f18840be494d966bea2870ee0d4b Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Tue, 5 Jan 2021 20:30:43 +0900 Subject: [PATCH 19/75] Add type for torchvision transforms --- pl_bolts/datamodules/binary_mnist_datamodule.py | 4 +++- pl_bolts/datamodules/cifar10_datamodule.py | 4 +++- pl_bolts/datamodules/cityscapes_datamodule.py | 6 ++++-- pl_bolts/datamodules/fashion_mnist_datamodule.py | 4 +++- pl_bolts/datamodules/imagenet_datamodule.py | 6 ++++-- pl_bolts/datamodules/kitti_datamodule.py | 4 +++- pl_bolts/datamodules/mnist_datamodule.py | 4 +++- pl_bolts/datamodules/ssl_imagenet_datamodule.py | 4 +++- pl_bolts/datamodules/stl10_datamodule.py | 4 +++- pl_bolts/datamodules/vision_datamodule.py | 9 ++++++++- 10 files changed, 37 insertions(+), 12 deletions(-) diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py index 142b3d54ef..ad17360d08 100644 --- a/pl_bolts/datamodules/binary_mnist_datamodule.py +++ b/pl_bolts/datamodules/binary_mnist_datamodule.py @@ -7,8 +7,10 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib + from torchvision.transforms import Compose else: # pragma: no-cover warn_missing_pkg('torchvision') + Compose = object class BinaryMNISTDataModule(VisionDataModule): @@ -98,7 +100,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self): + def default_transforms(self) -> Compose: if self.normalize: mnist_transforms = transform_lib.Compose( [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))] diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index 2cb894d749..9dbf10b670 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -9,9 +9,11 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib from torchvision.datasets import CIFAR10 + from torchvision.transforms import Compose else: warn_missing_pkg('torchvision') # pragma: no-cover CIFAR10 = None + Compose = object class CIFAR10DataModule(VisionDataModule): @@ -112,7 +114,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self): + def default_transforms(self) -> Compose: if self.normalize: cf10_transforms = transform_lib.Compose([transform_lib.ToTensor(), cifar10_normalization()]) else: diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py index 61c1ae2bef..130e333976 100644 --- a/pl_bolts/datamodules/cityscapes_datamodule.py +++ b/pl_bolts/datamodules/cityscapes_datamodule.py @@ -9,8 +9,10 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib from torchvision.datasets import Cityscapes + from torchvision.transforms import Compose else: warn_missing_pkg('torchvision') # pragma: no-cover + Compose = object class CityscapesDataModule(LightningDataModule): @@ -192,7 +194,7 @@ def test_dataloader(self) -> DataLoader: ) return loader - def _default_transforms(self): + def _default_transforms(self) -> Compose: cityscapes_transforms = transform_lib.Compose([ transform_lib.ToTensor(), transform_lib.Normalize( @@ -202,7 +204,7 @@ def _default_transforms(self): ]) return cityscapes_transforms - def _default_target_transforms(self): + def _default_target_transforms(self) -> Compose: cityscapes_target_trasnforms = transform_lib.Compose([ transform_lib.ToTensor(), transform_lib.Lambda(lambda t: t.squeeze()) diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py index 833c4599a6..b37221bc74 100644 --- a/pl_bolts/datamodules/fashion_mnist_datamodule.py +++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py @@ -7,9 +7,11 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib from torchvision.datasets import FashionMNIST + from torchvision.transforms import Compose else: warn_missing_pkg('torchvision') # pragma: no-cover FashionMNIST = None + Compose = object class FashionMNISTDataModule(VisionDataModule): @@ -93,7 +95,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self): + def default_transforms(self) -> Compose: if self.normalize: mnist_transforms = transform_lib.Compose( [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))] diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py index 6f06913f9f..db2fc68c0b 100644 --- a/pl_bolts/datamodules/imagenet_datamodule.py +++ b/pl_bolts/datamodules/imagenet_datamodule.py @@ -11,8 +11,10 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib + from torchvision.transforms import Compose else: warn_missing_pkg('torchvision') # pragma: no-cover + Compose = object class ImagenetDataModule(LightningDataModule): @@ -206,7 +208,7 @@ def test_dataloader(self) -> DataLoader: ) return loader - def train_transform(self): + def train_transform(self) -> Compose: """ The standard imagenet transforms @@ -232,7 +234,7 @@ def train_transform(self): return preprocessing - def val_transform(self): + def val_transform(self) -> Compose: """ The standard imagenet transforms for validation diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py index 3cf26dc762..b63040f3bf 100644 --- a/pl_bolts/datamodules/kitti_datamodule.py +++ b/pl_bolts/datamodules/kitti_datamodule.py @@ -12,8 +12,10 @@ if _TORCHVISION_AVAILABLE: import torchvision.transforms as transforms + from torchvision.transforms import Compose else: # pragma: no cover warn_missing_pkg('torchvision') + Compose = object class KittiDataModule(LightningDataModule): @@ -134,7 +136,7 @@ def test_dataloader(self) -> DataLoader: ) return loader - def _default_transforms(self): + def _default_transforms(self) -> Compose: kitti_transforms = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753], diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py index 1dd5e927b6..711460023c 100644 --- a/pl_bolts/datamodules/mnist_datamodule.py +++ b/pl_bolts/datamodules/mnist_datamodule.py @@ -7,9 +7,11 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib from torchvision.datasets import MNIST + from torchvision.transforms import Compose else: warn_missing_pkg('torchvision') # pragma: no-cover MNIST = None + Compose = object class MNISTDataModule(VisionDataModule): @@ -92,7 +94,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self): + def default_transforms(self) -> Compose: if self.normalize: mnist_transforms = transform_lib.Compose( [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))] diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py index 3dbda03527..d575eb2d01 100644 --- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py +++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py @@ -11,8 +11,10 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib + from torchvision.transforms import Compose else: warn_missing_pkg('torchvision') # pragma: no-cover + Compose = object class SSLImagenetDataModule(LightningDataModule): # pragma: no cover @@ -134,7 +136,7 @@ def test_dataloader(self, num_images_per_class: int, add_normalize: bool = False ) return loader - def _default_transforms(self): + def _default_transforms(self) -> Compose: mnist_transforms = transform_lib.Compose([ transform_lib.ToTensor(), imagenet_normalization() diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py index 79420be149..f9ac77e140 100644 --- a/pl_bolts/datamodules/stl10_datamodule.py +++ b/pl_bolts/datamodules/stl10_datamodule.py @@ -13,8 +13,10 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib from torchvision.datasets import STL10 + from torchvision.transforms import Compose else: warn_missing_pkg('torchvision') # pragma: no-cover + Compose = object class STL10DataModule(LightningDataModule): # pragma: no cover @@ -299,7 +301,7 @@ def val_dataloader_labeled(self) -> DataLoader: ) return loader - def _default_transforms(self): + def _default_transforms(self) -> Compose: data_transforms = transform_lib.Compose([ transform_lib.ToTensor(), stl10_normalization() diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py index 2144f0f509..92e6723968 100644 --- a/pl_bolts/datamodules/vision_datamodule.py +++ b/pl_bolts/datamodules/vision_datamodule.py @@ -6,6 +6,13 @@ from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader, Dataset, random_split +from pl_bolts.utils import _TORCHVISION_AVAILABLE + +if _TORCHVISION_AVAILABLE: + from torchvision.transforms import Compose +else: + Compose = object + class VisionDataModule(LightningDataModule): @@ -115,7 +122,7 @@ def _get_splits(self, len_dataset: int) -> List[int]: return splits @abstractmethod - def default_transforms(self): + def default_transforms(self) -> Compose: """ Default transform for the dataset """ def train_dataloader(self) -> DataLoader: From cd09554cc8b4b2e5d75d0cf3db2c237e2d10e1d0 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <jirka.borovec@seznam.cz> Date: Tue, 5 Jan 2021 14:04:40 +0100 Subject: [PATCH 20/75] enable check --- setup.cfg | 3 --- 1 file changed, 3 deletions(-) diff --git a/setup.cfg b/setup.cfg index dcd35979f9..bda41d20f4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -61,9 +61,6 @@ show_error_codes = True disallow_untyped_defs = True ignore_missing_imports = True -[mypy-pl_bolts.datamodules.*] -ignore_errors = True - [mypy-pl_bolts.datasets.*] ignore_errors = True From 0fcd1862c8c1e8d0fcb33d215cc47d8d58dfb432 Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Fri, 18 Dec 2020 19:58:16 +0900 Subject: [PATCH 21/75] Adding types to datamodules --- .../datamodules/binary_mnist_datamodule.py | 4 +- pl_bolts/datamodules/cifar10_datamodule.py | 6 +-- pl_bolts/datamodules/cityscapes_datamodule.py | 16 ++++---- .../datamodules/fashion_mnist_datamodule.py | 4 +- pl_bolts/datamodules/imagenet_datamodule.py | 18 ++++----- pl_bolts/datamodules/kitti_datamodule.py | 14 +++---- pl_bolts/datamodules/mnist_datamodule.py | 4 +- pl_bolts/datamodules/sklearn_datamodule.py | 38 +++++++++++-------- .../datamodules/ssl_imagenet_datamodule.py | 22 +++++------ pl_bolts/datamodules/stl10_datamodule.py | 22 +++++------ pl_bolts/datamodules/vision_datamodule.py | 16 ++++++-- 11 files changed, 90 insertions(+), 74 deletions(-) diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py index 142b3d54ef..c713abe107 100644 --- a/pl_bolts/datamodules/binary_mnist_datamodule.py +++ b/pl_bolts/datamodules/binary_mnist_datamodule.py @@ -56,7 +56,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ) -> None: + ): """ Args: data_dir: Where to save/load the data @@ -98,7 +98,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self): + def default_transforms(self) -> transform_lib.Compose: if self.normalize: mnist_transforms = transform_lib.Compose( [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))] diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index afb2df8c9a..12aea1ec87 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -112,7 +112,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self): + def default_transforms(self) -> transform_lib.Compose: if self.normalize: cf10_transforms = transform_lib.Compose([transform_lib.ToTensor(), cifar10_normalization()]) else: @@ -146,14 +146,14 @@ class TinyCIFAR10DataModule(CIFAR10DataModule): def __init__( self, - data_dir: str, + data_dir: Optional[str], val_split: int = 50, num_workers: int = 16, num_samples: int = 100, labels: Optional[Sequence] = (1, 5, 8), *args: Any, **kwargs: Any, - ) -> None: + ): """ Args: data_dir: where to save/load the data diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py index b851617225..d27bfc3196 100644 --- a/pl_bolts/datamodules/cityscapes_datamodule.py +++ b/pl_bolts/datamodules/cityscapes_datamodule.py @@ -69,8 +69,8 @@ def __init__( shuffle: bool = False, pin_memory: bool = False, drop_last: bool = False, - *args, - **kwargs, + *args: Any, + **kwargs: Any, ): """ Args: @@ -109,14 +109,14 @@ def __init__( self.target_transforms = None @property - def num_classes(self): + def num_classes(self) -> int: """ Return: 30 """ return 30 - def train_dataloader(self): + def train_dataloader(self) -> DataLoader: """ Cityscapes train set """ @@ -141,7 +141,7 @@ def train_dataloader(self): ) return loader - def val_dataloader(self): + def val_dataloader(self) -> DataLoader: """ Cityscapes val set """ @@ -166,7 +166,7 @@ def val_dataloader(self): ) return loader - def test_dataloader(self): + def test_dataloader(self) -> DataLoader: """ Cityscapes test set """ @@ -190,7 +190,7 @@ def test_dataloader(self): ) return loader - def _default_transforms(self): + def _default_transforms(self) -> transform_lib.Compose: cityscapes_transforms = transform_lib.Compose([ transform_lib.ToTensor(), transform_lib.Normalize( @@ -200,7 +200,7 @@ def _default_transforms(self): ]) return cityscapes_transforms - def _default_target_transforms(self): + def _default_target_transforms(self) -> transform_lib.Compose: cityscapes_target_trasnforms = transform_lib.Compose([ transform_lib.ToTensor(), transform_lib.Lambda(lambda t: t.squeeze()) diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py index 833c4599a6..9a73e6a637 100644 --- a/pl_bolts/datamodules/fashion_mnist_datamodule.py +++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py @@ -57,7 +57,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ) -> None: + ): """ Args: data_dir: Where to save/load the data @@ -93,7 +93,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self): + def default_transforms(self) -> transform_lib.Compose: if self.normalize: mnist_transforms = transform_lib.Compose( [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))] diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py index 38546e29ee..a31b637ba9 100644 --- a/pl_bolts/datamodules/imagenet_datamodule.py +++ b/pl_bolts/datamodules/imagenet_datamodule.py @@ -58,8 +58,8 @@ def __init__( shuffle: bool = False, pin_memory: bool = False, drop_last: bool = False, - *args, - **kwargs, + *args: Any, + **kwargs: Any, ): """ Args: @@ -94,7 +94,7 @@ def __init__( self.num_samples = 1281167 - self.num_imgs_per_val_class * self.num_classes @property - def num_classes(self): + def num_classes(self) -> int: """ Return: @@ -103,7 +103,7 @@ def num_classes(self): """ return 1000 - def _verify_splits(self, data_dir, split): + def _verify_splits(self, data_dir: str, split: str): dirs = os.listdir(data_dir) if split not in dirs: @@ -138,7 +138,7 @@ def prepare_data(self): UnlabeledImagenet.generate_meta_bins(path) """) - def train_dataloader(self): + def train_dataloader(self) -> DataLoader: """ Uses the train split of imagenet2012 and puts away a portion of it for the validation split """ @@ -160,7 +160,7 @@ def train_dataloader(self): ) return loader - def val_dataloader(self): + def val_dataloader(self) -> DataLoader: """ Uses the part of the train split of imagenet2012 that was not used for training via `num_imgs_per_val_class` @@ -185,7 +185,7 @@ def val_dataloader(self): ) return loader - def test_dataloader(self): + def test_dataloader(self) -> DataLoader: """ Uses the validation split of imagenet2012 for testing """ @@ -206,7 +206,7 @@ def test_dataloader(self): ) return loader - def train_transform(self): + def train_transform(self) -> transform_lib.Compose: """ The standard imagenet transforms @@ -232,7 +232,7 @@ def train_transform(self): return preprocessing - def val_transform(self): + def val_transform(self) -> transform_lib.Compose: """ The standard imagenet transforms for validation diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py index 433e7fffed..a50528028b 100644 --- a/pl_bolts/datamodules/kitti_datamodule.py +++ b/pl_bolts/datamodules/kitti_datamodule.py @@ -21,7 +21,7 @@ class KittiDataModule(LightningDataModule): def __init__( self, - data_dir: str, + data_dir: Optional[str], val_split: float = 0.2, test_split: float = 0.1, num_workers: int = 16, @@ -30,8 +30,8 @@ def __init__( shuffle: bool = False, pin_memory: bool = False, drop_last: bool = False, - *args, - **kwargs, + *args: Any, + **kwargs: Any, ): """ Kitti train, validation and test dataloaders. @@ -100,7 +100,7 @@ def __init__( lengths=[train_len, val_len, test_len], generator=torch.Generator().manual_seed(self.seed)) - def train_dataloader(self): + def train_dataloader(self) -> DataLoader: loader = DataLoader( self.trainset, batch_size=self.batch_size, @@ -111,7 +111,7 @@ def train_dataloader(self): ) return loader - def val_dataloader(self): + def val_dataloader(self) -> DataLoader: loader = DataLoader( self.valset, batch_size=self.batch_size, @@ -122,7 +122,7 @@ def val_dataloader(self): ) return loader - def test_dataloader(self): + def test_dataloader(self) -> DataLoader: loader = DataLoader( self.testset, batch_size=self.batch_size, @@ -133,7 +133,7 @@ def test_dataloader(self): ) return loader - def _default_transforms(self): + def _default_transforms(self) -> transform_lib.Compose: kitti_transforms = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753], diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py index 1dd5e927b6..76a0438a0c 100644 --- a/pl_bolts/datamodules/mnist_datamodule.py +++ b/pl_bolts/datamodules/mnist_datamodule.py @@ -56,7 +56,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ) -> None: + ): """ Args: data_dir: Where to save/load the data @@ -92,7 +92,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self): + def default_transforms(self) -> transform_lib.Compose: if self.normalize: mnist_transforms = transform_lib.Compose( [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))] diff --git a/pl_bolts/datamodules/sklearn_datamodule.py b/pl_bolts/datamodules/sklearn_datamodule.py index ed262b10c8..dcfd559441 100644 --- a/pl_bolts/datamodules/sklearn_datamodule.py +++ b/pl_bolts/datamodules/sklearn_datamodule.py @@ -42,10 +42,10 @@ def __init__(self, X: np.ndarray, y: np.ndarray, X_transform: Any = None, y_tran self.X_transform = X_transform self.y_transform = y_transform - def __len__(self): + def __len__(self) -> int: return len(self.X) - def __getitem__(self, idx): + def __getitem__(self, idx) -> Tuple[np.ndarray, np.ndarray]: x = self.X[idx].astype(np.float32) y = self.Y[idx] @@ -89,10 +89,10 @@ def __init__(self, X: torch.Tensor, y: torch.Tensor, X_transform: Any = None, y_ self.X_transform = X_transform self.y_transform = y_transform - def __len__(self): + def __len__(self) -> int: return len(self.X) - def __getitem__(self, idx): + def __getitem__(self, idx) -> Tuple[torch.Tensor, torch.Tensor]: x = self.X[idx].float() y = self.Y[idx] @@ -145,14 +145,14 @@ def __init__( x_val=None, y_val=None, x_test=None, y_test=None, val_split=0.2, test_split=0.1, - num_workers=2, - random_state=1234, - shuffle=True, + num_workers:int = 2, + random_state: int = 1234, + shuffle: bool = True, batch_size: int = 16, - pin_memory=False, - drop_last=False, - *args, - **kwargs, + pin_memory: bool = False, + drop_last: bool = False, + *args: Any, + **kwargs: Any, ): super().__init__(*args, **kwargs) @@ -193,12 +193,20 @@ def __init__( self._init_datasets(X, y, x_val, y_val, x_test, y_test) - def _init_datasets(self, X, y, x_val, y_val, x_test, y_test): + def _init_datasets( + self, + X: np.ndarray, + y: np.ndarray, + x_val: np.ndarray, + y_val: np.ndarray, + x_test: np.ndarray, + y_test: np.ndarray + ): self.train_dataset = SklearnDataset(X, y) self.val_dataset = SklearnDataset(x_val, y_val) self.test_dataset = SklearnDataset(x_test, y_test) - def train_dataloader(self): + def train_dataloader(self) -> DataLoader: loader = DataLoader( self.train_dataset, batch_size=self.batch_size, @@ -209,7 +217,7 @@ def train_dataloader(self): ) return loader - def val_dataloader(self): + def val_dataloader(self) -> DataLoader: loader = DataLoader( self.val_dataset, batch_size=self.batch_size, @@ -220,7 +228,7 @@ def val_dataloader(self): ) return loader - def test_dataloader(self): + def test_dataloader(self) -> DataLoader: loader = DataLoader( self.test_dataset, batch_size=self.batch_size, diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py index 03e459fd5e..50315245af 100644 --- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py +++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py @@ -20,15 +20,15 @@ class SSLImagenetDataModule(LightningDataModule): # pragma: no cover def __init__( self, - data_dir, - meta_dir=None, - num_workers=16, + data_dir: str, + meta_dir: Optional[str] = None, + num_workers: int = 16, batch_size: int = 32, shuffle: bool = False, pin_memory: bool = False, drop_last: bool = False, - *args, - **kwargs, + *args: Any, + **kwargs: Any, ): super().__init__(*args, **kwargs) @@ -46,10 +46,10 @@ def __init__( self.drop_last = drop_last @property - def num_classes(self): + def num_classes(self) -> int: return 1000 - def _verify_splits(self, data_dir, split): + def _verify_splits(self, data_dir: str, split: str): dirs = os.listdir(data_dir) if split not in dirs: @@ -79,7 +79,7 @@ def prepare_data(self): UnlabeledImagenet.generate_meta_bins(path) """) - def train_dataloader(self, num_images_per_class=-1, add_normalize=False): + def train_dataloader(self, num_images_per_class: int = -1, add_normalize: bool = False) -> DataLoader: transforms = self._default_transforms() if self.train_transforms is None else self.train_transforms dataset = UnlabeledImagenet(self.data_dir, @@ -97,7 +97,7 @@ def train_dataloader(self, num_images_per_class=-1, add_normalize=False): ) return loader - def val_dataloader(self, num_images_per_class=50, add_normalize=False): + def val_dataloader(self, num_images_per_class: int = 50, add_normalize: bool = False) -> DataLoader: transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms dataset = UnlabeledImagenet(self.data_dir, @@ -115,7 +115,7 @@ def val_dataloader(self, num_images_per_class=50, add_normalize=False): ) return loader - def test_dataloader(self, num_images_per_class, add_normalize=False): + def test_dataloader(self, num_images_per_class: int, add_normalize: bool = False) -> DataLoader: transforms = self._default_transforms() if self.test_transforms is None else self.test_transforms dataset = UnlabeledImagenet(self.data_dir, @@ -133,7 +133,7 @@ def test_dataloader(self, num_images_per_class, add_normalize=False): ) return loader - def _default_transforms(self): + def _default_transforms(self) -> transform_lib.Compose: mnist_transforms = transform_lib.Compose([ transform_lib.ToTensor(), imagenet_normalization() diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py index c666db9b9b..3b29995a1b 100644 --- a/pl_bolts/datamodules/stl10_datamodule.py +++ b/pl_bolts/datamodules/stl10_datamodule.py @@ -63,8 +63,8 @@ def __init__( shuffle: bool = False, pin_memory: bool = False, drop_last: bool = False, - *args, - **kwargs, + *args: Any, + **kwargs: Any, ): """ Args: @@ -99,7 +99,7 @@ def __init__( self.num_unlabeled_samples = 100000 - unlabeled_val_split @property - def num_classes(self): + def num_classes(self) -> int: return 10 def prepare_data(self): @@ -110,7 +110,7 @@ def prepare_data(self): STL10(self.data_dir, split='train', download=True, transform=transform_lib.ToTensor()) STL10(self.data_dir, split='test', download=True, transform=transform_lib.ToTensor()) - def train_dataloader(self): + def train_dataloader(self) -> DataLoader: """ Loads the 'unlabeled' split minus a portion set aside for validation via `unlabeled_val_split`. """ @@ -131,7 +131,7 @@ def train_dataloader(self): ) return loader - def train_dataloader_mixed(self): + def train_dataloader_mixed(self) -> DataLoader: """ Loads a portion of the 'unlabeled' training data and 'train' (labeled) data. both portions have a subset removed for validation via `unlabeled_val_split` and `train_val_split` @@ -169,7 +169,7 @@ def train_dataloader_mixed(self): ) return loader - def val_dataloader(self): + def val_dataloader(self) -> DataLoader: """ Loads a portion of the 'unlabeled' training data set aside for validation The val dataset = (unlabeled - train_val_split) @@ -196,7 +196,7 @@ def val_dataloader(self): ) return loader - def val_dataloader_mixed(self): + def val_dataloader_mixed(self) -> DataLoader: """ Loads a portion of the 'unlabeled' training data set aside for validation along with the portion of the 'train' dataset to be used for validation @@ -239,7 +239,7 @@ def val_dataloader_mixed(self): ) return loader - def test_dataloader(self): + def test_dataloader(self) -> DataLoader: """ Loads the test split of STL10 @@ -260,7 +260,7 @@ def test_dataloader(self): ) return loader - def train_dataloader_labeled(self): + def train_dataloader_labeled(self) -> DataLoader: transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms dataset = STL10(self.data_dir, split='train', download=False, transform=transforms) @@ -278,7 +278,7 @@ def train_dataloader_labeled(self): ) return loader - def val_dataloader_labeled(self): + def val_dataloader_labeled(self) -> DataLoader: transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms dataset = STL10(self.data_dir, split='train', @@ -299,7 +299,7 @@ def val_dataloader_labeled(self): ) return loader - def _default_transforms(self): + def _default_transforms(self) -> transform_lib.Compose: data_transforms = transform_lib.Compose([ transform_lib.ToTensor(), stl10_normalization() diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py index 2144f0f509..5b8c508904 100644 --- a/pl_bolts/datamodules/vision_datamodule.py +++ b/pl_bolts/datamodules/vision_datamodule.py @@ -6,6 +6,14 @@ from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader, Dataset, random_split +from pl_bolts.utils import _TORCHVISION_AVAILABLE +from pl_bolts.utils.warnings import warn_missing_pkg + +if _TORCHVISION_AVAILABLE: + from torchvision import transforms as transform_lib +else: + warn_missing_pkg('torchvision') # pragma: no-cover + class VisionDataModule(LightningDataModule): @@ -29,7 +37,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ) -> None: + ): """ Args: data_dir: Where to save/load the data @@ -56,14 +64,14 @@ def __init__( self.pin_memory = pin_memory self.drop_last = drop_last - def prepare_data(self) -> None: + def prepare_data(self): """ Saves files to data_dir """ self.dataset_cls(self.data_dir, train=True, download=True) self.dataset_cls(self.data_dir, train=False, download=True) - def setup(self, stage: Optional[str] = None) -> None: + def setup(self, stage: Optional[str] = None): """ Creates train, val, and test dataset """ @@ -115,7 +123,7 @@ def _get_splits(self, len_dataset: int) -> List[int]: return splits @abstractmethod - def default_transforms(self): + def default_transforms(self) -> transform_lib.Compose: """ Default transform for the dataset """ def train_dataloader(self) -> DataLoader: From a4306969a9e3fbdf517a2da7889107a599c922e6 Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Fri, 18 Dec 2020 20:20:20 +0900 Subject: [PATCH 22/75] Fixing typing imports --- pl_bolts/datamodules/async_dataloader.py | 14 +++++++++++--- pl_bolts/datamodules/cityscapes_datamodule.py | 2 ++ pl_bolts/datamodules/imagenet_datamodule.py | 2 +- pl_bolts/datamodules/kitti_datamodule.py | 3 ++- pl_bolts/datamodules/sklearn_datamodule.py | 4 ++-- .../datamodules/ssl_imagenet_datamodule.py | 1 + pl_bolts/datamodules/stl10_datamodule.py | 2 +- pl_bolts/datamodules/vision_datamodule.py | 7 +------ .../datamodules/vocdetection_datamodule.py | 18 ++++++++++-------- 9 files changed, 31 insertions(+), 22 deletions(-) diff --git a/pl_bolts/datamodules/async_dataloader.py b/pl_bolts/datamodules/async_dataloader.py index 7ded9d9ef1..38a0b9bb58 100644 --- a/pl_bolts/datamodules/async_dataloader.py +++ b/pl_bolts/datamodules/async_dataloader.py @@ -1,10 +1,11 @@ import re from queue import Queue from threading import Thread +from typing import Any, Optional, Union import torch from torch._six import container_abcs, string_classes -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, Dataset class AsynchronousLoader(object): @@ -26,7 +27,14 @@ class AsynchronousLoader(object): constructing one here """ - def __init__(self, data, device=torch.device('cuda', 0), q_size=10, num_batches=None, **kwargs): + def __init__( + self, + data: Union[DataLoader, Dataset], + device: torch.device = torch.device('cuda', 0), + q_size: int = 10, + num_batches: Optional[int] = None, + **kwargs: Any + ): if isinstance(data, torch.utils.data.DataLoader): self.dataloader = data else: @@ -105,5 +113,5 @@ def __next__(self): self.idx += 1 return out - def __len__(self): + def __len__(self) -> int: return self.num_batches diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py index d27bfc3196..17812a0ac5 100644 --- a/pl_bolts/datamodules/cityscapes_datamodule.py +++ b/pl_bolts/datamodules/cityscapes_datamodule.py @@ -1,3 +1,5 @@ +from typing import Any + from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py index a31b637ba9..829c485aed 100644 --- a/pl_bolts/datamodules/imagenet_datamodule.py +++ b/pl_bolts/datamodules/imagenet_datamodule.py @@ -1,5 +1,5 @@ import os -from typing import Optional +from typing import Any, Optional from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py index a50528028b..9a82f0b7ec 100644 --- a/pl_bolts/datamodules/kitti_datamodule.py +++ b/pl_bolts/datamodules/kitti_datamodule.py @@ -1,4 +1,5 @@ import os +from typing import Any, Optional import torch from pytorch_lightning import LightningDataModule @@ -133,7 +134,7 @@ def test_dataloader(self) -> DataLoader: ) return loader - def _default_transforms(self) -> transform_lib.Compose: + def _default_transforms(self) -> transforms.Compose: kitti_transforms = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753], diff --git a/pl_bolts/datamodules/sklearn_datamodule.py b/pl_bolts/datamodules/sklearn_datamodule.py index dcfd559441..e80e1dfc9a 100644 --- a/pl_bolts/datamodules/sklearn_datamodule.py +++ b/pl_bolts/datamodules/sklearn_datamodule.py @@ -1,5 +1,5 @@ import math -from typing import Any +from typing import Any, Tuple import numpy as np import torch @@ -145,7 +145,7 @@ def __init__( x_val=None, y_val=None, x_test=None, y_test=None, val_split=0.2, test_split=0.1, - num_workers:int = 2, + num_workers: int = 2, random_state: int = 1234, shuffle: bool = True, batch_size: int = 16, diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py index 50315245af..1584583101 100644 --- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py +++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py @@ -1,4 +1,5 @@ import os +from typing import Any, Optional from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py index 3b29995a1b..8d46cfd7bf 100644 --- a/pl_bolts/datamodules/stl10_datamodule.py +++ b/pl_bolts/datamodules/stl10_datamodule.py @@ -1,5 +1,5 @@ import os -from typing import Optional +from typing import Any, Optional import torch from pytorch_lightning import LightningDataModule diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py index 5b8c508904..cdcefcb2eb 100644 --- a/pl_bolts/datamodules/vision_datamodule.py +++ b/pl_bolts/datamodules/vision_datamodule.py @@ -9,11 +9,6 @@ from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg -if _TORCHVISION_AVAILABLE: - from torchvision import transforms as transform_lib -else: - warn_missing_pkg('torchvision') # pragma: no-cover - class VisionDataModule(LightningDataModule): @@ -123,7 +118,7 @@ def _get_splits(self, len_dataset: int) -> List[int]: return splits @abstractmethod - def default_transforms(self) -> transform_lib.Compose: + def default_transforms(self): """ Default transform for the dataset """ def train_dataloader(self) -> DataLoader: diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py index a2087f9448..3f3767b4f8 100644 --- a/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -1,3 +1,5 @@ +from typing import Any, Dict + import torch from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader @@ -17,7 +19,7 @@ class Compose(object): Like `torchvision.transforms.compose` but works for (image, target) """ - def __init__(self, transforms): + def __init__(self, transforms: T.Compose): self.transforms = transforms def __call__(self, image, target): @@ -55,7 +57,7 @@ def _collate_fn(batch): ) -def _prepare_voc_instance(image, target): +def _prepare_voc_instance(image, target: Dict[str, Any]): """ Prepares VOC dataset into appropriate target for fasterrcnn @@ -113,8 +115,8 @@ def __init__( shuffle: bool = False, pin_memory: bool = False, drop_last: bool = False, - *args, - **kwargs, + *args: Any, + **kwargs: Any, ): if not _TORCHVISION_AVAILABLE: raise ModuleNotFoundError( # pragma: no-cover @@ -132,7 +134,7 @@ def __init__( self.drop_last = drop_last @property - def num_classes(self): + def num_classes(self) -> int: """ Return: 21 @@ -146,7 +148,7 @@ def prepare_data(self): VOCDetection(self.data_dir, year=self.year, image_set="train", download=True) VOCDetection(self.data_dir, year=self.year, image_set="val", download=True) - def train_dataloader(self, batch_size=1, transforms=None): + def train_dataloader(self, batch_size: int = 1, transforms=None) -> DataLoader: """ VOCDetection train set uses the `train` subset @@ -174,7 +176,7 @@ def train_dataloader(self, batch_size=1, transforms=None): ) return loader - def val_dataloader(self, batch_size=1, transforms=None): + def val_dataloader(self, batch_size: int = 1, transforms=None) -> DataLoader: """ VOCDetection val set uses the `val` subset @@ -201,7 +203,7 @@ def val_dataloader(self, batch_size=1, transforms=None): ) return loader - def _default_transforms(self): + def _default_transforms(self) -> T.Compose: if self.normalize: return ( lambda image, target: ( From a84551e0c69853b31c3eb41c49b2483c3a012133 Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Sun, 20 Dec 2020 03:03:00 +0900 Subject: [PATCH 23/75] Removing torchvision.transforms from return typing --- pl_bolts/datamodules/binary_mnist_datamodule.py | 2 +- pl_bolts/datamodules/cifar10_datamodule.py | 2 +- pl_bolts/datamodules/cityscapes_datamodule.py | 4 ++-- pl_bolts/datamodules/fashion_mnist_datamodule.py | 2 +- pl_bolts/datamodules/imagenet_datamodule.py | 4 ++-- pl_bolts/datamodules/kitti_datamodule.py | 2 +- pl_bolts/datamodules/mnist_datamodule.py | 2 +- pl_bolts/datamodules/ssl_imagenet_datamodule.py | 2 +- pl_bolts/datamodules/stl10_datamodule.py | 2 +- pl_bolts/datamodules/vocdetection_datamodule.py | 2 +- 10 files changed, 12 insertions(+), 12 deletions(-) diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py index c713abe107..8dc02ec95e 100644 --- a/pl_bolts/datamodules/binary_mnist_datamodule.py +++ b/pl_bolts/datamodules/binary_mnist_datamodule.py @@ -98,7 +98,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self) -> transform_lib.Compose: + def default_transforms(self): if self.normalize: mnist_transforms = transform_lib.Compose( [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))] diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index 12aea1ec87..b208172ed0 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -112,7 +112,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self) -> transform_lib.Compose: + def default_transforms(self): if self.normalize: cf10_transforms = transform_lib.Compose([transform_lib.ToTensor(), cifar10_normalization()]) else: diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py index 17812a0ac5..dc8b866fba 100644 --- a/pl_bolts/datamodules/cityscapes_datamodule.py +++ b/pl_bolts/datamodules/cityscapes_datamodule.py @@ -192,7 +192,7 @@ def test_dataloader(self) -> DataLoader: ) return loader - def _default_transforms(self) -> transform_lib.Compose: + def _default_transforms(self): cityscapes_transforms = transform_lib.Compose([ transform_lib.ToTensor(), transform_lib.Normalize( @@ -202,7 +202,7 @@ def _default_transforms(self) -> transform_lib.Compose: ]) return cityscapes_transforms - def _default_target_transforms(self) -> transform_lib.Compose: + def _default_target_transforms(self): cityscapes_target_trasnforms = transform_lib.Compose([ transform_lib.ToTensor(), transform_lib.Lambda(lambda t: t.squeeze()) diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py index 9a73e6a637..a128ddfaab 100644 --- a/pl_bolts/datamodules/fashion_mnist_datamodule.py +++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py @@ -93,7 +93,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self) -> transform_lib.Compose: + def default_transforms(self): if self.normalize: mnist_transforms = transform_lib.Compose( [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))] diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py index 829c485aed..f8d8262108 100644 --- a/pl_bolts/datamodules/imagenet_datamodule.py +++ b/pl_bolts/datamodules/imagenet_datamodule.py @@ -206,7 +206,7 @@ def test_dataloader(self) -> DataLoader: ) return loader - def train_transform(self) -> transform_lib.Compose: + def train_transform(self): """ The standard imagenet transforms @@ -232,7 +232,7 @@ def train_transform(self) -> transform_lib.Compose: return preprocessing - def val_transform(self) -> transform_lib.Compose: + def val_transform(self): """ The standard imagenet transforms for validation diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py index 9a82f0b7ec..06778fdbfc 100644 --- a/pl_bolts/datamodules/kitti_datamodule.py +++ b/pl_bolts/datamodules/kitti_datamodule.py @@ -134,7 +134,7 @@ def test_dataloader(self) -> DataLoader: ) return loader - def _default_transforms(self) -> transforms.Compose: + def _default_transforms(self): kitti_transforms = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753], diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py index 76a0438a0c..c57fe8ca82 100644 --- a/pl_bolts/datamodules/mnist_datamodule.py +++ b/pl_bolts/datamodules/mnist_datamodule.py @@ -92,7 +92,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self) -> transform_lib.Compose: + def default_transforms(self): if self.normalize: mnist_transforms = transform_lib.Compose( [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))] diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py index 1584583101..96041fd4d9 100644 --- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py +++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py @@ -134,7 +134,7 @@ def test_dataloader(self, num_images_per_class: int, add_normalize: bool = False ) return loader - def _default_transforms(self) -> transform_lib.Compose: + def _default_transforms(self): mnist_transforms = transform_lib.Compose([ transform_lib.ToTensor(), imagenet_normalization() diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py index 8d46cfd7bf..5cd680a535 100644 --- a/pl_bolts/datamodules/stl10_datamodule.py +++ b/pl_bolts/datamodules/stl10_datamodule.py @@ -299,7 +299,7 @@ def val_dataloader_labeled(self) -> DataLoader: ) return loader - def _default_transforms(self) -> transform_lib.Compose: + def _default_transforms(self): data_transforms = transform_lib.Compose([ transform_lib.ToTensor(), stl10_normalization() diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py index 3f3767b4f8..e0768b42d5 100644 --- a/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -203,7 +203,7 @@ def val_dataloader(self, batch_size: int = 1, transforms=None) -> DataLoader: ) return loader - def _default_transforms(self) -> T.Compose: + def _default_transforms(self): if self.normalize: return ( lambda image, target: ( From 685162c7226a7cfe7bba6e948436dd8d40bc8c27 Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Sun, 20 Dec 2020 03:09:03 +0900 Subject: [PATCH 24/75] Remove more torchvision.transforms typing --- pl_bolts/datamodules/vocdetection_datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py index e0768b42d5..3753e727ad 100644 --- a/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -19,7 +19,7 @@ class Compose(object): Like `torchvision.transforms.compose` but works for (image, target) """ - def __init__(self, transforms: T.Compose): + def __init__(self, transforms): self.transforms = transforms def __call__(self, image, target): From 01408372c75cb8c22fbff7c4062509e66ed64335 Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Mon, 21 Dec 2020 10:13:11 +0900 Subject: [PATCH 25/75] Removing return typing --- pl_bolts/datamodules/cifar10_datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index b208172ed0..85ba4de6e7 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -71,7 +71,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ) -> None: + ): """ Args: data_dir: Where to save/load the data From a6b8d4af95c2710ad47e81d25b5f866dd79eca37 Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Mon, 21 Dec 2020 16:00:55 +0900 Subject: [PATCH 26/75] Add `None` for optional arguments --- pl_bolts/datamodules/cifar10_datamodule.py | 2 +- pl_bolts/datamodules/kitti_datamodule.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index 85ba4de6e7..534774684f 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -146,7 +146,7 @@ class TinyCIFAR10DataModule(CIFAR10DataModule): def __init__( self, - data_dir: Optional[str], + data_dir: Optional[str] = None, val_split: int = 50, num_workers: int = 16, num_samples: int = 100, diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py index 06778fdbfc..0067a1e53d 100644 --- a/pl_bolts/datamodules/kitti_datamodule.py +++ b/pl_bolts/datamodules/kitti_datamodule.py @@ -22,7 +22,7 @@ class KittiDataModule(LightningDataModule): def __init__( self, - data_dir: Optional[str], + data_dir: Optional[str] = None, val_split: float = 0.2, test_split: float = 0.1, num_workers: int = 16, From de35a5514fc39ccbe74564f72b245d676f4154ac Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Mon, 21 Dec 2020 16:04:36 +0900 Subject: [PATCH 27/75] Remove unnecessary import --- pl_bolts/datamodules/vision_datamodule.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py index cdcefcb2eb..42252c4edf 100644 --- a/pl_bolts/datamodules/vision_datamodule.py +++ b/pl_bolts/datamodules/vision_datamodule.py @@ -6,9 +6,6 @@ from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader, Dataset, random_split -from pl_bolts.utils import _TORCHVISION_AVAILABLE -from pl_bolts.utils.warnings import warn_missing_pkg - class VisionDataModule(LightningDataModule): From f521b793cd6d488531179e457b022c25e1ab9e8f Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Mon, 21 Dec 2020 16:51:36 +0900 Subject: [PATCH 28/75] Add `None` return type --- pl_bolts/datamodules/async_dataloader.py | 4 ++-- pl_bolts/datamodules/binary_mnist_datamodule.py | 2 +- pl_bolts/datamodules/cifar10_datamodule.py | 4 ++-- pl_bolts/datamodules/cityscapes_datamodule.py | 2 +- pl_bolts/datamodules/experience_source.py | 4 ++-- pl_bolts/datamodules/fashion_mnist_datamodule.py | 2 +- pl_bolts/datamodules/imagenet_datamodule.py | 6 +++--- pl_bolts/datamodules/kitti_datamodule.py | 2 +- pl_bolts/datamodules/mnist_datamodule.py | 2 +- pl_bolts/datamodules/sklearn_datamodule.py | 12 ++++++++---- pl_bolts/datamodules/ssl_imagenet_datamodule.py | 6 +++--- pl_bolts/datamodules/stl10_datamodule.py | 4 ++-- pl_bolts/datamodules/vision_datamodule.py | 6 +++--- pl_bolts/datamodules/vocdetection_datamodule.py | 6 +++--- 14 files changed, 33 insertions(+), 29 deletions(-) diff --git a/pl_bolts/datamodules/async_dataloader.py b/pl_bolts/datamodules/async_dataloader.py index 38a0b9bb58..224f34d5ee 100644 --- a/pl_bolts/datamodules/async_dataloader.py +++ b/pl_bolts/datamodules/async_dataloader.py @@ -34,7 +34,7 @@ def __init__( q_size: int = 10, num_batches: Optional[int] = None, **kwargs: Any - ): + ) -> None: if isinstance(data, torch.utils.data.DataLoader): self.dataloader = data else: @@ -57,7 +57,7 @@ def __init__( self.np_str_obj_array_pattern = re.compile(r'[SaUO]') - def load_loop(self): # The loop that will load into the queue in the background + def load_loop(self) -> None: # The loop that will load into the queue in the background for i, sample in enumerate(self.dataloader): self.queue.put(self.load_instance(sample)) if i == len(self): diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py index 8dc02ec95e..142b3d54ef 100644 --- a/pl_bolts/datamodules/binary_mnist_datamodule.py +++ b/pl_bolts/datamodules/binary_mnist_datamodule.py @@ -56,7 +56,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: Where to save/load the data diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index 534774684f..2cb894d749 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -71,7 +71,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: Where to save/load the data @@ -153,7 +153,7 @@ def __init__( labels: Optional[Sequence] = (1, 5, 8), *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: where to save/load the data diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py index dc8b866fba..61c1ae2bef 100644 --- a/pl_bolts/datamodules/cityscapes_datamodule.py +++ b/pl_bolts/datamodules/cityscapes_datamodule.py @@ -73,7 +73,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: where to load the data from path, i.e. where directory leftImg8bit and gtFine or gtCoarse diff --git a/pl_bolts/datamodules/experience_source.py b/pl_bolts/datamodules/experience_source.py index 5fe1332dfd..fac4e82f25 100644 --- a/pl_bolts/datamodules/experience_source.py +++ b/pl_bolts/datamodules/experience_source.py @@ -30,7 +30,7 @@ class ExperienceSourceDataset(IterableDataset): The logic for the experience source and how the batch is generated is defined the Lightning model itself """ - def __init__(self, generate_batch: Callable): + def __init__(self, generate_batch: Callable) -> None: self.generate_batch = generate_batch def __iter__(self) -> Iterable: @@ -243,7 +243,7 @@ def pop_rewards_steps(self): class DiscountedExperienceSource(ExperienceSource): """Outputs experiences with a discounted reward over N steps""" - def __init__(self, env: Env, agent, n_steps: int = 1, gamma: float = 0.99): + def __init__(self, env: Env, agent, n_steps: int = 1, gamma: float = 0.99) -> None: super().__init__(env, agent, (n_steps + 1)) self.gamma = gamma self.steps = n_steps diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py index a128ddfaab..833c4599a6 100644 --- a/pl_bolts/datamodules/fashion_mnist_datamodule.py +++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py @@ -57,7 +57,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: Where to save/load the data diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py index f8d8262108..6f06913f9f 100644 --- a/pl_bolts/datamodules/imagenet_datamodule.py +++ b/pl_bolts/datamodules/imagenet_datamodule.py @@ -60,7 +60,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: path to the imagenet dataset file @@ -103,14 +103,14 @@ def num_classes(self) -> int: """ return 1000 - def _verify_splits(self, data_dir: str, split: str): + def _verify_splits(self, data_dir: str, split: str) -> None: dirs = os.listdir(data_dir) if split not in dirs: raise FileNotFoundError(f'a {split} Imagenet split was not found in {data_dir},' f' make sure the folder contains a subfolder named {split}') - def prepare_data(self): + def prepare_data(self) -> None: """ This method already assumes you have imagenet2012 downloaded. It validates the data using the meta.bin. diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py index 0067a1e53d..3cf26dc762 100644 --- a/pl_bolts/datamodules/kitti_datamodule.py +++ b/pl_bolts/datamodules/kitti_datamodule.py @@ -33,7 +33,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Kitti train, validation and test dataloaders. diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py index c57fe8ca82..1dd5e927b6 100644 --- a/pl_bolts/datamodules/mnist_datamodule.py +++ b/pl_bolts/datamodules/mnist_datamodule.py @@ -56,7 +56,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: Where to save/load the data diff --git a/pl_bolts/datamodules/sklearn_datamodule.py b/pl_bolts/datamodules/sklearn_datamodule.py index e80e1dfc9a..d9477acc0b 100644 --- a/pl_bolts/datamodules/sklearn_datamodule.py +++ b/pl_bolts/datamodules/sklearn_datamodule.py @@ -28,7 +28,9 @@ class SklearnDataset(Dataset): >>> len(dataset) 506 """ - def __init__(self, X: np.ndarray, y: np.ndarray, X_transform: Any = None, y_transform: Any = None): + def __init__( + self, X: np.ndarray, y: np.ndarray, X_transform: Any = None, y_transform: Any = None + ) -> None: """ Args: X: Numpy ndarray @@ -75,7 +77,9 @@ class TensorDataset(Dataset): >>> len(dataset) 10 """ - def __init__(self, X: torch.Tensor, y: torch.Tensor, X_transform: Any = None, y_transform: Any = None): + def __init__( + self, X: torch.Tensor, y: torch.Tensor, X_transform: Any = None, y_transform: Any = None + ) -> None: """ Args: X: PyTorch tensor @@ -153,7 +157,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: super().__init__(*args, **kwargs) self.num_workers = num_workers @@ -201,7 +205,7 @@ def _init_datasets( y_val: np.ndarray, x_test: np.ndarray, y_test: np.ndarray - ): + ) -> None: self.train_dataset = SklearnDataset(X, y) self.val_dataset = SklearnDataset(x_val, y_val) self.test_dataset = SklearnDataset(x_test, y_test) diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py index 96041fd4d9..3dbda03527 100644 --- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py +++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py @@ -30,7 +30,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: super().__init__(*args, **kwargs) if not _TORCHVISION_AVAILABLE: @@ -50,14 +50,14 @@ def __init__( def num_classes(self) -> int: return 1000 - def _verify_splits(self, data_dir: str, split: str): + def _verify_splits(self, data_dir: str, split: str) -> None: dirs = os.listdir(data_dir) if split not in dirs: raise FileNotFoundError(f'a {split} Imagenet split was not found in {data_dir}, make sure the' f' folder contains a subfolder named {split}') - def prepare_data(self): + def prepare_data(self) -> None: # imagenet cannot be downloaded... must provide path to folder with the train/val splits self._verify_splits(self.data_dir, 'train') self._verify_splits(self.data_dir, 'val') diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py index 5cd680a535..79420be149 100644 --- a/pl_bolts/datamodules/stl10_datamodule.py +++ b/pl_bolts/datamodules/stl10_datamodule.py @@ -65,7 +65,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: where to save/load the data @@ -102,7 +102,7 @@ def __init__( def num_classes(self) -> int: return 10 - def prepare_data(self): + def prepare_data(self) -> None: """ Downloads the unlabeled, train and test split """ diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py index 42252c4edf..2144f0f509 100644 --- a/pl_bolts/datamodules/vision_datamodule.py +++ b/pl_bolts/datamodules/vision_datamodule.py @@ -29,7 +29,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: Where to save/load the data @@ -56,14 +56,14 @@ def __init__( self.pin_memory = pin_memory self.drop_last = drop_last - def prepare_data(self): + def prepare_data(self) -> None: """ Saves files to data_dir """ self.dataset_cls(self.data_dir, train=True, download=True) self.dataset_cls(self.data_dir, train=False, download=True) - def setup(self, stage: Optional[str] = None): + def setup(self, stage: Optional[str] = None) -> None: """ Creates train, val, and test dataset """ diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py index 3753e727ad..52d5065d97 100644 --- a/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -19,7 +19,7 @@ class Compose(object): Like `torchvision.transforms.compose` but works for (image, target) """ - def __init__(self, transforms): + def __init__(self, transforms) -> None: self.transforms = transforms def __call__(self, image, target): @@ -117,7 +117,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: if not _TORCHVISION_AVAILABLE: raise ModuleNotFoundError( # pragma: no-cover 'You want to use VOC dataset loaded from `torchvision` which is not installed yet.' @@ -141,7 +141,7 @@ def num_classes(self) -> int: """ return 21 - def prepare_data(self): + def prepare_data(self) -> None: """ Saves VOCDetection files to data_dir """ From fa0d271011399d5006541dc6b8fc9cc1efaf180c Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Tue, 5 Jan 2021 20:30:43 +0900 Subject: [PATCH 29/75] Add type for torchvision transforms --- pl_bolts/datamodules/binary_mnist_datamodule.py | 4 +++- pl_bolts/datamodules/cifar10_datamodule.py | 4 +++- pl_bolts/datamodules/cityscapes_datamodule.py | 6 ++++-- pl_bolts/datamodules/fashion_mnist_datamodule.py | 4 +++- pl_bolts/datamodules/imagenet_datamodule.py | 6 ++++-- pl_bolts/datamodules/kitti_datamodule.py | 4 +++- pl_bolts/datamodules/mnist_datamodule.py | 4 +++- pl_bolts/datamodules/ssl_imagenet_datamodule.py | 4 +++- pl_bolts/datamodules/stl10_datamodule.py | 4 +++- pl_bolts/datamodules/vision_datamodule.py | 9 ++++++++- 10 files changed, 37 insertions(+), 12 deletions(-) diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py index 142b3d54ef..ad17360d08 100644 --- a/pl_bolts/datamodules/binary_mnist_datamodule.py +++ b/pl_bolts/datamodules/binary_mnist_datamodule.py @@ -7,8 +7,10 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib + from torchvision.transforms import Compose else: # pragma: no-cover warn_missing_pkg('torchvision') + Compose = object class BinaryMNISTDataModule(VisionDataModule): @@ -98,7 +100,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self): + def default_transforms(self) -> Compose: if self.normalize: mnist_transforms = transform_lib.Compose( [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))] diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index 2cb894d749..9dbf10b670 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -9,9 +9,11 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib from torchvision.datasets import CIFAR10 + from torchvision.transforms import Compose else: warn_missing_pkg('torchvision') # pragma: no-cover CIFAR10 = None + Compose = object class CIFAR10DataModule(VisionDataModule): @@ -112,7 +114,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self): + def default_transforms(self) -> Compose: if self.normalize: cf10_transforms = transform_lib.Compose([transform_lib.ToTensor(), cifar10_normalization()]) else: diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py index 61c1ae2bef..130e333976 100644 --- a/pl_bolts/datamodules/cityscapes_datamodule.py +++ b/pl_bolts/datamodules/cityscapes_datamodule.py @@ -9,8 +9,10 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib from torchvision.datasets import Cityscapes + from torchvision.transforms import Compose else: warn_missing_pkg('torchvision') # pragma: no-cover + Compose = object class CityscapesDataModule(LightningDataModule): @@ -192,7 +194,7 @@ def test_dataloader(self) -> DataLoader: ) return loader - def _default_transforms(self): + def _default_transforms(self) -> Compose: cityscapes_transforms = transform_lib.Compose([ transform_lib.ToTensor(), transform_lib.Normalize( @@ -202,7 +204,7 @@ def _default_transforms(self): ]) return cityscapes_transforms - def _default_target_transforms(self): + def _default_target_transforms(self) -> Compose: cityscapes_target_trasnforms = transform_lib.Compose([ transform_lib.ToTensor(), transform_lib.Lambda(lambda t: t.squeeze()) diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py index 833c4599a6..b37221bc74 100644 --- a/pl_bolts/datamodules/fashion_mnist_datamodule.py +++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py @@ -7,9 +7,11 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib from torchvision.datasets import FashionMNIST + from torchvision.transforms import Compose else: warn_missing_pkg('torchvision') # pragma: no-cover FashionMNIST = None + Compose = object class FashionMNISTDataModule(VisionDataModule): @@ -93,7 +95,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self): + def default_transforms(self) -> Compose: if self.normalize: mnist_transforms = transform_lib.Compose( [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))] diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py index 6f06913f9f..db2fc68c0b 100644 --- a/pl_bolts/datamodules/imagenet_datamodule.py +++ b/pl_bolts/datamodules/imagenet_datamodule.py @@ -11,8 +11,10 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib + from torchvision.transforms import Compose else: warn_missing_pkg('torchvision') # pragma: no-cover + Compose = object class ImagenetDataModule(LightningDataModule): @@ -206,7 +208,7 @@ def test_dataloader(self) -> DataLoader: ) return loader - def train_transform(self): + def train_transform(self) -> Compose: """ The standard imagenet transforms @@ -232,7 +234,7 @@ def train_transform(self): return preprocessing - def val_transform(self): + def val_transform(self) -> Compose: """ The standard imagenet transforms for validation diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py index 3cf26dc762..b63040f3bf 100644 --- a/pl_bolts/datamodules/kitti_datamodule.py +++ b/pl_bolts/datamodules/kitti_datamodule.py @@ -12,8 +12,10 @@ if _TORCHVISION_AVAILABLE: import torchvision.transforms as transforms + from torchvision.transforms import Compose else: # pragma: no cover warn_missing_pkg('torchvision') + Compose = object class KittiDataModule(LightningDataModule): @@ -134,7 +136,7 @@ def test_dataloader(self) -> DataLoader: ) return loader - def _default_transforms(self): + def _default_transforms(self) -> Compose: kitti_transforms = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753], diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py index 1dd5e927b6..711460023c 100644 --- a/pl_bolts/datamodules/mnist_datamodule.py +++ b/pl_bolts/datamodules/mnist_datamodule.py @@ -7,9 +7,11 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib from torchvision.datasets import MNIST + from torchvision.transforms import Compose else: warn_missing_pkg('torchvision') # pragma: no-cover MNIST = None + Compose = object class MNISTDataModule(VisionDataModule): @@ -92,7 +94,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self): + def default_transforms(self) -> Compose: if self.normalize: mnist_transforms = transform_lib.Compose( [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))] diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py index 3dbda03527..d575eb2d01 100644 --- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py +++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py @@ -11,8 +11,10 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib + from torchvision.transforms import Compose else: warn_missing_pkg('torchvision') # pragma: no-cover + Compose = object class SSLImagenetDataModule(LightningDataModule): # pragma: no cover @@ -134,7 +136,7 @@ def test_dataloader(self, num_images_per_class: int, add_normalize: bool = False ) return loader - def _default_transforms(self): + def _default_transforms(self) -> Compose: mnist_transforms = transform_lib.Compose([ transform_lib.ToTensor(), imagenet_normalization() diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py index 79420be149..f9ac77e140 100644 --- a/pl_bolts/datamodules/stl10_datamodule.py +++ b/pl_bolts/datamodules/stl10_datamodule.py @@ -13,8 +13,10 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib from torchvision.datasets import STL10 + from torchvision.transforms import Compose else: warn_missing_pkg('torchvision') # pragma: no-cover + Compose = object class STL10DataModule(LightningDataModule): # pragma: no cover @@ -299,7 +301,7 @@ def val_dataloader_labeled(self) -> DataLoader: ) return loader - def _default_transforms(self): + def _default_transforms(self) -> Compose: data_transforms = transform_lib.Compose([ transform_lib.ToTensor(), stl10_normalization() diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py index 2144f0f509..92e6723968 100644 --- a/pl_bolts/datamodules/vision_datamodule.py +++ b/pl_bolts/datamodules/vision_datamodule.py @@ -6,6 +6,13 @@ from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader, Dataset, random_split +from pl_bolts.utils import _TORCHVISION_AVAILABLE + +if _TORCHVISION_AVAILABLE: + from torchvision.transforms import Compose +else: + Compose = object + class VisionDataModule(LightningDataModule): @@ -115,7 +122,7 @@ def _get_splits(self, len_dataset: int) -> List[int]: return splits @abstractmethod - def default_transforms(self): + def default_transforms(self) -> Compose: """ Default transform for the dataset """ def train_dataloader(self) -> DataLoader: From 0bc9f7b99646e73696bbe75e2311463a607231c5 Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Fri, 18 Dec 2020 19:58:16 +0900 Subject: [PATCH 30/75] Adding types to datamodules --- pl_bolts/datamodules/binary_mnist_datamodule.py | 2 +- pl_bolts/datamodules/cifar10_datamodule.py | 6 +++++- pl_bolts/datamodules/fashion_mnist_datamodule.py | 2 +- pl_bolts/datamodules/mnist_datamodule.py | 2 +- pl_bolts/datamodules/sklearn_datamodule.py | 2 +- pl_bolts/datamodules/vision_datamodule.py | 6 +++--- 6 files changed, 12 insertions(+), 8 deletions(-) diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py index ad17360d08..28e44ce72c 100644 --- a/pl_bolts/datamodules/binary_mnist_datamodule.py +++ b/pl_bolts/datamodules/binary_mnist_datamodule.py @@ -58,7 +58,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ) -> None: + ): """ Args: data_dir: Where to save/load the data diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index 9dbf10b670..c5e880bf1e 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -114,7 +114,11 @@ def num_classes(self) -> int: """ return 10 +<<<<<<< HEAD def default_transforms(self) -> Compose: +======= + def default_transforms(self) -> transform_lib.Compose: +>>>>>>> Adding types to datamodules if self.normalize: cf10_transforms = transform_lib.Compose([transform_lib.ToTensor(), cifar10_normalization()]) else: @@ -155,7 +159,7 @@ def __init__( labels: Optional[Sequence] = (1, 5, 8), *args: Any, **kwargs: Any, - ) -> None: + ): """ Args: data_dir: where to save/load the data diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py index b37221bc74..97cb8de0cb 100644 --- a/pl_bolts/datamodules/fashion_mnist_datamodule.py +++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py @@ -59,7 +59,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ) -> None: + ): """ Args: data_dir: Where to save/load the data diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py index 711460023c..2ab81f6422 100644 --- a/pl_bolts/datamodules/mnist_datamodule.py +++ b/pl_bolts/datamodules/mnist_datamodule.py @@ -58,7 +58,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ) -> None: + ): """ Args: data_dir: Where to save/load the data diff --git a/pl_bolts/datamodules/sklearn_datamodule.py b/pl_bolts/datamodules/sklearn_datamodule.py index d9477acc0b..5c05ea3fce 100644 --- a/pl_bolts/datamodules/sklearn_datamodule.py +++ b/pl_bolts/datamodules/sklearn_datamodule.py @@ -204,7 +204,7 @@ def _init_datasets( x_val: np.ndarray, y_val: np.ndarray, x_test: np.ndarray, - y_test: np.ndarray +<<<<<<< HEAD ) -> None: self.train_dataset = SklearnDataset(X, y) self.val_dataset = SklearnDataset(x_val, y_val) diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py index 92e6723968..5346713f5b 100644 --- a/pl_bolts/datamodules/vision_datamodule.py +++ b/pl_bolts/datamodules/vision_datamodule.py @@ -36,7 +36,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ) -> None: + ): """ Args: data_dir: Where to save/load the data @@ -63,14 +63,14 @@ def __init__( self.pin_memory = pin_memory self.drop_last = drop_last - def prepare_data(self) -> None: + def prepare_data(self): """ Saves files to data_dir """ self.dataset_cls(self.data_dir, train=True, download=True) self.dataset_cls(self.data_dir, train=False, download=True) - def setup(self, stage: Optional[str] = None) -> None: + def setup(self, stage: Optional[str] = None): """ Creates train, val, and test dataset """ From 05fcef2aebae6d90d64cc964e5c3a460d4f51a25 Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Fri, 18 Dec 2020 20:20:20 +0900 Subject: [PATCH 31/75] Fixing typing imports --- pl_bolts/datamodules/vocdetection_datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py index 52d5065d97..2ebcad6cc5 100644 --- a/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -203,7 +203,7 @@ def val_dataloader(self, batch_size: int = 1, transforms=None) -> DataLoader: ) return loader - def _default_transforms(self): + def _default_transforms(self) -> T.Compose: if self.normalize: return ( lambda image, target: ( From d92604b10545309db49a81b2a4cbb1aff66654e6 Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Sun, 20 Dec 2020 03:03:00 +0900 Subject: [PATCH 32/75] Removing torchvision.transforms from return typing --- pl_bolts/datamodules/cifar10_datamodule.py | 4 ---- pl_bolts/datamodules/vocdetection_datamodule.py | 2 +- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index c5e880bf1e..cbe0333050 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -114,11 +114,7 @@ def num_classes(self) -> int: """ return 10 -<<<<<<< HEAD def default_transforms(self) -> Compose: -======= - def default_transforms(self) -> transform_lib.Compose: ->>>>>>> Adding types to datamodules if self.normalize: cf10_transforms = transform_lib.Compose([transform_lib.ToTensor(), cifar10_normalization()]) else: diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py index 2ebcad6cc5..52d5065d97 100644 --- a/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -203,7 +203,7 @@ def val_dataloader(self, batch_size: int = 1, transforms=None) -> DataLoader: ) return loader - def _default_transforms(self) -> T.Compose: + def _default_transforms(self): if self.normalize: return ( lambda image, target: ( From a9c641b817f6276720aa6abf11698d086dc8f339 Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Mon, 21 Dec 2020 10:13:11 +0900 Subject: [PATCH 33/75] Removing return typing --- pl_bolts/datamodules/cifar10_datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index cbe0333050..cdca1b61a6 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -73,7 +73,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ) -> None: + ): """ Args: data_dir: Where to save/load the data From 314329c571d58ef2918b9da3776f5347bc4d7a92 Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Mon, 21 Dec 2020 16:51:36 +0900 Subject: [PATCH 34/75] Add `None` return type --- pl_bolts/datamodules/binary_mnist_datamodule.py | 2 +- pl_bolts/datamodules/cifar10_datamodule.py | 4 ++-- pl_bolts/datamodules/fashion_mnist_datamodule.py | 2 +- pl_bolts/datamodules/mnist_datamodule.py | 2 +- pl_bolts/datamodules/sklearn_datamodule.py | 2 +- pl_bolts/datamodules/vision_datamodule.py | 6 +++--- 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py index 28e44ce72c..ad17360d08 100644 --- a/pl_bolts/datamodules/binary_mnist_datamodule.py +++ b/pl_bolts/datamodules/binary_mnist_datamodule.py @@ -58,7 +58,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: Where to save/load the data diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index cdca1b61a6..9dbf10b670 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -73,7 +73,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: Where to save/load the data @@ -155,7 +155,7 @@ def __init__( labels: Optional[Sequence] = (1, 5, 8), *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: where to save/load the data diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py index 97cb8de0cb..b37221bc74 100644 --- a/pl_bolts/datamodules/fashion_mnist_datamodule.py +++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py @@ -59,7 +59,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: Where to save/load the data diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py index 2ab81f6422..711460023c 100644 --- a/pl_bolts/datamodules/mnist_datamodule.py +++ b/pl_bolts/datamodules/mnist_datamodule.py @@ -58,7 +58,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: Where to save/load the data diff --git a/pl_bolts/datamodules/sklearn_datamodule.py b/pl_bolts/datamodules/sklearn_datamodule.py index 5c05ea3fce..d9477acc0b 100644 --- a/pl_bolts/datamodules/sklearn_datamodule.py +++ b/pl_bolts/datamodules/sklearn_datamodule.py @@ -204,7 +204,7 @@ def _init_datasets( x_val: np.ndarray, y_val: np.ndarray, x_test: np.ndarray, -<<<<<<< HEAD + y_test: np.ndarray ) -> None: self.train_dataset = SklearnDataset(X, y) self.val_dataset = SklearnDataset(x_val, y_val) diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py index 5346713f5b..92e6723968 100644 --- a/pl_bolts/datamodules/vision_datamodule.py +++ b/pl_bolts/datamodules/vision_datamodule.py @@ -36,7 +36,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: Where to save/load the data @@ -63,14 +63,14 @@ def __init__( self.pin_memory = pin_memory self.drop_last = drop_last - def prepare_data(self): + def prepare_data(self) -> None: """ Saves files to data_dir """ self.dataset_cls(self.data_dir, train=True, download=True) self.dataset_cls(self.data_dir, train=False, download=True) - def setup(self, stage: Optional[str] = None): + def setup(self, stage: Optional[str] = None) -> None: """ Creates train, val, and test dataset """ From 14ea6b7da5ea2a15f44760cbfe6a131235ce2497 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <jirka.borovec@seznam.cz> Date: Tue, 5 Jan 2021 14:04:40 +0100 Subject: [PATCH 35/75] enable check --- setup.cfg | 3 --- 1 file changed, 3 deletions(-) diff --git a/setup.cfg b/setup.cfg index dcd35979f9..bda41d20f4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -61,9 +61,6 @@ show_error_codes = True disallow_untyped_defs = True ignore_missing_imports = True -[mypy-pl_bolts.datamodules.*] -ignore_errors = True - [mypy-pl_bolts.datasets.*] ignore_errors = True From 8a7c6f128a33dab0230734c30b0da7bfdb39e390 Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Fri, 18 Dec 2020 19:58:16 +0900 Subject: [PATCH 36/75] Adding types to datamodules --- .../datamodules/binary_mnist_datamodule.py | 4 ++-- pl_bolts/datamodules/cifar10_datamodule.py | 6 ++--- pl_bolts/datamodules/cityscapes_datamodule.py | 12 +++++----- .../datamodules/fashion_mnist_datamodule.py | 4 ++-- pl_bolts/datamodules/imagenet_datamodule.py | 14 +++++------ pl_bolts/datamodules/kitti_datamodule.py | 8 +++---- pl_bolts/datamodules/mnist_datamodule.py | 4 ++-- pl_bolts/datamodules/sklearn_datamodule.py | 24 ++++++++++++------- .../datamodules/ssl_imagenet_datamodule.py | 10 ++++---- pl_bolts/datamodules/stl10_datamodule.py | 16 ++++++------- pl_bolts/datamodules/vision_datamodule.py | 16 +++++++++---- 11 files changed, 67 insertions(+), 51 deletions(-) diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py index 4dea946bf9..6e47040f90 100644 --- a/pl_bolts/datamodules/binary_mnist_datamodule.py +++ b/pl_bolts/datamodules/binary_mnist_datamodule.py @@ -56,7 +56,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ) -> None: + ): """ Args: data_dir: Where to save/load the data @@ -98,7 +98,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self): + def default_transforms(self) -> transform_lib.Compose: if self.normalize: mnist_transforms = transform_lib.Compose([ transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5, ), std=(0.5, )) diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index afb2df8c9a..12aea1ec87 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -112,7 +112,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self): + def default_transforms(self) -> transform_lib.Compose: if self.normalize: cf10_transforms = transform_lib.Compose([transform_lib.ToTensor(), cifar10_normalization()]) else: @@ -146,14 +146,14 @@ class TinyCIFAR10DataModule(CIFAR10DataModule): def __init__( self, - data_dir: str, + data_dir: Optional[str], val_split: int = 50, num_workers: int = 16, num_samples: int = 100, labels: Optional[Sequence] = (1, 5, 8), *args: Any, **kwargs: Any, - ) -> None: + ): """ Args: data_dir: where to save/load the data diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py index ba6acb947d..4e1bbe0699 100644 --- a/pl_bolts/datamodules/cityscapes_datamodule.py +++ b/pl_bolts/datamodules/cityscapes_datamodule.py @@ -109,14 +109,14 @@ def __init__( self.target_transforms = None @property - def num_classes(self): + def num_classes(self) -> int: """ Return: 30 """ return 30 - def train_dataloader(self): + def train_dataloader(self) -> DataLoader: """ Cityscapes train set """ @@ -143,7 +143,7 @@ def train_dataloader(self): ) return loader - def val_dataloader(self): + def val_dataloader(self) -> DataLoader: """ Cityscapes val set """ @@ -170,7 +170,7 @@ def val_dataloader(self): ) return loader - def test_dataloader(self): + def test_dataloader(self) -> DataLoader: """ Cityscapes test set """ @@ -196,7 +196,7 @@ def test_dataloader(self): ) return loader - def _default_transforms(self): + def _default_transforms(self) -> transform_lib.Compose: cityscapes_transforms = transform_lib.Compose([ transform_lib.ToTensor(), transform_lib.Normalize( @@ -205,7 +205,7 @@ def _default_transforms(self): ]) return cityscapes_transforms - def _default_target_transforms(self): + def _default_target_transforms(self) -> transform_lib.Compose: cityscapes_target_trasnforms = transform_lib.Compose([ transform_lib.ToTensor(), transform_lib.Lambda(lambda t: t.squeeze()) ]) diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py index 8455d1f315..9e4022e218 100644 --- a/pl_bolts/datamodules/fashion_mnist_datamodule.py +++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py @@ -57,7 +57,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ) -> None: + ): """ Args: data_dir: Where to save/load the data @@ -93,7 +93,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self): + def default_transforms(self) -> transform_lib.Compose: if self.normalize: mnist_transforms = transform_lib.Compose([ transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5, ), std=(0.5, )) diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py index 60b6c32578..4c3dc2c41f 100644 --- a/pl_bolts/datamodules/imagenet_datamodule.py +++ b/pl_bolts/datamodules/imagenet_datamodule.py @@ -94,7 +94,7 @@ def __init__( self.num_samples = 1281167 - self.num_imgs_per_val_class * self.num_classes @property - def num_classes(self): + def num_classes(self) -> int: """ Return: @@ -103,7 +103,7 @@ def num_classes(self): """ return 1000 - def _verify_splits(self, data_dir, split): + def _verify_splits(self, data_dir: str, split: str): dirs = os.listdir(data_dir) if split not in dirs: @@ -142,7 +142,7 @@ def prepare_data(self): """ ) - def train_dataloader(self): + def train_dataloader(self) -> DataLoader: """ Uses the train split of imagenet2012 and puts away a portion of it for the validation split """ @@ -166,7 +166,7 @@ def train_dataloader(self): ) return loader - def val_dataloader(self): + def val_dataloader(self) -> DataLoader: """ Uses the part of the train split of imagenet2012 that was not used for training via `num_imgs_per_val_class` @@ -193,7 +193,7 @@ def val_dataloader(self): ) return loader - def test_dataloader(self): + def test_dataloader(self) -> DataLoader: """ Uses the validation split of imagenet2012 for testing """ @@ -212,7 +212,7 @@ def test_dataloader(self): ) return loader - def train_transform(self): + def train_transform(self) -> transform_lib.Compose: """ The standard imagenet transforms @@ -238,7 +238,7 @@ def train_transform(self): return preprocessing - def val_transform(self): + def val_transform(self) -> transform_lib.Compose: """ The standard imagenet transforms for validation diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py index e2cb6fa828..dec13b4514 100644 --- a/pl_bolts/datamodules/kitti_datamodule.py +++ b/pl_bolts/datamodules/kitti_datamodule.py @@ -97,7 +97,7 @@ def __init__( kitti_dataset, lengths=[train_len, val_len, test_len], generator=torch.Generator().manual_seed(self.seed) ) - def train_dataloader(self): + def train_dataloader(self) -> DataLoader: loader = DataLoader( self.trainset, batch_size=self.batch_size, @@ -108,7 +108,7 @@ def train_dataloader(self): ) return loader - def val_dataloader(self): + def val_dataloader(self) -> DataLoader: loader = DataLoader( self.valset, batch_size=self.batch_size, @@ -119,7 +119,7 @@ def val_dataloader(self): ) return loader - def test_dataloader(self): + def test_dataloader(self) -> DataLoader: loader = DataLoader( self.testset, batch_size=self.batch_size, @@ -130,7 +130,7 @@ def test_dataloader(self): ) return loader - def _default_transforms(self): + def _default_transforms(self) -> transform_lib.Compose: kitti_transforms = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753], std=[0.32064945, 0.32098866, 0.32325324]) diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py index b700b23123..87d1d72418 100644 --- a/pl_bolts/datamodules/mnist_datamodule.py +++ b/pl_bolts/datamodules/mnist_datamodule.py @@ -56,7 +56,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ) -> None: + ): """ Args: data_dir: Where to save/load the data @@ -92,7 +92,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self): + def default_transforms(self) -> transform_lib.Compose: if self.normalize: mnist_transforms = transform_lib.Compose([ transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5, ), std=(0.5, )) diff --git a/pl_bolts/datamodules/sklearn_datamodule.py b/pl_bolts/datamodules/sklearn_datamodule.py index 00e333fd30..ef983a609f 100644 --- a/pl_bolts/datamodules/sklearn_datamodule.py +++ b/pl_bolts/datamodules/sklearn_datamodule.py @@ -43,10 +43,10 @@ def __init__(self, X: np.ndarray, y: np.ndarray, X_transform: Any = None, y_tran self.X_transform = X_transform self.y_transform = y_transform - def __len__(self): + def __len__(self) -> int: return len(self.X) - def __getitem__(self, idx): + def __getitem__(self, idx) -> Tuple[np.ndarray, np.ndarray]: x = self.X[idx].astype(np.float32) y = self.Y[idx] @@ -91,10 +91,10 @@ def __init__(self, X: torch.Tensor, y: torch.Tensor, X_transform: Any = None, y_ self.X_transform = X_transform self.y_transform = y_transform - def __len__(self): + def __len__(self) -> int: return len(self.X) - def __getitem__(self, idx): + def __getitem__(self, idx) -> Tuple[torch.Tensor, torch.Tensor]: x = self.X[idx].float() y = self.Y[idx] @@ -200,12 +200,20 @@ def __init__( self._init_datasets(X, y, x_val, y_val, x_test, y_test) - def _init_datasets(self, X, y, x_val, y_val, x_test, y_test): + def _init_datasets( + self, + X: np.ndarray, + y: np.ndarray, + x_val: np.ndarray, + y_val: np.ndarray, + x_test: np.ndarray, + y_test: np.ndarray + ): self.train_dataset = SklearnDataset(X, y) self.val_dataset = SklearnDataset(x_val, y_val) self.test_dataset = SklearnDataset(x_test, y_test) - def train_dataloader(self): + def train_dataloader(self) -> DataLoader: loader = DataLoader( self.train_dataset, batch_size=self.batch_size, @@ -216,7 +224,7 @@ def train_dataloader(self): ) return loader - def val_dataloader(self): + def val_dataloader(self) -> DataLoader: loader = DataLoader( self.val_dataset, batch_size=self.batch_size, @@ -227,7 +235,7 @@ def val_dataloader(self): ) return loader - def test_dataloader(self): + def test_dataloader(self) -> DataLoader: loader = DataLoader( self.test_dataset, batch_size=self.batch_size, diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py index 7949d218e5..4ede93041a 100644 --- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py +++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py @@ -46,10 +46,10 @@ def __init__( self.drop_last = drop_last @property - def num_classes(self): + def num_classes(self) -> int: return 1000 - def _verify_splits(self, data_dir, split): + def _verify_splits(self, data_dir: str, split: str): dirs = os.listdir(data_dir) if split not in dirs: @@ -83,7 +83,7 @@ def prepare_data(self): """ ) - def train_dataloader(self, num_images_per_class=-1, add_normalize=False): + def train_dataloader(self, num_images_per_class: int = -1, add_normalize: bool = False) -> DataLoader: transforms = self._default_transforms() if self.train_transforms is None else self.train_transforms dataset = UnlabeledImagenet( @@ -103,7 +103,7 @@ def train_dataloader(self, num_images_per_class=-1, add_normalize=False): ) return loader - def val_dataloader(self, num_images_per_class=50, add_normalize=False): + def val_dataloader(self, num_images_per_class: int = 50, add_normalize: bool = False) -> DataLoader: transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms dataset = UnlabeledImagenet( @@ -123,7 +123,7 @@ def val_dataloader(self, num_images_per_class=50, add_normalize=False): ) return loader - def test_dataloader(self, num_images_per_class, add_normalize=False): + def test_dataloader(self, num_images_per_class: int, add_normalize: bool = False) -> DataLoader: transforms = self._default_transforms() if self.test_transforms is None else self.test_transforms dataset = UnlabeledImagenet( diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py index 90f3434aa1..30411a5ea2 100644 --- a/pl_bolts/datamodules/stl10_datamodule.py +++ b/pl_bolts/datamodules/stl10_datamodule.py @@ -99,7 +99,7 @@ def __init__( self.num_unlabeled_samples = 100000 - unlabeled_val_split @property - def num_classes(self): + def num_classes(self) -> int: return 10 def prepare_data(self): @@ -110,7 +110,7 @@ def prepare_data(self): STL10(self.data_dir, split='train', download=True, transform=transform_lib.ToTensor()) STL10(self.data_dir, split='test', download=True, transform=transform_lib.ToTensor()) - def train_dataloader(self): + def train_dataloader(self) -> DataLoader: """ Loads the 'unlabeled' split minus a portion set aside for validation via `unlabeled_val_split`. """ @@ -132,7 +132,7 @@ def train_dataloader(self): ) return loader - def train_dataloader_mixed(self): + def train_dataloader_mixed(self) -> DataLoader: """ Loads a portion of the 'unlabeled' training data and 'train' (labeled) data. both portions have a subset removed for validation via `unlabeled_val_split` and `train_val_split` @@ -169,7 +169,7 @@ def train_dataloader_mixed(self): ) return loader - def val_dataloader(self): + def val_dataloader(self) -> DataLoader: """ Loads a portion of the 'unlabeled' training data set aside for validation The val dataset = (unlabeled - train_val_split) @@ -197,7 +197,7 @@ def val_dataloader(self): ) return loader - def val_dataloader_mixed(self): + def val_dataloader_mixed(self) -> DataLoader: """ Loads a portion of the 'unlabeled' training data set aside for validation along with the portion of the 'train' dataset to be used for validation @@ -239,7 +239,7 @@ def val_dataloader_mixed(self): ) return loader - def test_dataloader(self): + def test_dataloader(self) -> DataLoader: """ Loads the test split of STL10 @@ -260,7 +260,7 @@ def test_dataloader(self): ) return loader - def train_dataloader_labeled(self): + def train_dataloader_labeled(self) -> DataLoader: transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms dataset = STL10(self.data_dir, split='train', download=False, transform=transforms) @@ -279,7 +279,7 @@ def train_dataloader_labeled(self): ) return loader - def val_dataloader_labeled(self): + def val_dataloader_labeled(self) -> DataLoader: transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms dataset = STL10(self.data_dir, split='train', download=False, transform=transforms) labeled_length = len(dataset) diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py index 5a6f4af4c2..7ab8f1cccb 100644 --- a/pl_bolts/datamodules/vision_datamodule.py +++ b/pl_bolts/datamodules/vision_datamodule.py @@ -6,6 +6,14 @@ from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader, Dataset, random_split +from pl_bolts.utils import _TORCHVISION_AVAILABLE +from pl_bolts.utils.warnings import warn_missing_pkg + +if _TORCHVISION_AVAILABLE: + from torchvision import transforms as transform_lib +else: + warn_missing_pkg('torchvision') # pragma: no-cover + class VisionDataModule(LightningDataModule): @@ -29,7 +37,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ) -> None: + ): """ Args: data_dir: Where to save/load the data @@ -56,14 +64,14 @@ def __init__( self.pin_memory = pin_memory self.drop_last = drop_last - def prepare_data(self) -> None: + def prepare_data(self): """ Saves files to data_dir """ self.dataset_cls(self.data_dir, train=True, download=True) self.dataset_cls(self.data_dir, train=False, download=True) - def setup(self, stage: Optional[str] = None) -> None: + def setup(self, stage: Optional[str] = None): """ Creates train, val, and test dataset """ @@ -113,7 +121,7 @@ def _get_splits(self, len_dataset: int) -> List[int]: return splits @abstractmethod - def default_transforms(self): + def default_transforms(self) -> transform_lib.Compose: """ Default transform for the dataset """ def train_dataloader(self) -> DataLoader: From 3b0ee3cf129f8a72e6b0cf41f41a6d1c9c19cba8 Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Fri, 18 Dec 2020 20:20:20 +0900 Subject: [PATCH 37/75] Fixing typing imports --- pl_bolts/datamodules/async_dataloader.py | 14 +++++++++++--- pl_bolts/datamodules/cityscapes_datamodule.py | 2 ++ pl_bolts/datamodules/imagenet_datamodule.py | 2 +- pl_bolts/datamodules/kitti_datamodule.py | 3 ++- pl_bolts/datamodules/sklearn_datamodule.py | 2 +- .../datamodules/ssl_imagenet_datamodule.py | 1 + pl_bolts/datamodules/stl10_datamodule.py | 2 +- pl_bolts/datamodules/vision_datamodule.py | 7 +------ .../datamodules/vocdetection_datamodule.py | 18 ++++++++++-------- 9 files changed, 30 insertions(+), 21 deletions(-) diff --git a/pl_bolts/datamodules/async_dataloader.py b/pl_bolts/datamodules/async_dataloader.py index 7ded9d9ef1..38a0b9bb58 100644 --- a/pl_bolts/datamodules/async_dataloader.py +++ b/pl_bolts/datamodules/async_dataloader.py @@ -1,10 +1,11 @@ import re from queue import Queue from threading import Thread +from typing import Any, Optional, Union import torch from torch._six import container_abcs, string_classes -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, Dataset class AsynchronousLoader(object): @@ -26,7 +27,14 @@ class AsynchronousLoader(object): constructing one here """ - def __init__(self, data, device=torch.device('cuda', 0), q_size=10, num_batches=None, **kwargs): + def __init__( + self, + data: Union[DataLoader, Dataset], + device: torch.device = torch.device('cuda', 0), + q_size: int = 10, + num_batches: Optional[int] = None, + **kwargs: Any + ): if isinstance(data, torch.utils.data.DataLoader): self.dataloader = data else: @@ -105,5 +113,5 @@ def __next__(self): self.idx += 1 return out - def __len__(self): + def __len__(self) -> int: return self.num_batches diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py index 4e1bbe0699..ce5d4d52ee 100644 --- a/pl_bolts/datamodules/cityscapes_datamodule.py +++ b/pl_bolts/datamodules/cityscapes_datamodule.py @@ -1,3 +1,5 @@ +from typing import Any + from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py index 4c3dc2c41f..df4094d353 100644 --- a/pl_bolts/datamodules/imagenet_datamodule.py +++ b/pl_bolts/datamodules/imagenet_datamodule.py @@ -1,5 +1,5 @@ import os -from typing import Optional +from typing import Any, Optional from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py index dec13b4514..856f54a39e 100644 --- a/pl_bolts/datamodules/kitti_datamodule.py +++ b/pl_bolts/datamodules/kitti_datamodule.py @@ -1,4 +1,5 @@ import os +from typing import Any, Optional import torch from pytorch_lightning import LightningDataModule @@ -130,7 +131,7 @@ def test_dataloader(self) -> DataLoader: ) return loader - def _default_transforms(self) -> transform_lib.Compose: + def _default_transforms(self) -> transforms.Compose: kitti_transforms = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753], std=[0.32064945, 0.32098866, 0.32325324]) diff --git a/pl_bolts/datamodules/sklearn_datamodule.py b/pl_bolts/datamodules/sklearn_datamodule.py index ef983a609f..a2ffca5ee8 100644 --- a/pl_bolts/datamodules/sklearn_datamodule.py +++ b/pl_bolts/datamodules/sklearn_datamodule.py @@ -1,5 +1,5 @@ import math -from typing import Any +from typing import Any, Tuple import numpy as np import torch diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py index 4ede93041a..354cb4f02b 100644 --- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py +++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py @@ -1,4 +1,5 @@ import os +from typing import Any, Optional from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py index 30411a5ea2..8f5ac120eb 100644 --- a/pl_bolts/datamodules/stl10_datamodule.py +++ b/pl_bolts/datamodules/stl10_datamodule.py @@ -1,5 +1,5 @@ import os -from typing import Optional +from typing import Any, Optional import torch from pytorch_lightning import LightningDataModule diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py index 7ab8f1cccb..faac1663da 100644 --- a/pl_bolts/datamodules/vision_datamodule.py +++ b/pl_bolts/datamodules/vision_datamodule.py @@ -9,11 +9,6 @@ from pl_bolts.utils import _TORCHVISION_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg -if _TORCHVISION_AVAILABLE: - from torchvision import transforms as transform_lib -else: - warn_missing_pkg('torchvision') # pragma: no-cover - class VisionDataModule(LightningDataModule): @@ -121,7 +116,7 @@ def _get_splits(self, len_dataset: int) -> List[int]: return splits @abstractmethod - def default_transforms(self) -> transform_lib.Compose: + def default_transforms(self): """ Default transform for the dataset """ def train_dataloader(self) -> DataLoader: diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py index 6065dcf076..cb54d75d2e 100644 --- a/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -1,3 +1,5 @@ +from typing import Any, Dict + import torch from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader @@ -17,7 +19,7 @@ class Compose(object): Like `torchvision.transforms.compose` but works for (image, target) """ - def __init__(self, transforms): + def __init__(self, transforms: T.Compose): self.transforms = transforms def __call__(self, image, target): @@ -55,7 +57,7 @@ def _collate_fn(batch): ) -def _prepare_voc_instance(image, target): +def _prepare_voc_instance(image, target: Dict[str, Any]): """ Prepares VOC dataset into appropriate target for fasterrcnn @@ -113,8 +115,8 @@ def __init__( shuffle: bool = False, pin_memory: bool = False, drop_last: bool = False, - *args, - **kwargs, + *args: Any, + **kwargs: Any, ): if not _TORCHVISION_AVAILABLE: raise ModuleNotFoundError( # pragma: no-cover @@ -132,7 +134,7 @@ def __init__( self.drop_last = drop_last @property - def num_classes(self): + def num_classes(self) -> int: """ Return: 21 @@ -146,7 +148,7 @@ def prepare_data(self): VOCDetection(self.data_dir, year=self.year, image_set="train", download=True) VOCDetection(self.data_dir, year=self.year, image_set="val", download=True) - def train_dataloader(self, batch_size=1, transforms=None): + def train_dataloader(self, batch_size: int = 1, transforms=None) -> DataLoader: """ VOCDetection train set uses the `train` subset @@ -172,7 +174,7 @@ def train_dataloader(self, batch_size=1, transforms=None): ) return loader - def val_dataloader(self, batch_size=1, transforms=None): + def val_dataloader(self, batch_size: int = 1, transforms=None) -> DataLoader: """ VOCDetection val set uses the `val` subset @@ -197,7 +199,7 @@ def val_dataloader(self, batch_size=1, transforms=None): ) return loader - def _default_transforms(self): + def _default_transforms(self) -> T.Compose: if self.normalize: return ( lambda image, target: ( From 3d1c9a11c23b8eb4bd727b7e0efc4e3a3347bc01 Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Sun, 20 Dec 2020 03:03:00 +0900 Subject: [PATCH 38/75] Removing torchvision.transforms from return typing --- pl_bolts/datamodules/binary_mnist_datamodule.py | 2 +- pl_bolts/datamodules/cifar10_datamodule.py | 2 +- pl_bolts/datamodules/cityscapes_datamodule.py | 4 ++-- pl_bolts/datamodules/fashion_mnist_datamodule.py | 2 +- pl_bolts/datamodules/imagenet_datamodule.py | 4 ++-- pl_bolts/datamodules/kitti_datamodule.py | 2 +- pl_bolts/datamodules/mnist_datamodule.py | 2 +- pl_bolts/datamodules/vocdetection_datamodule.py | 2 +- 8 files changed, 10 insertions(+), 10 deletions(-) diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py index 6e47040f90..85de4f0ef6 100644 --- a/pl_bolts/datamodules/binary_mnist_datamodule.py +++ b/pl_bolts/datamodules/binary_mnist_datamodule.py @@ -98,7 +98,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self) -> transform_lib.Compose: + def default_transforms(self): if self.normalize: mnist_transforms = transform_lib.Compose([ transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5, ), std=(0.5, )) diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index 12aea1ec87..b208172ed0 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -112,7 +112,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self) -> transform_lib.Compose: + def default_transforms(self): if self.normalize: cf10_transforms = transform_lib.Compose([transform_lib.ToTensor(), cifar10_normalization()]) else: diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py index ce5d4d52ee..a2d2aa7950 100644 --- a/pl_bolts/datamodules/cityscapes_datamodule.py +++ b/pl_bolts/datamodules/cityscapes_datamodule.py @@ -198,7 +198,7 @@ def test_dataloader(self) -> DataLoader: ) return loader - def _default_transforms(self) -> transform_lib.Compose: + def _default_transforms(self): cityscapes_transforms = transform_lib.Compose([ transform_lib.ToTensor(), transform_lib.Normalize( @@ -207,7 +207,7 @@ def _default_transforms(self) -> transform_lib.Compose: ]) return cityscapes_transforms - def _default_target_transforms(self) -> transform_lib.Compose: + def _default_target_transforms(self): cityscapes_target_trasnforms = transform_lib.Compose([ transform_lib.ToTensor(), transform_lib.Lambda(lambda t: t.squeeze()) ]) diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py index 9e4022e218..c8fd7232f8 100644 --- a/pl_bolts/datamodules/fashion_mnist_datamodule.py +++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py @@ -93,7 +93,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self) -> transform_lib.Compose: + def default_transforms(self): if self.normalize: mnist_transforms = transform_lib.Compose([ transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5, ), std=(0.5, )) diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py index df4094d353..80b6e6976a 100644 --- a/pl_bolts/datamodules/imagenet_datamodule.py +++ b/pl_bolts/datamodules/imagenet_datamodule.py @@ -212,7 +212,7 @@ def test_dataloader(self) -> DataLoader: ) return loader - def train_transform(self) -> transform_lib.Compose: + def train_transform(self): """ The standard imagenet transforms @@ -238,7 +238,7 @@ def train_transform(self) -> transform_lib.Compose: return preprocessing - def val_transform(self) -> transform_lib.Compose: + def val_transform(self): """ The standard imagenet transforms for validation diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py index 856f54a39e..6a052bdf56 100644 --- a/pl_bolts/datamodules/kitti_datamodule.py +++ b/pl_bolts/datamodules/kitti_datamodule.py @@ -131,7 +131,7 @@ def test_dataloader(self) -> DataLoader: ) return loader - def _default_transforms(self) -> transforms.Compose: + def _default_transforms(self): kitti_transforms = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753], std=[0.32064945, 0.32098866, 0.32325324]) diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py index 87d1d72418..5c8388facc 100644 --- a/pl_bolts/datamodules/mnist_datamodule.py +++ b/pl_bolts/datamodules/mnist_datamodule.py @@ -92,7 +92,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self) -> transform_lib.Compose: + def default_transforms(self): if self.normalize: mnist_transforms = transform_lib.Compose([ transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5, ), std=(0.5, )) diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py index cb54d75d2e..4011b6226a 100644 --- a/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -199,7 +199,7 @@ def val_dataloader(self, batch_size: int = 1, transforms=None) -> DataLoader: ) return loader - def _default_transforms(self) -> T.Compose: + def _default_transforms(self): if self.normalize: return ( lambda image, target: ( From 6cac90995387107381f90b269b336016441a2730 Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Sun, 20 Dec 2020 03:09:03 +0900 Subject: [PATCH 39/75] Remove more torchvision.transforms typing --- pl_bolts/datamodules/vocdetection_datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py index 4011b6226a..1bd6706650 100644 --- a/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -19,7 +19,7 @@ class Compose(object): Like `torchvision.transforms.compose` but works for (image, target) """ - def __init__(self, transforms: T.Compose): + def __init__(self, transforms): self.transforms = transforms def __call__(self, image, target): From c1ea0fb2bc34f0c5b2a083e848533b978a268408 Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Mon, 21 Dec 2020 10:13:11 +0900 Subject: [PATCH 40/75] Removing return typing --- pl_bolts/datamodules/cifar10_datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index b208172ed0..85ba4de6e7 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -71,7 +71,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ) -> None: + ): """ Args: data_dir: Where to save/load the data From c04caabd7b074a471dfb5608d3ddd51558482769 Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Mon, 21 Dec 2020 16:00:55 +0900 Subject: [PATCH 41/75] Add `None` for optional arguments --- pl_bolts/datamodules/cifar10_datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index 85ba4de6e7..534774684f 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -146,7 +146,7 @@ class TinyCIFAR10DataModule(CIFAR10DataModule): def __init__( self, - data_dir: Optional[str], + data_dir: Optional[str] = None, val_split: int = 50, num_workers: int = 16, num_samples: int = 100, From 5b6bf64b319b08a2c0651958226027194985d164 Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Mon, 21 Dec 2020 16:04:36 +0900 Subject: [PATCH 42/75] Remove unnecessary import --- pl_bolts/datamodules/vision_datamodule.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py index faac1663da..15648467e8 100644 --- a/pl_bolts/datamodules/vision_datamodule.py +++ b/pl_bolts/datamodules/vision_datamodule.py @@ -6,9 +6,6 @@ from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader, Dataset, random_split -from pl_bolts.utils import _TORCHVISION_AVAILABLE -from pl_bolts.utils.warnings import warn_missing_pkg - class VisionDataModule(LightningDataModule): From 7ce736dfdb1fe0d5167e6f62234878d922f13251 Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Mon, 21 Dec 2020 16:51:36 +0900 Subject: [PATCH 43/75] Add `None` return type --- pl_bolts/datamodules/async_dataloader.py | 4 ++-- pl_bolts/datamodules/binary_mnist_datamodule.py | 2 +- pl_bolts/datamodules/cifar10_datamodule.py | 4 ++-- pl_bolts/datamodules/cityscapes_datamodule.py | 2 +- pl_bolts/datamodules/experience_source.py | 4 ++-- pl_bolts/datamodules/fashion_mnist_datamodule.py | 2 +- pl_bolts/datamodules/imagenet_datamodule.py | 6 +++--- pl_bolts/datamodules/kitti_datamodule.py | 2 +- pl_bolts/datamodules/mnist_datamodule.py | 2 +- pl_bolts/datamodules/sklearn_datamodule.py | 14 ++++++++------ pl_bolts/datamodules/ssl_imagenet_datamodule.py | 6 +++--- pl_bolts/datamodules/stl10_datamodule.py | 4 ++-- pl_bolts/datamodules/vision_datamodule.py | 6 +++--- pl_bolts/datamodules/vocdetection_datamodule.py | 6 +++--- 14 files changed, 33 insertions(+), 31 deletions(-) diff --git a/pl_bolts/datamodules/async_dataloader.py b/pl_bolts/datamodules/async_dataloader.py index 38a0b9bb58..224f34d5ee 100644 --- a/pl_bolts/datamodules/async_dataloader.py +++ b/pl_bolts/datamodules/async_dataloader.py @@ -34,7 +34,7 @@ def __init__( q_size: int = 10, num_batches: Optional[int] = None, **kwargs: Any - ): + ) -> None: if isinstance(data, torch.utils.data.DataLoader): self.dataloader = data else: @@ -57,7 +57,7 @@ def __init__( self.np_str_obj_array_pattern = re.compile(r'[SaUO]') - def load_loop(self): # The loop that will load into the queue in the background + def load_loop(self) -> None: # The loop that will load into the queue in the background for i, sample in enumerate(self.dataloader): self.queue.put(self.load_instance(sample)) if i == len(self): diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py index 85de4f0ef6..4dea946bf9 100644 --- a/pl_bolts/datamodules/binary_mnist_datamodule.py +++ b/pl_bolts/datamodules/binary_mnist_datamodule.py @@ -56,7 +56,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: Where to save/load the data diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index 534774684f..2cb894d749 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -71,7 +71,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: Where to save/load the data @@ -153,7 +153,7 @@ def __init__( labels: Optional[Sequence] = (1, 5, 8), *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: where to save/load the data diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py index a2d2aa7950..7816236a27 100644 --- a/pl_bolts/datamodules/cityscapes_datamodule.py +++ b/pl_bolts/datamodules/cityscapes_datamodule.py @@ -73,7 +73,7 @@ def __init__( drop_last: bool = False, *args, **kwargs, - ): + ) -> None: """ Args: data_dir: where to load the data from path, i.e. where directory leftImg8bit and gtFine or gtCoarse diff --git a/pl_bolts/datamodules/experience_source.py b/pl_bolts/datamodules/experience_source.py index 6c85f76fd2..1bc7b0f8a8 100644 --- a/pl_bolts/datamodules/experience_source.py +++ b/pl_bolts/datamodules/experience_source.py @@ -27,7 +27,7 @@ class ExperienceSourceDataset(IterableDataset): The logic for the experience source and how the batch is generated is defined the Lightning model itself """ - def __init__(self, generate_batch: Callable): + def __init__(self, generate_batch: Callable) -> None: self.generate_batch = generate_batch def __iter__(self) -> Iterable: @@ -240,7 +240,7 @@ def pop_rewards_steps(self): class DiscountedExperienceSource(ExperienceSource): """Outputs experiences with a discounted reward over N steps""" - def __init__(self, env: Env, agent, n_steps: int = 1, gamma: float = 0.99): + def __init__(self, env: Env, agent, n_steps: int = 1, gamma: float = 0.99) -> None: super().__init__(env, agent, (n_steps + 1)) self.gamma = gamma self.steps = n_steps diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py index c8fd7232f8..8455d1f315 100644 --- a/pl_bolts/datamodules/fashion_mnist_datamodule.py +++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py @@ -57,7 +57,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: Where to save/load the data diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py index 80b6e6976a..61ff477e82 100644 --- a/pl_bolts/datamodules/imagenet_datamodule.py +++ b/pl_bolts/datamodules/imagenet_datamodule.py @@ -60,7 +60,7 @@ def __init__( drop_last: bool = False, *args, **kwargs, - ): + ) -> None: """ Args: data_dir: path to the imagenet dataset file @@ -103,7 +103,7 @@ def num_classes(self) -> int: """ return 1000 - def _verify_splits(self, data_dir: str, split: str): + def _verify_splits(self, data_dir: str, split: str) -> None: dirs = os.listdir(data_dir) if split not in dirs: @@ -112,7 +112,7 @@ def _verify_splits(self, data_dir: str, split: str): f' make sure the folder contains a subfolder named {split}' ) - def prepare_data(self): + def prepare_data(self) -> None: """ This method already assumes you have imagenet2012 downloaded. It validates the data using the meta.bin. diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py index 6a052bdf56..ee66d5c1dc 100644 --- a/pl_bolts/datamodules/kitti_datamodule.py +++ b/pl_bolts/datamodules/kitti_datamodule.py @@ -33,7 +33,7 @@ def __init__( drop_last: bool = False, *args, **kwargs, - ): + ) -> None: """ Kitti train, validation and test dataloaders. diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py index 5c8388facc..b700b23123 100644 --- a/pl_bolts/datamodules/mnist_datamodule.py +++ b/pl_bolts/datamodules/mnist_datamodule.py @@ -56,7 +56,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: Where to save/load the data diff --git a/pl_bolts/datamodules/sklearn_datamodule.py b/pl_bolts/datamodules/sklearn_datamodule.py index a2ffca5ee8..be517f47bc 100644 --- a/pl_bolts/datamodules/sklearn_datamodule.py +++ b/pl_bolts/datamodules/sklearn_datamodule.py @@ -28,8 +28,9 @@ class SklearnDataset(Dataset): >>> len(dataset) 506 """ - - def __init__(self, X: np.ndarray, y: np.ndarray, X_transform: Any = None, y_transform: Any = None): + def __init__( + self, X: np.ndarray, y: np.ndarray, X_transform: Any = None, y_transform: Any = None + ) -> None: """ Args: X: Numpy ndarray @@ -76,8 +77,9 @@ class TensorDataset(Dataset): >>> len(dataset) 10 """ - - def __init__(self, X: torch.Tensor, y: torch.Tensor, X_transform: Any = None, y_transform: Any = None): + def __init__( + self, X: torch.Tensor, y: torch.Tensor, X_transform: Any = None, y_transform: Any = None + ) -> None: """ Args: X: PyTorch tensor @@ -160,7 +162,7 @@ def __init__( drop_last=False, *args, **kwargs, - ): + ) -> None: super().__init__(*args, **kwargs) self.num_workers = num_workers @@ -208,7 +210,7 @@ def _init_datasets( y_val: np.ndarray, x_test: np.ndarray, y_test: np.ndarray - ): + ) -> None: self.train_dataset = SklearnDataset(X, y) self.val_dataset = SklearnDataset(x_val, y_val) self.test_dataset = SklearnDataset(x_test, y_test) diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py index 354cb4f02b..14656280ac 100644 --- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py +++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py @@ -30,7 +30,7 @@ def __init__( drop_last: bool = False, *args, **kwargs, - ): + ) -> None: super().__init__(*args, **kwargs) if not _TORCHVISION_AVAILABLE: @@ -50,7 +50,7 @@ def __init__( def num_classes(self) -> int: return 1000 - def _verify_splits(self, data_dir: str, split: str): + def _verify_splits(self, data_dir: str, split: str) -> None: dirs = os.listdir(data_dir) if split not in dirs: @@ -59,7 +59,7 @@ def _verify_splits(self, data_dir: str, split: str): f' folder contains a subfolder named {split}' ) - def prepare_data(self): + def prepare_data(self) -> None: # imagenet cannot be downloaded... must provide path to folder with the train/val splits self._verify_splits(self.data_dir, 'train') self._verify_splits(self.data_dir, 'val') diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py index 8f5ac120eb..5bf9a9b084 100644 --- a/pl_bolts/datamodules/stl10_datamodule.py +++ b/pl_bolts/datamodules/stl10_datamodule.py @@ -65,7 +65,7 @@ def __init__( drop_last: bool = False, *args, **kwargs, - ): + ) -> None: """ Args: data_dir: where to save/load the data @@ -102,7 +102,7 @@ def __init__( def num_classes(self) -> int: return 10 - def prepare_data(self): + def prepare_data(self) -> None: """ Downloads the unlabeled, train and test split """ diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py index 15648467e8..5a6f4af4c2 100644 --- a/pl_bolts/datamodules/vision_datamodule.py +++ b/pl_bolts/datamodules/vision_datamodule.py @@ -29,7 +29,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: Where to save/load the data @@ -56,14 +56,14 @@ def __init__( self.pin_memory = pin_memory self.drop_last = drop_last - def prepare_data(self): + def prepare_data(self) -> None: """ Saves files to data_dir """ self.dataset_cls(self.data_dir, train=True, download=True) self.dataset_cls(self.data_dir, train=False, download=True) - def setup(self, stage: Optional[str] = None): + def setup(self, stage: Optional[str] = None) -> None: """ Creates train, val, and test dataset """ diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py index 1bd6706650..be91eac99b 100644 --- a/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -19,7 +19,7 @@ class Compose(object): Like `torchvision.transforms.compose` but works for (image, target) """ - def __init__(self, transforms): + def __init__(self, transforms) -> None: self.transforms = transforms def __call__(self, image, target): @@ -117,7 +117,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: if not _TORCHVISION_AVAILABLE: raise ModuleNotFoundError( # pragma: no-cover 'You want to use VOC dataset loaded from `torchvision` which is not installed yet.' @@ -141,7 +141,7 @@ def num_classes(self) -> int: """ return 21 - def prepare_data(self): + def prepare_data(self) -> None: """ Saves VOCDetection files to data_dir """ From 7309adefb9b92aa1122d187a36e951e94681cf5d Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Tue, 5 Jan 2021 20:30:43 +0900 Subject: [PATCH 44/75] Add type for torchvision transforms --- pl_bolts/datamodules/binary_mnist_datamodule.py | 4 +++- pl_bolts/datamodules/cifar10_datamodule.py | 4 +++- pl_bolts/datamodules/cityscapes_datamodule.py | 6 ++++-- pl_bolts/datamodules/fashion_mnist_datamodule.py | 4 +++- pl_bolts/datamodules/imagenet_datamodule.py | 6 ++++-- pl_bolts/datamodules/kitti_datamodule.py | 4 +++- pl_bolts/datamodules/mnist_datamodule.py | 4 +++- pl_bolts/datamodules/ssl_imagenet_datamodule.py | 9 +++++++-- pl_bolts/datamodules/stl10_datamodule.py | 9 +++++++-- pl_bolts/datamodules/vision_datamodule.py | 9 ++++++++- 10 files changed, 45 insertions(+), 14 deletions(-) diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py index 4dea946bf9..cdd07ff1e2 100644 --- a/pl_bolts/datamodules/binary_mnist_datamodule.py +++ b/pl_bolts/datamodules/binary_mnist_datamodule.py @@ -7,8 +7,10 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib + from torchvision.transforms import Compose else: # pragma: no-cover warn_missing_pkg('torchvision') + Compose = object class BinaryMNISTDataModule(VisionDataModule): @@ -98,7 +100,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self): + def default_transforms(self) -> Compose: if self.normalize: mnist_transforms = transform_lib.Compose([ transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5, ), std=(0.5, )) diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index 2cb894d749..9dbf10b670 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -9,9 +9,11 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib from torchvision.datasets import CIFAR10 + from torchvision.transforms import Compose else: warn_missing_pkg('torchvision') # pragma: no-cover CIFAR10 = None + Compose = object class CIFAR10DataModule(VisionDataModule): @@ -112,7 +114,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self): + def default_transforms(self) -> Compose: if self.normalize: cf10_transforms = transform_lib.Compose([transform_lib.ToTensor(), cifar10_normalization()]) else: diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py index 7816236a27..3721f897af 100644 --- a/pl_bolts/datamodules/cityscapes_datamodule.py +++ b/pl_bolts/datamodules/cityscapes_datamodule.py @@ -9,8 +9,10 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib from torchvision.datasets import Cityscapes + from torchvision.transforms import Compose else: warn_missing_pkg('torchvision') # pragma: no-cover + Compose = object class CityscapesDataModule(LightningDataModule): @@ -198,7 +200,7 @@ def test_dataloader(self) -> DataLoader: ) return loader - def _default_transforms(self): + def _default_transforms(self) -> Compose: cityscapes_transforms = transform_lib.Compose([ transform_lib.ToTensor(), transform_lib.Normalize( @@ -207,7 +209,7 @@ def _default_transforms(self): ]) return cityscapes_transforms - def _default_target_transforms(self): + def _default_target_transforms(self) -> Compose: cityscapes_target_trasnforms = transform_lib.Compose([ transform_lib.ToTensor(), transform_lib.Lambda(lambda t: t.squeeze()) ]) diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py index 8455d1f315..b31a5aa792 100644 --- a/pl_bolts/datamodules/fashion_mnist_datamodule.py +++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py @@ -7,9 +7,11 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib from torchvision.datasets import FashionMNIST + from torchvision.transforms import Compose else: warn_missing_pkg('torchvision') # pragma: no-cover FashionMNIST = None + Compose = object class FashionMNISTDataModule(VisionDataModule): @@ -93,7 +95,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self): + def default_transforms(self) -> Compose: if self.normalize: mnist_transforms = transform_lib.Compose([ transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5, ), std=(0.5, )) diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py index 61ff477e82..b63611b060 100644 --- a/pl_bolts/datamodules/imagenet_datamodule.py +++ b/pl_bolts/datamodules/imagenet_datamodule.py @@ -11,8 +11,10 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib + from torchvision.transforms import Compose else: warn_missing_pkg('torchvision') # pragma: no-cover + Compose = object class ImagenetDataModule(LightningDataModule): @@ -212,7 +214,7 @@ def test_dataloader(self) -> DataLoader: ) return loader - def train_transform(self): + def train_transform(self) -> Compose: """ The standard imagenet transforms @@ -238,7 +240,7 @@ def train_transform(self): return preprocessing - def val_transform(self): + def val_transform(self) -> Compose: """ The standard imagenet transforms for validation diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py index ee66d5c1dc..cab7529e91 100644 --- a/pl_bolts/datamodules/kitti_datamodule.py +++ b/pl_bolts/datamodules/kitti_datamodule.py @@ -12,8 +12,10 @@ if _TORCHVISION_AVAILABLE: import torchvision.transforms as transforms + from torchvision.transforms import Compose else: # pragma: no cover warn_missing_pkg('torchvision') + Compose = object class KittiDataModule(LightningDataModule): @@ -131,7 +133,7 @@ def test_dataloader(self) -> DataLoader: ) return loader - def _default_transforms(self): + def _default_transforms(self) -> Compose: kitti_transforms = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753], std=[0.32064945, 0.32098866, 0.32325324]) diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py index b700b23123..d52315f41c 100644 --- a/pl_bolts/datamodules/mnist_datamodule.py +++ b/pl_bolts/datamodules/mnist_datamodule.py @@ -7,9 +7,11 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib from torchvision.datasets import MNIST + from torchvision.transforms import Compose else: warn_missing_pkg('torchvision') # pragma: no-cover MNIST = None + Compose = object class MNISTDataModule(VisionDataModule): @@ -92,7 +94,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self): + def default_transforms(self) -> Compose: if self.normalize: mnist_transforms = transform_lib.Compose([ transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5, ), std=(0.5, )) diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py index 14656280ac..6ee430428b 100644 --- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py +++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py @@ -11,8 +11,10 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib + from torchvision.transforms import Compose else: warn_missing_pkg('torchvision') # pragma: no-cover + Compose = object class SSLImagenetDataModule(LightningDataModule): # pragma: no cover @@ -144,6 +146,9 @@ def test_dataloader(self, num_images_per_class: int, add_normalize: bool = False ) return loader - def _default_transforms(self): - mnist_transforms = transform_lib.Compose([transform_lib.ToTensor(), imagenet_normalization()]) + def _default_transforms(self) -> Compose: + mnist_transforms = transform_lib.Compose([ + transform_lib.ToTensor(), + imagenet_normalization() + ]) return mnist_transforms diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py index 5bf9a9b084..45c0c040ea 100644 --- a/pl_bolts/datamodules/stl10_datamodule.py +++ b/pl_bolts/datamodules/stl10_datamodule.py @@ -13,8 +13,10 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib from torchvision.datasets import STL10 + from torchvision.transforms import Compose else: warn_missing_pkg('torchvision') # pragma: no-cover + Compose = object class STL10DataModule(LightningDataModule): # pragma: no cover @@ -298,6 +300,9 @@ def val_dataloader_labeled(self) -> DataLoader: ) return loader - def _default_transforms(self): - data_transforms = transform_lib.Compose([transform_lib.ToTensor(), stl10_normalization()]) + def _default_transforms(self) -> Compose: + data_transforms = transform_lib.Compose([ + transform_lib.ToTensor(), + stl10_normalization() + ]) return data_transforms diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py index 5a6f4af4c2..bab4c722dc 100644 --- a/pl_bolts/datamodules/vision_datamodule.py +++ b/pl_bolts/datamodules/vision_datamodule.py @@ -6,6 +6,13 @@ from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader, Dataset, random_split +from pl_bolts.utils import _TORCHVISION_AVAILABLE + +if _TORCHVISION_AVAILABLE: + from torchvision.transforms import Compose +else: + Compose = object + class VisionDataModule(LightningDataModule): @@ -113,7 +120,7 @@ def _get_splits(self, len_dataset: int) -> List[int]: return splits @abstractmethod - def default_transforms(self): + def default_transforms(self) -> Compose: """ Default transform for the dataset """ def train_dataloader(self) -> DataLoader: From cc154a7242215fd2feb0907727f856d8d59a2067 Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Fri, 18 Dec 2020 19:58:16 +0900 Subject: [PATCH 45/75] Adding types to datamodules --- pl_bolts/datamodules/binary_mnist_datamodule.py | 2 +- pl_bolts/datamodules/cifar10_datamodule.py | 6 +++++- pl_bolts/datamodules/fashion_mnist_datamodule.py | 2 +- pl_bolts/datamodules/mnist_datamodule.py | 2 +- pl_bolts/datamodules/sklearn_datamodule.py | 2 +- pl_bolts/datamodules/vision_datamodule.py | 6 +++--- 6 files changed, 12 insertions(+), 8 deletions(-) diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py index cdd07ff1e2..88ff0b359d 100644 --- a/pl_bolts/datamodules/binary_mnist_datamodule.py +++ b/pl_bolts/datamodules/binary_mnist_datamodule.py @@ -58,7 +58,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ) -> None: + ): """ Args: data_dir: Where to save/load the data diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index 9dbf10b670..c5e880bf1e 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -114,7 +114,11 @@ def num_classes(self) -> int: """ return 10 +<<<<<<< HEAD def default_transforms(self) -> Compose: +======= + def default_transforms(self) -> transform_lib.Compose: +>>>>>>> Adding types to datamodules if self.normalize: cf10_transforms = transform_lib.Compose([transform_lib.ToTensor(), cifar10_normalization()]) else: @@ -155,7 +159,7 @@ def __init__( labels: Optional[Sequence] = (1, 5, 8), *args: Any, **kwargs: Any, - ) -> None: + ): """ Args: data_dir: where to save/load the data diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py index b31a5aa792..8de6c99bd7 100644 --- a/pl_bolts/datamodules/fashion_mnist_datamodule.py +++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py @@ -59,7 +59,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ) -> None: + ): """ Args: data_dir: Where to save/load the data diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py index d52315f41c..5013b8b6b5 100644 --- a/pl_bolts/datamodules/mnist_datamodule.py +++ b/pl_bolts/datamodules/mnist_datamodule.py @@ -58,7 +58,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ) -> None: + ): """ Args: data_dir: Where to save/load the data diff --git a/pl_bolts/datamodules/sklearn_datamodule.py b/pl_bolts/datamodules/sklearn_datamodule.py index be517f47bc..ae8ad70d52 100644 --- a/pl_bolts/datamodules/sklearn_datamodule.py +++ b/pl_bolts/datamodules/sklearn_datamodule.py @@ -209,7 +209,7 @@ def _init_datasets( x_val: np.ndarray, y_val: np.ndarray, x_test: np.ndarray, - y_test: np.ndarray +<<<<<<< HEAD ) -> None: self.train_dataset = SklearnDataset(X, y) self.val_dataset = SklearnDataset(x_val, y_val) diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py index bab4c722dc..ba95c26a20 100644 --- a/pl_bolts/datamodules/vision_datamodule.py +++ b/pl_bolts/datamodules/vision_datamodule.py @@ -36,7 +36,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ) -> None: + ): """ Args: data_dir: Where to save/load the data @@ -63,14 +63,14 @@ def __init__( self.pin_memory = pin_memory self.drop_last = drop_last - def prepare_data(self) -> None: + def prepare_data(self): """ Saves files to data_dir """ self.dataset_cls(self.data_dir, train=True, download=True) self.dataset_cls(self.data_dir, train=False, download=True) - def setup(self, stage: Optional[str] = None) -> None: + def setup(self, stage: Optional[str] = None): """ Creates train, val, and test dataset """ From cff73307ca58a4df401a982475ae5534daf7a682 Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Fri, 18 Dec 2020 20:20:20 +0900 Subject: [PATCH 46/75] Fixing typing imports --- pl_bolts/datamodules/vocdetection_datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py index be91eac99b..9c6e4b501b 100644 --- a/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -199,7 +199,7 @@ def val_dataloader(self, batch_size: int = 1, transforms=None) -> DataLoader: ) return loader - def _default_transforms(self): + def _default_transforms(self) -> T.Compose: if self.normalize: return ( lambda image, target: ( From 3bbc1894237bcd616c00882d0f19a9c607fe93d0 Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Sun, 20 Dec 2020 03:03:00 +0900 Subject: [PATCH 47/75] Removing torchvision.transforms from return typing --- pl_bolts/datamodules/cifar10_datamodule.py | 4 ---- pl_bolts/datamodules/vocdetection_datamodule.py | 2 +- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index c5e880bf1e..cbe0333050 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -114,11 +114,7 @@ def num_classes(self) -> int: """ return 10 -<<<<<<< HEAD def default_transforms(self) -> Compose: -======= - def default_transforms(self) -> transform_lib.Compose: ->>>>>>> Adding types to datamodules if self.normalize: cf10_transforms = transform_lib.Compose([transform_lib.ToTensor(), cifar10_normalization()]) else: diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py index 9c6e4b501b..be91eac99b 100644 --- a/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -199,7 +199,7 @@ def val_dataloader(self, batch_size: int = 1, transforms=None) -> DataLoader: ) return loader - def _default_transforms(self) -> T.Compose: + def _default_transforms(self): if self.normalize: return ( lambda image, target: ( From 5b2401b6f8422d34593d9b518ad9da69706e4395 Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Mon, 21 Dec 2020 10:13:11 +0900 Subject: [PATCH 48/75] Removing return typing --- pl_bolts/datamodules/cifar10_datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index cbe0333050..cdca1b61a6 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -73,7 +73,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ) -> None: + ): """ Args: data_dir: Where to save/load the data From 7eb32b8e84c547e85feb33d808e98256de9fe450 Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Mon, 21 Dec 2020 16:51:36 +0900 Subject: [PATCH 49/75] Add `None` return type --- pl_bolts/datamodules/binary_mnist_datamodule.py | 2 +- pl_bolts/datamodules/cifar10_datamodule.py | 4 ++-- pl_bolts/datamodules/fashion_mnist_datamodule.py | 2 +- pl_bolts/datamodules/mnist_datamodule.py | 2 +- pl_bolts/datamodules/sklearn_datamodule.py | 2 +- pl_bolts/datamodules/vision_datamodule.py | 6 +++--- 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py index 88ff0b359d..cdd07ff1e2 100644 --- a/pl_bolts/datamodules/binary_mnist_datamodule.py +++ b/pl_bolts/datamodules/binary_mnist_datamodule.py @@ -58,7 +58,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: Where to save/load the data diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index cdca1b61a6..9dbf10b670 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -73,7 +73,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: Where to save/load the data @@ -155,7 +155,7 @@ def __init__( labels: Optional[Sequence] = (1, 5, 8), *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: where to save/load the data diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py index 8de6c99bd7..b31a5aa792 100644 --- a/pl_bolts/datamodules/fashion_mnist_datamodule.py +++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py @@ -59,7 +59,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: Where to save/load the data diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py index 5013b8b6b5..d52315f41c 100644 --- a/pl_bolts/datamodules/mnist_datamodule.py +++ b/pl_bolts/datamodules/mnist_datamodule.py @@ -58,7 +58,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: Where to save/load the data diff --git a/pl_bolts/datamodules/sklearn_datamodule.py b/pl_bolts/datamodules/sklearn_datamodule.py index ae8ad70d52..be517f47bc 100644 --- a/pl_bolts/datamodules/sklearn_datamodule.py +++ b/pl_bolts/datamodules/sklearn_datamodule.py @@ -209,7 +209,7 @@ def _init_datasets( x_val: np.ndarray, y_val: np.ndarray, x_test: np.ndarray, -<<<<<<< HEAD + y_test: np.ndarray ) -> None: self.train_dataset = SklearnDataset(X, y) self.val_dataset = SklearnDataset(x_val, y_val) diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py index ba95c26a20..bab4c722dc 100644 --- a/pl_bolts/datamodules/vision_datamodule.py +++ b/pl_bolts/datamodules/vision_datamodule.py @@ -36,7 +36,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: Where to save/load the data @@ -63,14 +63,14 @@ def __init__( self.pin_memory = pin_memory self.drop_last = drop_last - def prepare_data(self): + def prepare_data(self) -> None: """ Saves files to data_dir """ self.dataset_cls(self.data_dir, train=True, download=True) self.dataset_cls(self.data_dir, train=False, download=True) - def setup(self, stage: Optional[str] = None): + def setup(self, stage: Optional[str] = None) -> None: """ Creates train, val, and test dataset """ From 47cca3231091322c155c49369f21a02fba5ecbbb Mon Sep 17 00:00:00 2001 From: Jirka Borovec <jirka.borovec@seznam.cz> Date: Tue, 5 Jan 2021 14:04:40 +0100 Subject: [PATCH 50/75] enable check --- setup.cfg | 3 --- 1 file changed, 3 deletions(-) diff --git a/setup.cfg b/setup.cfg index 080004f375..12525b46ae 100644 --- a/setup.cfg +++ b/setup.cfg @@ -72,9 +72,6 @@ show_error_codes = True disallow_untyped_defs = True ignore_missing_imports = True -[mypy-pl_bolts.datamodules.*] -ignore_errors = True - [mypy-pl_bolts.datasets.*] ignore_errors = True From 52e48113aa4fb16d44e2a524869f3c187b51b63f Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Fri, 18 Dec 2020 19:58:16 +0900 Subject: [PATCH 51/75] Adding types to datamodules --- pl_bolts/datamodules/binary_mnist_datamodule.py | 2 +- pl_bolts/datamodules/fashion_mnist_datamodule.py | 2 +- pl_bolts/datamodules/mnist_datamodule.py | 2 +- pl_bolts/datamodules/vision_datamodule.py | 6 +++--- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py index cdd07ff1e2..88ff0b359d 100644 --- a/pl_bolts/datamodules/binary_mnist_datamodule.py +++ b/pl_bolts/datamodules/binary_mnist_datamodule.py @@ -58,7 +58,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ) -> None: + ): """ Args: data_dir: Where to save/load the data diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py index b31a5aa792..8de6c99bd7 100644 --- a/pl_bolts/datamodules/fashion_mnist_datamodule.py +++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py @@ -59,7 +59,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ) -> None: + ): """ Args: data_dir: Where to save/load the data diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py index d52315f41c..5013b8b6b5 100644 --- a/pl_bolts/datamodules/mnist_datamodule.py +++ b/pl_bolts/datamodules/mnist_datamodule.py @@ -58,7 +58,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ) -> None: + ): """ Args: data_dir: Where to save/load the data diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py index bab4c722dc..ba95c26a20 100644 --- a/pl_bolts/datamodules/vision_datamodule.py +++ b/pl_bolts/datamodules/vision_datamodule.py @@ -36,7 +36,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ) -> None: + ): """ Args: data_dir: Where to save/load the data @@ -63,14 +63,14 @@ def __init__( self.pin_memory = pin_memory self.drop_last = drop_last - def prepare_data(self) -> None: + def prepare_data(self): """ Saves files to data_dir """ self.dataset_cls(self.data_dir, train=True, download=True) self.dataset_cls(self.data_dir, train=False, download=True) - def setup(self, stage: Optional[str] = None) -> None: + def setup(self, stage: Optional[str] = None): """ Creates train, val, and test dataset """ From 64f871edf3e27f7c2b048061ed9895c0b038135d Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Fri, 18 Dec 2020 20:20:20 +0900 Subject: [PATCH 52/75] Fixing typing imports --- pl_bolts/datamodules/vocdetection_datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py index be91eac99b..9c6e4b501b 100644 --- a/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -199,7 +199,7 @@ def val_dataloader(self, batch_size: int = 1, transforms=None) -> DataLoader: ) return loader - def _default_transforms(self): + def _default_transforms(self) -> T.Compose: if self.normalize: return ( lambda image, target: ( From bf6ee1cbadd30fa48b117c9d737483445636818b Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Sun, 20 Dec 2020 03:03:00 +0900 Subject: [PATCH 53/75] Removing torchvision.transforms from return typing --- pl_bolts/datamodules/vocdetection_datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py index 9c6e4b501b..be91eac99b 100644 --- a/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -199,7 +199,7 @@ def val_dataloader(self, batch_size: int = 1, transforms=None) -> DataLoader: ) return loader - def _default_transforms(self) -> T.Compose: + def _default_transforms(self): if self.normalize: return ( lambda image, target: ( From b5c6dccf886d7b869f08c8a8c24e42b895ab9298 Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Mon, 21 Dec 2020 10:13:11 +0900 Subject: [PATCH 54/75] Removing return typing --- pl_bolts/datamodules/cifar10_datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index 9dbf10b670..6c15801aae 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -73,7 +73,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ) -> None: + ): """ Args: data_dir: Where to save/load the data From 984a9629db081607e65e628c7d9f35d63befb46a Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Mon, 21 Dec 2020 16:51:36 +0900 Subject: [PATCH 55/75] Add `None` return type --- pl_bolts/datamodules/binary_mnist_datamodule.py | 2 +- pl_bolts/datamodules/cifar10_datamodule.py | 2 +- pl_bolts/datamodules/fashion_mnist_datamodule.py | 2 +- pl_bolts/datamodules/mnist_datamodule.py | 2 +- pl_bolts/datamodules/vision_datamodule.py | 6 +++--- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py index 88ff0b359d..cdd07ff1e2 100644 --- a/pl_bolts/datamodules/binary_mnist_datamodule.py +++ b/pl_bolts/datamodules/binary_mnist_datamodule.py @@ -58,7 +58,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: Where to save/load the data diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index 6c15801aae..9dbf10b670 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -73,7 +73,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: Where to save/load the data diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py index 8de6c99bd7..b31a5aa792 100644 --- a/pl_bolts/datamodules/fashion_mnist_datamodule.py +++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py @@ -59,7 +59,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: Where to save/load the data diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py index 5013b8b6b5..d52315f41c 100644 --- a/pl_bolts/datamodules/mnist_datamodule.py +++ b/pl_bolts/datamodules/mnist_datamodule.py @@ -58,7 +58,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: Where to save/load the data diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py index ba95c26a20..bab4c722dc 100644 --- a/pl_bolts/datamodules/vision_datamodule.py +++ b/pl_bolts/datamodules/vision_datamodule.py @@ -36,7 +36,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: Where to save/load the data @@ -63,14 +63,14 @@ def __init__( self.pin_memory = pin_memory self.drop_last = drop_last - def prepare_data(self): + def prepare_data(self) -> None: """ Saves files to data_dir """ self.dataset_cls(self.data_dir, train=True, download=True) self.dataset_cls(self.data_dir, train=False, download=True) - def setup(self, stage: Optional[str] = None): + def setup(self, stage: Optional[str] = None) -> None: """ Creates train, val, and test dataset """ From 7894471f3038f08df28578c6e92f94d4e6e308f6 Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Fri, 18 Dec 2020 19:58:16 +0900 Subject: [PATCH 56/75] Adding types to datamodules --- pl_bolts/datamodules/binary_mnist_datamodule.py | 2 +- pl_bolts/datamodules/cifar10_datamodule.py | 2 +- pl_bolts/datamodules/fashion_mnist_datamodule.py | 2 +- pl_bolts/datamodules/mnist_datamodule.py | 2 +- pl_bolts/datamodules/vision_datamodule.py | 6 +++--- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py index cdd07ff1e2..88ff0b359d 100644 --- a/pl_bolts/datamodules/binary_mnist_datamodule.py +++ b/pl_bolts/datamodules/binary_mnist_datamodule.py @@ -58,7 +58,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ) -> None: + ): """ Args: data_dir: Where to save/load the data diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index 9dbf10b670..cbe0333050 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -155,7 +155,7 @@ def __init__( labels: Optional[Sequence] = (1, 5, 8), *args: Any, **kwargs: Any, - ) -> None: + ): """ Args: data_dir: where to save/load the data diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py index b31a5aa792..8de6c99bd7 100644 --- a/pl_bolts/datamodules/fashion_mnist_datamodule.py +++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py @@ -59,7 +59,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ) -> None: + ): """ Args: data_dir: Where to save/load the data diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py index d52315f41c..5013b8b6b5 100644 --- a/pl_bolts/datamodules/mnist_datamodule.py +++ b/pl_bolts/datamodules/mnist_datamodule.py @@ -58,7 +58,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ) -> None: + ): """ Args: data_dir: Where to save/load the data diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py index bab4c722dc..ba95c26a20 100644 --- a/pl_bolts/datamodules/vision_datamodule.py +++ b/pl_bolts/datamodules/vision_datamodule.py @@ -36,7 +36,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ) -> None: + ): """ Args: data_dir: Where to save/load the data @@ -63,14 +63,14 @@ def __init__( self.pin_memory = pin_memory self.drop_last = drop_last - def prepare_data(self) -> None: + def prepare_data(self): """ Saves files to data_dir """ self.dataset_cls(self.data_dir, train=True, download=True) self.dataset_cls(self.data_dir, train=False, download=True) - def setup(self, stage: Optional[str] = None) -> None: + def setup(self, stage: Optional[str] = None): """ Creates train, val, and test dataset """ From 3443883e97cd37047ae258b851dd3eec7764416c Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Fri, 18 Dec 2020 20:20:20 +0900 Subject: [PATCH 57/75] Fixing typing imports --- pl_bolts/datamodules/vocdetection_datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py index be91eac99b..9c6e4b501b 100644 --- a/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -199,7 +199,7 @@ def val_dataloader(self, batch_size: int = 1, transforms=None) -> DataLoader: ) return loader - def _default_transforms(self): + def _default_transforms(self) -> T.Compose: if self.normalize: return ( lambda image, target: ( From 51f8f167ee745ec81af993408683642fa865e22b Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Sun, 20 Dec 2020 03:03:00 +0900 Subject: [PATCH 58/75] Removing torchvision.transforms from return typing --- pl_bolts/datamodules/vocdetection_datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py index 9c6e4b501b..be91eac99b 100644 --- a/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -199,7 +199,7 @@ def val_dataloader(self, batch_size: int = 1, transforms=None) -> DataLoader: ) return loader - def _default_transforms(self) -> T.Compose: + def _default_transforms(self): if self.normalize: return ( lambda image, target: ( From 3062dba63676eba615c569d834177bca6df1e3ff Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Mon, 21 Dec 2020 10:13:11 +0900 Subject: [PATCH 59/75] Removing return typing --- pl_bolts/datamodules/cifar10_datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index cbe0333050..cdca1b61a6 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -73,7 +73,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ) -> None: + ): """ Args: data_dir: Where to save/load the data From 53ebe33ff4f30936335b96421afe8606a7a630a9 Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Mon, 21 Dec 2020 16:51:36 +0900 Subject: [PATCH 60/75] Add `None` return type --- pl_bolts/datamodules/binary_mnist_datamodule.py | 2 +- pl_bolts/datamodules/cifar10_datamodule.py | 4 ++-- pl_bolts/datamodules/fashion_mnist_datamodule.py | 2 +- pl_bolts/datamodules/mnist_datamodule.py | 2 +- pl_bolts/datamodules/vision_datamodule.py | 6 +++--- 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py index 88ff0b359d..cdd07ff1e2 100644 --- a/pl_bolts/datamodules/binary_mnist_datamodule.py +++ b/pl_bolts/datamodules/binary_mnist_datamodule.py @@ -58,7 +58,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: Where to save/load the data diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index cdca1b61a6..9dbf10b670 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -73,7 +73,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: Where to save/load the data @@ -155,7 +155,7 @@ def __init__( labels: Optional[Sequence] = (1, 5, 8), *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: where to save/load the data diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py index 8de6c99bd7..b31a5aa792 100644 --- a/pl_bolts/datamodules/fashion_mnist_datamodule.py +++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py @@ -59,7 +59,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: Where to save/load the data diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py index 5013b8b6b5..d52315f41c 100644 --- a/pl_bolts/datamodules/mnist_datamodule.py +++ b/pl_bolts/datamodules/mnist_datamodule.py @@ -58,7 +58,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: Where to save/load the data diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py index ba95c26a20..bab4c722dc 100644 --- a/pl_bolts/datamodules/vision_datamodule.py +++ b/pl_bolts/datamodules/vision_datamodule.py @@ -36,7 +36,7 @@ def __init__( drop_last: bool = False, *args: Any, **kwargs: Any, - ): + ) -> None: """ Args: data_dir: Where to save/load the data @@ -63,14 +63,14 @@ def __init__( self.pin_memory = pin_memory self.drop_last = drop_last - def prepare_data(self): + def prepare_data(self) -> None: """ Saves files to data_dir """ self.dataset_cls(self.data_dir, train=True, download=True) self.dataset_cls(self.data_dir, train=False, download=True) - def setup(self, stage: Optional[str] = None): + def setup(self, stage: Optional[str] = None) -> None: """ Creates train, val, and test dataset """ From c15efdb5959b510e965ec742f6469ae18c016a70 Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Tue, 12 Jan 2021 19:52:18 +0900 Subject: [PATCH 61/75] Fix rebasing mistakes --- pl_bolts/datamodules/cityscapes_datamodule.py | 4 ++-- pl_bolts/datamodules/imagenet_datamodule.py | 4 ++-- pl_bolts/datamodules/kitti_datamodule.py | 4 ++-- pl_bolts/datamodules/ssl_imagenet_datamodule.py | 8 ++++---- pl_bolts/datamodules/stl10_datamodule.py | 4 ++-- 5 files changed, 12 insertions(+), 12 deletions(-) diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py index 3721f897af..41465840dc 100644 --- a/pl_bolts/datamodules/cityscapes_datamodule.py +++ b/pl_bolts/datamodules/cityscapes_datamodule.py @@ -73,8 +73,8 @@ def __init__( shuffle: bool = False, pin_memory: bool = False, drop_last: bool = False, - *args, - **kwargs, + *args: Any, + **kwargs: Any, ) -> None: """ Args: diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py index b63611b060..301894c074 100644 --- a/pl_bolts/datamodules/imagenet_datamodule.py +++ b/pl_bolts/datamodules/imagenet_datamodule.py @@ -60,8 +60,8 @@ def __init__( shuffle: bool = False, pin_memory: bool = False, drop_last: bool = False, - *args, - **kwargs, + *args: Any, + **kwargs: Any, ) -> None: """ Args: diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py index cab7529e91..461b2244a8 100644 --- a/pl_bolts/datamodules/kitti_datamodule.py +++ b/pl_bolts/datamodules/kitti_datamodule.py @@ -33,8 +33,8 @@ def __init__( shuffle: bool = False, pin_memory: bool = False, drop_last: bool = False, - *args, - **kwargs, + *args: Any, + **kwargs: Any, ) -> None: """ Kitti train, validation and test dataloaders. diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py index 116a1f1614..4ab0bcbd9b 100644 --- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py +++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py @@ -23,15 +23,15 @@ class SSLImagenetDataModule(LightningDataModule): # pragma: no cover def __init__( self, - data_dir, - meta_dir=None, + data_dir: str, + meta_dir: Optional[str] = None, num_workers=16, batch_size: int = 32, shuffle: bool = False, pin_memory: bool = False, drop_last: bool = False, - *args, - **kwargs, + *args: Any, + **kwargs: Any, ) -> None: super().__init__(*args, **kwargs) diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py index 133f3ed6f8..a32c89e5be 100644 --- a/pl_bolts/datamodules/stl10_datamodule.py +++ b/pl_bolts/datamodules/stl10_datamodule.py @@ -65,8 +65,8 @@ def __init__( shuffle: bool = False, pin_memory: bool = False, drop_last: bool = False, - *args, - **kwargs, + *args: Any, + **kwargs: Any, ) -> None: """ Args: From 7bc0c370b302b3f18e772c6a2f0eb39679ef881d Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Tue, 12 Jan 2021 19:54:42 +0900 Subject: [PATCH 62/75] Fix flake8 --- pl_bolts/datamodules/kitti_datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py index 461b2244a8..df07852085 100644 --- a/pl_bolts/datamodules/kitti_datamodule.py +++ b/pl_bolts/datamodules/kitti_datamodule.py @@ -24,7 +24,7 @@ class KittiDataModule(LightningDataModule): def __init__( self, - data_dir: str, + data_dir: Optional[str] = None, val_split: float = 0.2, test_split: float = 0.1, num_workers: int = 16, From a5f3e4f98d4d56d2c41f3e93da42e66b1e12e7ce Mon Sep 17 00:00:00 2001 From: Brian Ko <briankosw@gmail.com> Date: Tue, 12 Jan 2021 19:56:55 +0900 Subject: [PATCH 63/75] Fix yapf format --- pl_bolts/datamodules/async_dataloader.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pl_bolts/datamodules/async_dataloader.py b/pl_bolts/datamodules/async_dataloader.py index 224f34d5ee..72d823c9ae 100644 --- a/pl_bolts/datamodules/async_dataloader.py +++ b/pl_bolts/datamodules/async_dataloader.py @@ -28,12 +28,12 @@ class AsynchronousLoader(object): """ def __init__( - self, - data: Union[DataLoader, Dataset], - device: torch.device = torch.device('cuda', 0), - q_size: int = 10, - num_batches: Optional[int] = None, - **kwargs: Any + self, + data: Union[DataLoader, Dataset], + device: torch.device = torch.device('cuda', 0), + q_size: int = 10, + num_batches: Optional[int] = None, + **kwargs: Any ) -> None: if isinstance(data, torch.utils.data.DataLoader): self.dataloader = data From b9c910dde176d02c48e5f0e0f5a7a4e61b1d8cf3 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta <nitta@akihironitta.com> Date: Wed, 20 Jan 2021 20:11:53 +0900 Subject: [PATCH 64/75] Add types and skip mypy checks on some files --- pl_bolts/callbacks/byol_updates.py | 4 ++-- pl_bolts/callbacks/variational.py | 5 +++-- pl_bolts/datamodules/async_dataloader.py | 13 +++++++------ pl_bolts/datamodules/binary_mnist_datamodule.py | 2 +- pl_bolts/datamodules/cifar10_datamodule.py | 4 ++-- pl_bolts/datamodules/cityscapes_datamodule.py | 3 ++- pl_bolts/datamodules/experience_source.py | 2 +- pl_bolts/datamodules/fashion_mnist_datamodule.py | 2 +- pl_bolts/datamodules/imagenet_datamodule.py | 7 ++++--- pl_bolts/datamodules/kitti_datamodule.py | 1 + pl_bolts/datamodules/mnist_datamodule.py | 2 +- pl_bolts/datamodules/ssl_imagenet_datamodule.py | 9 +++++---- pl_bolts/datamodules/stl10_datamodule.py | 3 ++- pl_bolts/datamodules/vision_datamodule.py | 16 ++++++++-------- pl_bolts/datamodules/vocdetection_datamodule.py | 16 ++++++++-------- setup.cfg | 13 +++++++++++++ 16 files changed, 61 insertions(+), 41 deletions(-) diff --git a/pl_bolts/callbacks/byol_updates.py b/pl_bolts/callbacks/byol_updates.py index 8f47815521..2b4a1953f8 100644 --- a/pl_bolts/callbacks/byol_updates.py +++ b/pl_bolts/callbacks/byol_updates.py @@ -66,7 +66,7 @@ def update_tau(self, pl_module: LightningModule, trainer: Trainer) -> float: def update_weights(self, online_net: Union[Module, Tensor], target_net: Union[Module, Tensor]) -> None: # apply MA weight update for (name, online_p), (_, target_p) in zip( - online_net.named_parameters(), target_net.named_parameters() - ): # type: ignore[union-attr] + online_net.named_parameters(), target_net.named_parameters() # type: ignore[union-attr] + ): if 'weight' in name: target_p.data = self.current_tau * target_p.data + (1 - self.current_tau) * online_p.data diff --git a/pl_bolts/callbacks/variational.py b/pl_bolts/callbacks/variational.py index 5947f40be8..4d5f4c6e23 100644 --- a/pl_bolts/callbacks/variational.py +++ b/pl_bolts/callbacks/variational.py @@ -62,8 +62,9 @@ def __init__( def on_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: if (trainer.current_epoch + 1) % self.interpolate_epoch_interval == 0: images = self.interpolate_latent_space( - pl_module, latent_dim=pl_module.hparams.latent_dim - ) # type: ignore[union-attr] + pl_module, + latent_dim=pl_module.hparams.latent_dim # type: ignore[union-attr] + ) images = torch.cat(images, dim=0) # type: ignore[assignment] num_images = (self.range_end - self.range_start)**2 diff --git a/pl_bolts/datamodules/async_dataloader.py b/pl_bolts/datamodules/async_dataloader.py index 72d823c9ae..137429dc51 100644 --- a/pl_bolts/datamodules/async_dataloader.py +++ b/pl_bolts/datamodules/async_dataloader.py @@ -3,6 +3,7 @@ from threading import Thread from typing import Any, Optional, Union +import numpy as np import torch from torch._six import container_abcs, string_classes from torch.utils.data import DataLoader, Dataset @@ -33,7 +34,7 @@ def __init__( device: torch.device = torch.device('cuda', 0), q_size: int = 10, num_batches: Optional[int] = None, - **kwargs: Any + **kwargs: Any, ) -> None: if isinstance(data, torch.utils.data.DataLoader): self.dataloader = data @@ -51,7 +52,7 @@ def __init__( self.q_size = q_size self.load_stream = torch.cuda.Stream(device=device) - self.queue = Queue(maxsize=self.q_size) + self.queue: Queue = Queue(maxsize=self.q_size) self.idx = 0 @@ -64,7 +65,7 @@ def load_loop(self) -> None: # The loop that will load into the queue in the ba break # Recursive loading for each instance based on torch.utils.data.default_collate - def load_instance(self, sample): + def load_instance(self, sample: Any) -> Any: elem_type = type(sample) if torch.is_tensor(sample): @@ -88,16 +89,16 @@ def load_instance(self, sample): else: return sample - def __iter__(self): + def __iter__(self) -> "AsynchronousLoader": # We don't want to run the thread more than once # Start a new thread if we are at the beginning of a new epoch, and our current worker is dead - if (not hasattr(self, 'worker') or not self.worker.is_alive()) and self.queue.empty() and self.idx == 0: + if (not hasattr(self, 'worker') or not self.worker.is_alive()) and self.queue.empty() and self.idx == 0: # type: ignore[has-type] self.worker = Thread(target=self.load_loop) self.worker.daemon = True self.worker.start() return self - def __next__(self): + def __next__(self) -> torch.Tensor: # If we've reached the number of batches to return # or the queue is empty and the worker is dead then exit done = not self.worker.is_alive() and self.queue.empty() diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py index a72c02b53a..7b1b963f49 100644 --- a/pl_bolts/datamodules/binary_mnist_datamodule.py +++ b/pl_bolts/datamodules/binary_mnist_datamodule.py @@ -78,7 +78,7 @@ def __init__( "You want to use transforms loaded from `torchvision` which is not installed yet." ) - super().__init__( + super().__init__( # type: ignore[misc] data_dir=data_dir, val_split=val_split, num_workers=num_workers, diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index d7e3ed2d7a..1c17658a30 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -87,7 +87,7 @@ def __init__( returning them drop_last: If true drops the last incomplete batch """ - super().__init__( + super().__init__( # type: ignore[misc] data_dir=data_dir, val_split=val_split, num_workers=num_workers, @@ -166,7 +166,7 @@ def __init__( """ super().__init__(data_dir, val_split, num_workers, *args, **kwargs) - self.num_samples = num_samples + self.num_samples = num_samples # type: ignore[misc] self.labels = sorted(labels) if labels is not None else set(range(10)) self.extra_args = dict(num_samples=self.num_samples, labels=self.labels) diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py index eaf7cd5fa3..462d4ca982 100644 --- a/pl_bolts/datamodules/cityscapes_datamodule.py +++ b/pl_bolts/datamodules/cityscapes_datamodule.py @@ -1,3 +1,4 @@ +# type: ignore[override] from typing import Any from pytorch_lightning import LightningDataModule @@ -60,7 +61,7 @@ class CityscapesDataModule(LightningDataModule): """ name = 'Cityscapes' - extra_args = {} + extra_args: dict = {} def __init__( self, diff --git a/pl_bolts/datamodules/experience_source.py b/pl_bolts/datamodules/experience_source.py index 1bc7b0f8a8..50ed2a6a7b 100644 --- a/pl_bolts/datamodules/experience_source.py +++ b/pl_bolts/datamodules/experience_source.py @@ -299,5 +299,5 @@ def discount_rewards(self, experiences: Tuple[Experience]) -> float: """ total_reward = 0.0 for exp in reversed(experiences): - total_reward = (self.gamma * total_reward) + exp.reward + total_reward = (self.gamma * total_reward) + exp.reward # type: ignore[attr-defined] return total_reward diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py index eabecdbf60..b209e09f37 100644 --- a/pl_bolts/datamodules/fashion_mnist_datamodule.py +++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py @@ -78,7 +78,7 @@ def __init__( 'You want to use FashionMNIST dataset loaded from `torchvision` which is not installed yet.' ) - super().__init__( + super().__init__( # type: ignore[misc] data_dir=data_dir, val_split=val_split, num_workers=num_workers, diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py index 41ac3d42db..ab8e8b9921 100644 --- a/pl_bolts/datamodules/imagenet_datamodule.py +++ b/pl_bolts/datamodules/imagenet_datamodule.py @@ -1,3 +1,4 @@ +# type: ignore[override] import os from typing import Any, Optional @@ -158,7 +159,7 @@ def train_dataloader(self) -> DataLoader: split='train', transform=transforms ) - loader = DataLoader( + loader: DataLoader = DataLoader( dataset, batch_size=self.batch_size, shuffle=self.shuffle, @@ -185,7 +186,7 @@ def val_dataloader(self) -> DataLoader: split='val', transform=transforms ) - loader = DataLoader( + loader: DataLoader = DataLoader( dataset, batch_size=self.batch_size, shuffle=False, @@ -204,7 +205,7 @@ def test_dataloader(self) -> DataLoader: dataset = UnlabeledImagenet( self.data_dir, num_imgs_per_class=-1, meta_dir=self.meta_dir, split='test', transform=transforms ) - loader = DataLoader( + loader: DataLoader = DataLoader( dataset, batch_size=self.batch_size, shuffle=False, diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py index 34d64c3c00..893f224047 100644 --- a/pl_bolts/datamodules/kitti_datamodule.py +++ b/pl_bolts/datamodules/kitti_datamodule.py @@ -1,3 +1,4 @@ +# type: ignore[override] import os from typing import Any, Optional diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py index 4f8e9e19a0..c813cee685 100644 --- a/pl_bolts/datamodules/mnist_datamodule.py +++ b/pl_bolts/datamodules/mnist_datamodule.py @@ -77,7 +77,7 @@ def __init__( 'You want to use MNIST dataset loaded from `torchvision` which is not installed yet.' ) - super().__init__( + super().__init__( # type: ignore[misc] data_dir=data_dir, val_split=val_split, num_workers=num_workers, diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py index c8b05b904a..c21d1af46b 100644 --- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py +++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py @@ -1,3 +1,4 @@ +# type: ignore[override] import os from typing import Any, Optional @@ -25,7 +26,7 @@ def __init__( self, data_dir: str, meta_dir: Optional[str] = None, - num_workers=16, + num_workers: int = 16, batch_size: int = 32, shuffle: bool = False, pin_memory: bool = False, @@ -96,7 +97,7 @@ def train_dataloader(self, num_images_per_class: int = -1, add_normalize: bool = split='train', transform=transforms ) - loader = DataLoader( + loader: DataLoader = DataLoader( dataset, batch_size=self.batch_size, shuffle=self.shuffle, @@ -116,7 +117,7 @@ def val_dataloader(self, num_images_per_class: int = 50, add_normalize: bool = F split='val', transform=transforms ) - loader = DataLoader( + loader: DataLoader = DataLoader( dataset, batch_size=self.batch_size, shuffle=False, @@ -136,7 +137,7 @@ def test_dataloader(self, num_images_per_class: int, add_normalize: bool = False split='test', transform=transforms ) - loader = DataLoader( + loader: DataLoader = DataLoader( dataset, batch_size=self.batch_size, shuffle=False, diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py index 5aafa1380b..0433278af6 100644 --- a/pl_bolts/datamodules/stl10_datamodule.py +++ b/pl_bolts/datamodules/stl10_datamodule.py @@ -1,3 +1,4 @@ +# type: ignore[override] import os from typing import Any, Optional @@ -194,7 +195,7 @@ def val_dataloader(self) -> DataLoader: batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, - drpo_last=self.drop_last, + drop_last=self.drop_last, pin_memory=self.pin_memory ) return loader diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py index bab4c722dc..73ba424c96 100644 --- a/pl_bolts/datamodules/vision_datamodule.py +++ b/pl_bolts/datamodules/vision_datamodule.py @@ -16,12 +16,12 @@ class VisionDataModule(LightningDataModule): - EXTRA_ARGS = {} + EXTRA_ARGS: dict = {} name: str = "" #: Dataset class to use - dataset_cls = ... + dataset_cls: type #: A tuple describing the shape of the data - dims: tuple = ... + dims: tuple def __init__( self, @@ -63,7 +63,7 @@ def __init__( self.pin_memory = pin_memory self.drop_last = drop_last - def prepare_data(self) -> None: + def prepare_data(self, *args: Any, **kwargs: Any) -> None: """ Saves files to data_dir """ @@ -95,7 +95,7 @@ def _split_dataset(self, dataset: Dataset, train: bool = True) -> Dataset: """ Splits the dataset into train and validation set """ - len_dataset = len(dataset) + len_dataset = len(dataset) # type: ignore[arg-type] splits = self._get_splits(len_dataset) dataset_train, dataset_val = random_split(dataset, splits, generator=torch.Generator().manual_seed(self.seed)) @@ -123,15 +123,15 @@ def _get_splits(self, len_dataset: int) -> List[int]: def default_transforms(self) -> Compose: """ Default transform for the dataset """ - def train_dataloader(self) -> DataLoader: + def train_dataloader(self, *args: Any, **kwargs: Any) -> DataLoader: """ The train dataloader """ return self._data_loader(self.dataset_train, shuffle=self.shuffle) - def val_dataloader(self) -> DataLoader: + def val_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]: """ The val dataloader """ return self._data_loader(self.dataset_val) - def test_dataloader(self) -> DataLoader: + def test_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]: """ The test dataloader """ return self._data_loader(self.dataset_test) diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py index c3ba78452c..c704540312 100644 --- a/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, List, Callable, Tuple, Optional import torch from pytorch_lightning import LightningDataModule @@ -19,11 +19,11 @@ class Compose(object): Like `torchvision.transforms.compose` but works for (image, target) """ - def __init__(self, transforms, image_transforms=None): + def __init__(self, transforms: List[Callable], image_transforms: Optional[Callable] = None) -> None: self.transforms = transforms self.image_transforms = image_transforms - def __call__(self, image, target): + def __call__(self, image: Any, target: Any) -> Tuple[torch.Tensor, torch.Tensor]: for t in self.transforms: image, target = t(image, target) if self.image_transforms: @@ -31,7 +31,7 @@ def __call__(self, image, target): return image, target -def _collate_fn(batch): +def _collate_fn(batch: List[torch.Tensor]) -> tuple: return tuple(zip(*batch)) @@ -60,7 +60,7 @@ def _collate_fn(batch): ) -def _prepare_voc_instance(image, target: Dict[str, Any]): +def _prepare_voc_instance(image: Any, target: Dict[str, Any]): """ Prepares VOC dataset into appropriate target for fasterrcnn @@ -151,7 +151,7 @@ def prepare_data(self) -> None: VOCDetection(self.data_dir, year=self.year, image_set="train", download=True) VOCDetection(self.data_dir, year=self.year, image_set="val", download=True) - def train_dataloader(self, batch_size: int = 1, image_transforms=None) -> DataLoader: + def train_dataloader(self, batch_size: int = 1, image_transforms: Union[List[Callable], Callable]=None) -> DataLoader: """ VOCDetection train set uses the `train` subset @@ -174,7 +174,7 @@ def train_dataloader(self, batch_size: int = 1, image_transforms=None) -> DataLo ) return loader - def val_dataloader(self, batch_size: int = 1, image_transforms=None) -> DataLoader: + def val_dataloader(self, batch_size: int = 1, image_transforms: Optional[List[Callable]] = None) -> DataLoader: """ VOCDetection val set uses the `val` subset @@ -197,7 +197,7 @@ def val_dataloader(self, batch_size: int = 1, image_transforms=None) -> DataLoad ) return loader - def _default_transforms(self): + def _default_transforms(self) -> transform_lib.Compose: if self.normalize: voc_transforms = transform_lib.Compose([ transform_lib.ToTensor(), diff --git a/setup.cfg b/setup.cfg index 12525b46ae..957564a046 100644 --- a/setup.cfg +++ b/setup.cfg @@ -75,6 +75,19 @@ ignore_missing_imports = True [mypy-pl_bolts.datasets.*] ignore_errors = True +[mypy-pl_bolts.datamodules] + # pl_bolts/datamodules/__init__.py + ignore_errors = True + +[mypy-pl_bolts.datamodules.experience_source] +ignore_errors = True + +[mypy-pl_bolts.datamodules.sklearn_datamodule] +ignore_errors = True + +[mypy-pl_bolts.datamodules.vocdetection_datamodule] +ignore_errors = True + [mypy-pl_bolts.losses.*] ignore_errors = True From 9e222d02f5a2f6721ca5b932d717061b3a23b1d0 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta <nitta@akihironitta.com> Date: Wed, 20 Jan 2021 20:22:57 +0900 Subject: [PATCH 65/75] Fix setup.cfg --- setup.cfg | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index 957564a046..5883253ce5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -76,8 +76,8 @@ ignore_missing_imports = True ignore_errors = True [mypy-pl_bolts.datamodules] - # pl_bolts/datamodules/__init__.py - ignore_errors = True +# pl_bolts/datamodules/__init__.py +ignore_errors = True [mypy-pl_bolts.datamodules.experience_source] ignore_errors = True From 0c54fdd87906a518c7d836bb2478496e57682995 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta <nitta@akihironitta.com> Date: Wed, 20 Jan 2021 20:27:24 +0900 Subject: [PATCH 66/75] Add missing import --- pl_bolts/datamodules/vocdetection_datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py index c704540312..eec63f1891 100644 --- a/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Callable, Tuple, Optional +from typing import Any, Dict, List, Callable, Tuple, Optional, Union import torch from pytorch_lightning import LightningDataModule From 8b2e1964af51c771af83002e58539d970646d3e1 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta <nitta@akihironitta.com> Date: Wed, 20 Jan 2021 20:27:35 +0900 Subject: [PATCH 67/75] isort --- pl_bolts/datamodules/vocdetection_datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py index eec63f1891..3890ab640d 100644 --- a/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Callable, Tuple, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from pytorch_lightning import LightningDataModule From 9c5dd5cafc49d0732143babb0f8f198d07c75a00 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta <nitta@akihironitta.com> Date: Wed, 20 Jan 2021 20:34:32 +0900 Subject: [PATCH 68/75] yapf --- pl_bolts/callbacks/byol_updates.py | 3 ++- pl_bolts/datamodules/async_dataloader.py | 3 ++- pl_bolts/datamodules/vocdetection_datamodule.py | 4 +++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/pl_bolts/callbacks/byol_updates.py b/pl_bolts/callbacks/byol_updates.py index 2b4a1953f8..2918d2aa75 100644 --- a/pl_bolts/callbacks/byol_updates.py +++ b/pl_bolts/callbacks/byol_updates.py @@ -66,7 +66,8 @@ def update_tau(self, pl_module: LightningModule, trainer: Trainer) -> float: def update_weights(self, online_net: Union[Module, Tensor], target_net: Union[Module, Tensor]) -> None: # apply MA weight update for (name, online_p), (_, target_p) in zip( - online_net.named_parameters(), target_net.named_parameters() # type: ignore[union-attr] + online_net.named_parameters(), + target_net.named_parameters() # type: ignore[union-attr] ): if 'weight' in name: target_p.data = self.current_tau * target_p.data + (1 - self.current_tau) * online_p.data diff --git a/pl_bolts/datamodules/async_dataloader.py b/pl_bolts/datamodules/async_dataloader.py index 137429dc51..b8a35ffa9b 100644 --- a/pl_bolts/datamodules/async_dataloader.py +++ b/pl_bolts/datamodules/async_dataloader.py @@ -92,7 +92,8 @@ def load_instance(self, sample: Any) -> Any: def __iter__(self) -> "AsynchronousLoader": # We don't want to run the thread more than once # Start a new thread if we are at the beginning of a new epoch, and our current worker is dead - if (not hasattr(self, 'worker') or not self.worker.is_alive()) and self.queue.empty() and self.idx == 0: # type: ignore[has-type] + if (not hasattr(self, 'worker') + or not self.worker.is_alive()) and self.queue.empty() and self.idx == 0: # type: ignore[has-type] self.worker = Thread(target=self.load_loop) self.worker.daemon = True self.worker.start() diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py index 3890ab640d..f70806a7a7 100644 --- a/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -151,7 +151,9 @@ def prepare_data(self) -> None: VOCDetection(self.data_dir, year=self.year, image_set="train", download=True) VOCDetection(self.data_dir, year=self.year, image_set="val", download=True) - def train_dataloader(self, batch_size: int = 1, image_transforms: Union[List[Callable], Callable]=None) -> DataLoader: + def train_dataloader( + self, batch_size: int = 1, image_transforms: Union[List[Callable], Callable] = None + ) -> DataLoader: """ VOCDetection train set uses the `train` subset From c6c97e1e524208828ff4bd6302072657ab728c83 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta <nitta@akihironitta.com> Date: Wed, 20 Jan 2021 20:41:33 +0900 Subject: [PATCH 69/75] mypy please... --- pl_bolts/callbacks/byol_updates.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/callbacks/byol_updates.py b/pl_bolts/callbacks/byol_updates.py index 2918d2aa75..1c4d3ba7d4 100644 --- a/pl_bolts/callbacks/byol_updates.py +++ b/pl_bolts/callbacks/byol_updates.py @@ -66,7 +66,7 @@ def update_tau(self, pl_module: LightningModule, trainer: Trainer) -> float: def update_weights(self, online_net: Union[Module, Tensor], target_net: Union[Module, Tensor]) -> None: # apply MA weight update for (name, online_p), (_, target_p) in zip( - online_net.named_parameters(), + online_net.named_parameters(), # type: ignore[union-attr] target_net.named_parameters() # type: ignore[union-attr] ): if 'weight' in name: From 4ac8a5ba8d00dd8cc60a87c30f7d4910fb205f48 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta <nitta@akihironitta.com> Date: Wed, 20 Jan 2021 20:52:19 +0900 Subject: [PATCH 70/75] Please be quiet mypy and flake8 --- pl_bolts/datamodules/async_dataloader.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pl_bolts/datamodules/async_dataloader.py b/pl_bolts/datamodules/async_dataloader.py index b8a35ffa9b..6034c29d3e 100644 --- a/pl_bolts/datamodules/async_dataloader.py +++ b/pl_bolts/datamodules/async_dataloader.py @@ -3,7 +3,6 @@ from threading import Thread from typing import Any, Optional, Union -import numpy as np import torch from torch._six import container_abcs, string_classes from torch.utils.data import DataLoader, Dataset @@ -92,8 +91,7 @@ def load_instance(self, sample: Any) -> Any: def __iter__(self) -> "AsynchronousLoader": # We don't want to run the thread more than once # Start a new thread if we are at the beginning of a new epoch, and our current worker is dead - if (not hasattr(self, 'worker') - or not self.worker.is_alive()) and self.queue.empty() and self.idx == 0: # type: ignore[has-type] + if (not hasattr(self, 'worker') or not self.worker.is_alive()) and self.queue.empty() and self.idx == 0: # type: ignore[has-type] # noqa: E501 self.worker = Thread(target=self.load_loop) self.worker.daemon = True self.worker.start() From e847eadf2c47ea9560019d4782f14e1a0e087c00 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta <nitta@akihironitta.com> Date: Wed, 20 Jan 2021 20:53:53 +0900 Subject: [PATCH 71/75] yapf... --- pl_bolts/datamodules/async_dataloader.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pl_bolts/datamodules/async_dataloader.py b/pl_bolts/datamodules/async_dataloader.py index 6034c29d3e..410425e3cb 100644 --- a/pl_bolts/datamodules/async_dataloader.py +++ b/pl_bolts/datamodules/async_dataloader.py @@ -91,7 +91,8 @@ def load_instance(self, sample: Any) -> Any: def __iter__(self) -> "AsynchronousLoader": # We don't want to run the thread more than once # Start a new thread if we are at the beginning of a new epoch, and our current worker is dead - if (not hasattr(self, 'worker') or not self.worker.is_alive()) and self.queue.empty() and self.idx == 0: # type: ignore[has-type] # noqa: E501 + if (not hasattr(self, 'worker') or + not self.worker.is_alive()) and self.queue.empty() and self.idx == 0: # type: ignore[has-type] # noqa: E501 self.worker = Thread(target=self.load_loop) self.worker.daemon = True self.worker.start() From 1839438d6f194e8ec3cab5a13d9e27f3b69a8e4b Mon Sep 17 00:00:00 2001 From: Akihiro Nitta <nitta@akihironitta.com> Date: Wed, 20 Jan 2021 21:05:26 +0900 Subject: [PATCH 72/75] Disable all of yapf, flake8, and mypy --- pl_bolts/datamodules/async_dataloader.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pl_bolts/datamodules/async_dataloader.py b/pl_bolts/datamodules/async_dataloader.py index 410425e3cb..24fa820d67 100644 --- a/pl_bolts/datamodules/async_dataloader.py +++ b/pl_bolts/datamodules/async_dataloader.py @@ -91,9 +91,11 @@ def load_instance(self, sample: Any) -> Any: def __iter__(self) -> "AsynchronousLoader": # We don't want to run the thread more than once # Start a new thread if we are at the beginning of a new epoch, and our current worker is dead - if (not hasattr(self, 'worker') or - not self.worker.is_alive()) and self.queue.empty() and self.idx == 0: # type: ignore[has-type] # noqa: E501 + + # yapf: disable + if (not hasattr(self, 'worker') or not self.worker.is_alive()) and self.queue.empty() and self.idx == 0: # type: ignore[has-type] # noqa: E501 self.worker = Thread(target=self.load_loop) + # yapf: enable self.worker.daemon = True self.worker.start() return self From 097df6d3e9a8b7e4a286913066dcdc203330ef71 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta <nitta@akihironitta.com> Date: Wed, 20 Jan 2021 21:12:04 +0900 Subject: [PATCH 73/75] Use Callable --- pl_bolts/datamodules/vocdetection_datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py index f70806a7a7..97b63cc86e 100644 --- a/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -199,7 +199,7 @@ def val_dataloader(self, batch_size: int = 1, image_transforms: Optional[List[Ca ) return loader - def _default_transforms(self) -> transform_lib.Compose: + def _default_transforms(self) -> Callable: if self.normalize: voc_transforms = transform_lib.Compose([ transform_lib.ToTensor(), From 9e00c5dfd366d821ed600795c4f57e3fe9eda996 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta <nitta@akihironitta.com> Date: Wed, 20 Jan 2021 21:41:13 +0900 Subject: [PATCH 74/75] Use Callable --- pl_bolts/datamodules/binary_mnist_datamodule.py | 6 ++---- pl_bolts/datamodules/cifar10_datamodule.py | 6 ++---- pl_bolts/datamodules/cityscapes_datamodule.py | 8 +++----- pl_bolts/datamodules/fashion_mnist_datamodule.py | 6 ++---- pl_bolts/datamodules/imagenet_datamodule.py | 6 ++---- pl_bolts/datamodules/kitti_datamodule.py | 6 ++---- pl_bolts/datamodules/mnist_datamodule.py | 6 ++---- pl_bolts/datamodules/ssl_imagenet_datamodule.py | 6 ++---- pl_bolts/datamodules/stl10_datamodule.py | 4 +--- pl_bolts/datamodules/vision_datamodule.py | 11 ++--------- 10 files changed, 20 insertions(+), 45 deletions(-) diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py index 7b1b963f49..c43065984b 100644 --- a/pl_bolts/datamodules/binary_mnist_datamodule.py +++ b/pl_bolts/datamodules/binary_mnist_datamodule.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Union +from typing import Any, Callable, Optional, Union from pl_bolts.datamodules.vision_datamodule import VisionDataModule from pl_bolts.datasets.mnist_dataset import BinaryMNIST @@ -7,10 +7,8 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib - from torchvision.transforms import Compose else: # pragma: no cover warn_missing_pkg('torchvision') - Compose = object class BinaryMNISTDataModule(VisionDataModule): @@ -100,7 +98,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self) -> Compose: + def default_transforms(self) -> Callable: if self.normalize: mnist_transforms = transform_lib.Compose([ transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5, ), std=(0.5, )) diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index 1c17658a30..e54eb37deb 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Sequence, Union +from typing import Any, Callable, Optional, Sequence, Union from pl_bolts.datamodules.vision_datamodule import VisionDataModule from pl_bolts.datasets.cifar10_dataset import TrialCIFAR10 @@ -9,11 +9,9 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib from torchvision.datasets import CIFAR10 - from torchvision.transforms import Compose else: # pragma: no cover warn_missing_pkg('torchvision') CIFAR10 = None - Compose = object class CIFAR10DataModule(VisionDataModule): @@ -114,7 +112,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self) -> Compose: + def default_transforms(self) -> Callable: if self.normalize: cf10_transforms = transform_lib.Compose([transform_lib.ToTensor(), cifar10_normalization()]) else: diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py index 462d4ca982..3f1c223baf 100644 --- a/pl_bolts/datamodules/cityscapes_datamodule.py +++ b/pl_bolts/datamodules/cityscapes_datamodule.py @@ -1,5 +1,5 @@ # type: ignore[override] -from typing import Any +from typing import Any, Callable from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader @@ -10,10 +10,8 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib from torchvision.datasets import Cityscapes - from torchvision.transforms import Compose else: # pragma: no cover warn_missing_pkg('torchvision') - Compose = object class CityscapesDataModule(LightningDataModule): @@ -201,7 +199,7 @@ def test_dataloader(self) -> DataLoader: ) return loader - def _default_transforms(self) -> Compose: + def _default_transforms(self) -> Callable: cityscapes_transforms = transform_lib.Compose([ transform_lib.ToTensor(), transform_lib.Normalize( @@ -210,7 +208,7 @@ def _default_transforms(self) -> Compose: ]) return cityscapes_transforms - def _default_target_transforms(self) -> Compose: + def _default_target_transforms(self) -> Callable: cityscapes_target_trasnforms = transform_lib.Compose([ transform_lib.ToTensor(), transform_lib.Lambda(lambda t: t.squeeze()) ]) diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py index b209e09f37..f945e00912 100644 --- a/pl_bolts/datamodules/fashion_mnist_datamodule.py +++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Union +from typing import Any, Callable, Optional, Union from pl_bolts.datamodules.vision_datamodule import VisionDataModule from pl_bolts.utils import _TORCHVISION_AVAILABLE @@ -7,11 +7,9 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib from torchvision.datasets import FashionMNIST - from torchvision.transforms import Compose else: # pragma: no cover warn_missing_pkg('torchvision') FashionMNIST = None - Compose = object class FashionMNISTDataModule(VisionDataModule): @@ -100,7 +98,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self) -> Compose: + def default_transforms(self) -> Callable: if self.normalize: mnist_transforms = transform_lib.Compose([ transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5, ), std=(0.5, )) diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py index ab8e8b9921..066aa3cd0a 100644 --- a/pl_bolts/datamodules/imagenet_datamodule.py +++ b/pl_bolts/datamodules/imagenet_datamodule.py @@ -12,10 +12,8 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib - from torchvision.transforms import Compose else: # pragma: no cover warn_missing_pkg('torchvision') - Compose = object class ImagenetDataModule(LightningDataModule): @@ -215,7 +213,7 @@ def test_dataloader(self) -> DataLoader: ) return loader - def train_transform(self) -> Compose: + def train_transform(self) -> Callable: """ The standard imagenet transforms @@ -241,7 +239,7 @@ def train_transform(self) -> Compose: return preprocessing - def val_transform(self) -> Compose: + def val_transform(self) -> Callable: """ The standard imagenet transforms for validation diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py index 893f224047..cd6d198185 100644 --- a/pl_bolts/datamodules/kitti_datamodule.py +++ b/pl_bolts/datamodules/kitti_datamodule.py @@ -1,6 +1,6 @@ # type: ignore[override] import os -from typing import Any, Optional +from typing import Any, Callable, Optional import torch from pytorch_lightning import LightningDataModule @@ -13,10 +13,8 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transforms - from torchvision.transforms import Compose else: # pragma: no cover warn_missing_pkg('torchvision') - Compose = object class KittiDataModule(LightningDataModule): @@ -129,7 +127,7 @@ def test_dataloader(self) -> DataLoader: ) return loader - def _default_transforms(self) -> Compose: + def _default_transforms(self) -> Callable: kitti_transforms = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753], std=[0.32064945, 0.32098866, 0.32325324]) diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py index c813cee685..0889d71d09 100644 --- a/pl_bolts/datamodules/mnist_datamodule.py +++ b/pl_bolts/datamodules/mnist_datamodule.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Union +from typing import Any, Callable, Optional, Union from pl_bolts.datamodules.vision_datamodule import VisionDataModule from pl_bolts.utils import _TORCHVISION_AVAILABLE @@ -7,11 +7,9 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib from torchvision.datasets import MNIST - from torchvision.transforms import Compose else: # pragma: no cover warn_missing_pkg('torchvision') MNIST = None - Compose = object class MNISTDataModule(VisionDataModule): @@ -99,7 +97,7 @@ def num_classes(self) -> int: """ return 10 - def default_transforms(self) -> Compose: + def default_transforms(self) -> Callable: if self.normalize: mnist_transforms = transform_lib.Compose([ transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5, ), std=(0.5, )) diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py index c21d1af46b..fc14dd2cae 100644 --- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py +++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py @@ -1,6 +1,6 @@ # type: ignore[override] import os -from typing import Any, Optional +from typing import Any, Callable, Optional from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader @@ -12,10 +12,8 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib - from torchvision.transforms import Compose else: # pragma: no cover warn_missing_pkg('torchvision') - Compose = object class SSLImagenetDataModule(LightningDataModule): # pragma: no cover @@ -147,6 +145,6 @@ def test_dataloader(self, num_images_per_class: int, add_normalize: bool = False ) return loader - def _default_transforms(self) -> Compose: + def _default_transforms(self) -> Callable: mnist_transforms = transform_lib.Compose([transform_lib.ToTensor(), imagenet_normalization()]) return mnist_transforms diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py index 0433278af6..f7b9f9963f 100644 --- a/pl_bolts/datamodules/stl10_datamodule.py +++ b/pl_bolts/datamodules/stl10_datamodule.py @@ -14,10 +14,8 @@ if _TORCHVISION_AVAILABLE: from torchvision import transforms as transform_lib from torchvision.datasets import STL10 - from torchvision.transforms import Compose else: # pragma: no cover warn_missing_pkg('torchvision') - Compose = object class STL10DataModule(LightningDataModule): # pragma: no cover @@ -301,6 +299,6 @@ def val_dataloader_labeled(self) -> DataLoader: ) return loader - def _default_transforms(self) -> Compose: + def _default_transforms(self) -> Callable: data_transforms = transform_lib.Compose([transform_lib.ToTensor(), stl10_normalization()]) return data_transforms diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py index 73ba424c96..d15d9c59d4 100644 --- a/pl_bolts/datamodules/vision_datamodule.py +++ b/pl_bolts/datamodules/vision_datamodule.py @@ -1,18 +1,11 @@ import os from abc import abstractmethod -from typing import Any, List, Optional, Union +from typing import Any, Callable, List, Optional, Union import torch from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader, Dataset, random_split -from pl_bolts.utils import _TORCHVISION_AVAILABLE - -if _TORCHVISION_AVAILABLE: - from torchvision.transforms import Compose -else: - Compose = object - class VisionDataModule(LightningDataModule): @@ -120,7 +113,7 @@ def _get_splits(self, len_dataset: int) -> List[int]: return splits @abstractmethod - def default_transforms(self) -> Compose: + def default_transforms(self) -> Callable: """ Default transform for the dataset """ def train_dataloader(self, *args: Any, **kwargs: Any) -> DataLoader: From af563a57d64c76113c2b4daff02759422f62d90c Mon Sep 17 00:00:00 2001 From: Akihiro Nitta <nitta@akihironitta.com> Date: Wed, 20 Jan 2021 21:44:38 +0900 Subject: [PATCH 75/75] Add missing import --- pl_bolts/datamodules/imagenet_datamodule.py | 2 +- pl_bolts/datamodules/stl10_datamodule.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py index 066aa3cd0a..b9cd811335 100644 --- a/pl_bolts/datamodules/imagenet_datamodule.py +++ b/pl_bolts/datamodules/imagenet_datamodule.py @@ -1,6 +1,6 @@ # type: ignore[override] import os -from typing import Any, Optional +from typing import Any, Callable, Optional from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py index f7b9f9963f..43ff3ebb6a 100644 --- a/pl_bolts/datamodules/stl10_datamodule.py +++ b/pl_bolts/datamodules/stl10_datamodule.py @@ -1,6 +1,6 @@ # type: ignore[override] import os -from typing import Any, Optional +from typing import Any, Callable, Optional import torch from pytorch_lightning import LightningDataModule