From 34f5f07b5cbed147064a25b0541e4dc38c831060 Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Fri, 18 Dec 2020 19:58:16 +0900
Subject: [PATCH 01/75] Adding types to datamodules

---
 .../datamodules/binary_mnist_datamodule.py    |  4 +-
 pl_bolts/datamodules/cifar10_datamodule.py    |  6 +--
 pl_bolts/datamodules/cityscapes_datamodule.py | 16 ++++----
 .../datamodules/fashion_mnist_datamodule.py   |  4 +-
 pl_bolts/datamodules/imagenet_datamodule.py   | 18 ++++-----
 pl_bolts/datamodules/kitti_datamodule.py      | 14 +++----
 pl_bolts/datamodules/mnist_datamodule.py      |  4 +-
 pl_bolts/datamodules/sklearn_datamodule.py    | 38 +++++++++++--------
 .../datamodules/ssl_imagenet_datamodule.py    | 22 +++++------
 pl_bolts/datamodules/stl10_datamodule.py      | 22 +++++------
 pl_bolts/datamodules/vision_datamodule.py     | 16 ++++++--
 11 files changed, 90 insertions(+), 74 deletions(-)

diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py
index 142b3d54ef..c713abe107 100644
--- a/pl_bolts/datamodules/binary_mnist_datamodule.py
+++ b/pl_bolts/datamodules/binary_mnist_datamodule.py
@@ -56,7 +56,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ) -> None:
+    ):
         """
         Args:
             data_dir: Where to save/load the data
@@ -98,7 +98,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self):
+    def default_transforms(self) -> transform_lib.Compose:
         if self.normalize:
             mnist_transforms = transform_lib.Compose(
                 [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]
diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py
index afb2df8c9a..12aea1ec87 100644
--- a/pl_bolts/datamodules/cifar10_datamodule.py
+++ b/pl_bolts/datamodules/cifar10_datamodule.py
@@ -112,7 +112,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self):
+    def default_transforms(self) -> transform_lib.Compose:
         if self.normalize:
             cf10_transforms = transform_lib.Compose([transform_lib.ToTensor(), cifar10_normalization()])
         else:
@@ -146,14 +146,14 @@ class TinyCIFAR10DataModule(CIFAR10DataModule):
 
     def __init__(
         self,
-        data_dir: str,
+        data_dir: Optional[str],
         val_split: int = 50,
         num_workers: int = 16,
         num_samples: int = 100,
         labels: Optional[Sequence] = (1, 5, 8),
         *args: Any,
         **kwargs: Any,
-    ) -> None:
+    ):
         """
         Args:
             data_dir: where to save/load the data
diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py
index b851617225..d27bfc3196 100644
--- a/pl_bolts/datamodules/cityscapes_datamodule.py
+++ b/pl_bolts/datamodules/cityscapes_datamodule.py
@@ -69,8 +69,8 @@ def __init__(
             shuffle: bool = False,
             pin_memory: bool = False,
             drop_last: bool = False,
-            *args,
-            **kwargs,
+            *args: Any,
+            **kwargs: Any,
     ):
         """
         Args:
@@ -109,14 +109,14 @@ def __init__(
         self.target_transforms = None
 
     @property
-    def num_classes(self):
+    def num_classes(self) -> int:
         """
         Return:
             30
         """
         return 30
 
-    def train_dataloader(self):
+    def train_dataloader(self) -> DataLoader:
         """
         Cityscapes train set
         """
@@ -141,7 +141,7 @@ def train_dataloader(self):
         )
         return loader
 
-    def val_dataloader(self):
+    def val_dataloader(self) -> DataLoader:
         """
         Cityscapes val set
         """
@@ -166,7 +166,7 @@ def val_dataloader(self):
         )
         return loader
 
-    def test_dataloader(self):
+    def test_dataloader(self) -> DataLoader:
         """
         Cityscapes test set
         """
@@ -190,7 +190,7 @@ def test_dataloader(self):
         )
         return loader
 
-    def _default_transforms(self):
+    def _default_transforms(self) -> transform_lib.Compose:
         cityscapes_transforms = transform_lib.Compose([
             transform_lib.ToTensor(),
             transform_lib.Normalize(
@@ -200,7 +200,7 @@ def _default_transforms(self):
         ])
         return cityscapes_transforms
 
-    def _default_target_transforms(self):
+    def _default_target_transforms(self) -> transform_lib.Compose:
         cityscapes_target_trasnforms = transform_lib.Compose([
             transform_lib.ToTensor(),
             transform_lib.Lambda(lambda t: t.squeeze())
diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py
index 833c4599a6..9a73e6a637 100644
--- a/pl_bolts/datamodules/fashion_mnist_datamodule.py
+++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py
@@ -57,7 +57,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ) -> None:
+    ):
         """
         Args:
             data_dir: Where to save/load the data
@@ -93,7 +93,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self):
+    def default_transforms(self) -> transform_lib.Compose:
         if self.normalize:
             mnist_transforms = transform_lib.Compose(
                 [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]
diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py
index 38546e29ee..a31b637ba9 100644
--- a/pl_bolts/datamodules/imagenet_datamodule.py
+++ b/pl_bolts/datamodules/imagenet_datamodule.py
@@ -58,8 +58,8 @@ def __init__(
             shuffle: bool = False,
             pin_memory: bool = False,
             drop_last: bool = False,
-            *args,
-            **kwargs,
+            *args: Any,
+            **kwargs: Any,
     ):
         """
         Args:
@@ -94,7 +94,7 @@ def __init__(
         self.num_samples = 1281167 - self.num_imgs_per_val_class * self.num_classes
 
     @property
-    def num_classes(self):
+    def num_classes(self) -> int:
         """
         Return:
 
@@ -103,7 +103,7 @@ def num_classes(self):
         """
         return 1000
 
-    def _verify_splits(self, data_dir, split):
+    def _verify_splits(self, data_dir: str, split: str):
         dirs = os.listdir(data_dir)
 
         if split not in dirs:
@@ -138,7 +138,7 @@ def prepare_data(self):
                 UnlabeledImagenet.generate_meta_bins(path)
                 """)
 
-    def train_dataloader(self):
+    def train_dataloader(self) -> DataLoader:
         """
         Uses the train split of imagenet2012 and puts away a portion of it for the validation split
         """
@@ -160,7 +160,7 @@ def train_dataloader(self):
         )
         return loader
 
-    def val_dataloader(self):
+    def val_dataloader(self) -> DataLoader:
         """
         Uses the part of the train split of imagenet2012  that was not used for training via `num_imgs_per_val_class`
 
@@ -185,7 +185,7 @@ def val_dataloader(self):
         )
         return loader
 
-    def test_dataloader(self):
+    def test_dataloader(self) -> DataLoader:
         """
         Uses the validation split of imagenet2012 for testing
         """
@@ -206,7 +206,7 @@ def test_dataloader(self):
         )
         return loader
 
-    def train_transform(self):
+    def train_transform(self) -> transform_lib.Compose:
         """
         The standard imagenet transforms
 
@@ -232,7 +232,7 @@ def train_transform(self):
 
         return preprocessing
 
-    def val_transform(self):
+    def val_transform(self) -> transform_lib.Compose:
         """
         The standard imagenet transforms for validation
 
diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py
index 433e7fffed..a50528028b 100644
--- a/pl_bolts/datamodules/kitti_datamodule.py
+++ b/pl_bolts/datamodules/kitti_datamodule.py
@@ -21,7 +21,7 @@ class KittiDataModule(LightningDataModule):
 
     def __init__(
             self,
-            data_dir: str,
+            data_dir: Optional[str],
             val_split: float = 0.2,
             test_split: float = 0.1,
             num_workers: int = 16,
@@ -30,8 +30,8 @@ def __init__(
             shuffle: bool = False,
             pin_memory: bool = False,
             drop_last: bool = False,
-            *args,
-            **kwargs,
+            *args: Any,
+            **kwargs: Any,
     ):
         """
         Kitti train, validation and test dataloaders.
@@ -100,7 +100,7 @@ def __init__(
                                                                 lengths=[train_len, val_len, test_len],
                                                                 generator=torch.Generator().manual_seed(self.seed))
 
-    def train_dataloader(self):
+    def train_dataloader(self) -> DataLoader:
         loader = DataLoader(
             self.trainset,
             batch_size=self.batch_size,
@@ -111,7 +111,7 @@ def train_dataloader(self):
         )
         return loader
 
-    def val_dataloader(self):
+    def val_dataloader(self) -> DataLoader:
         loader = DataLoader(
             self.valset,
             batch_size=self.batch_size,
@@ -122,7 +122,7 @@ def val_dataloader(self):
         )
         return loader
 
-    def test_dataloader(self):
+    def test_dataloader(self) -> DataLoader:
         loader = DataLoader(
             self.testset,
             batch_size=self.batch_size,
@@ -133,7 +133,7 @@ def test_dataloader(self):
         )
         return loader
 
-    def _default_transforms(self):
+    def _default_transforms(self) -> transform_lib.Compose:
         kitti_transforms = transforms.Compose([
             transforms.ToTensor(),
             transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753],
diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py
index 1dd5e927b6..76a0438a0c 100644
--- a/pl_bolts/datamodules/mnist_datamodule.py
+++ b/pl_bolts/datamodules/mnist_datamodule.py
@@ -56,7 +56,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ) -> None:
+    ):
         """
         Args:
             data_dir: Where to save/load the data
@@ -92,7 +92,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self):
+    def default_transforms(self) -> transform_lib.Compose:
         if self.normalize:
             mnist_transforms = transform_lib.Compose(
                 [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]
diff --git a/pl_bolts/datamodules/sklearn_datamodule.py b/pl_bolts/datamodules/sklearn_datamodule.py
index ed262b10c8..dcfd559441 100644
--- a/pl_bolts/datamodules/sklearn_datamodule.py
+++ b/pl_bolts/datamodules/sklearn_datamodule.py
@@ -42,10 +42,10 @@ def __init__(self, X: np.ndarray, y: np.ndarray, X_transform: Any = None, y_tran
         self.X_transform = X_transform
         self.y_transform = y_transform
 
-    def __len__(self):
+    def __len__(self) -> int:
         return len(self.X)
 
-    def __getitem__(self, idx):
+    def __getitem__(self, idx) -> Tuple[np.ndarray, np.ndarray]:
         x = self.X[idx].astype(np.float32)
         y = self.Y[idx]
 
@@ -89,10 +89,10 @@ def __init__(self, X: torch.Tensor, y: torch.Tensor, X_transform: Any = None, y_
         self.X_transform = X_transform
         self.y_transform = y_transform
 
-    def __len__(self):
+    def __len__(self) -> int:
         return len(self.X)
 
-    def __getitem__(self, idx):
+    def __getitem__(self, idx) -> Tuple[torch.Tensor, torch.Tensor]:
         x = self.X[idx].float()
         y = self.Y[idx]
 
@@ -145,14 +145,14 @@ def __init__(
             x_val=None, y_val=None,
             x_test=None, y_test=None,
             val_split=0.2, test_split=0.1,
-            num_workers=2,
-            random_state=1234,
-            shuffle=True,
+            num_workers:int = 2,
+            random_state: int = 1234,
+            shuffle: bool = True,
             batch_size: int = 16,
-            pin_memory=False,
-            drop_last=False,
-            *args,
-            **kwargs,
+            pin_memory: bool = False,
+            drop_last: bool = False,
+            *args: Any,
+            **kwargs: Any,
     ):
 
         super().__init__(*args, **kwargs)
@@ -193,12 +193,20 @@ def __init__(
 
         self._init_datasets(X, y, x_val, y_val, x_test, y_test)
 
-    def _init_datasets(self, X, y, x_val, y_val, x_test, y_test):
+    def _init_datasets(
+        self,
+        X: np.ndarray,
+        y: np.ndarray,
+        x_val: np.ndarray,
+        y_val: np.ndarray,
+        x_test: np.ndarray,
+        y_test: np.ndarray
+    ):
         self.train_dataset = SklearnDataset(X, y)
         self.val_dataset = SklearnDataset(x_val, y_val)
         self.test_dataset = SklearnDataset(x_test, y_test)
 
-    def train_dataloader(self):
+    def train_dataloader(self) -> DataLoader:
         loader = DataLoader(
             self.train_dataset,
             batch_size=self.batch_size,
@@ -209,7 +217,7 @@ def train_dataloader(self):
         )
         return loader
 
-    def val_dataloader(self):
+    def val_dataloader(self) -> DataLoader:
         loader = DataLoader(
             self.val_dataset,
             batch_size=self.batch_size,
@@ -220,7 +228,7 @@ def val_dataloader(self):
         )
         return loader
 
-    def test_dataloader(self):
+    def test_dataloader(self) -> DataLoader:
         loader = DataLoader(
             self.test_dataset,
             batch_size=self.batch_size,
diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py
index 03e459fd5e..50315245af 100644
--- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py
+++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py
@@ -20,15 +20,15 @@ class SSLImagenetDataModule(LightningDataModule):  # pragma: no cover
 
     def __init__(
             self,
-            data_dir,
-            meta_dir=None,
-            num_workers=16,
+            data_dir: str,
+            meta_dir: Optional[str] = None,
+            num_workers: int = 16,
             batch_size: int = 32,
             shuffle: bool = False,
             pin_memory: bool = False,
             drop_last: bool = False,
-            *args,
-            **kwargs,
+            *args: Any,
+            **kwargs: Any,
     ):
         super().__init__(*args, **kwargs)
 
@@ -46,10 +46,10 @@ def __init__(
         self.drop_last = drop_last
 
     @property
-    def num_classes(self):
+    def num_classes(self) -> int:
         return 1000
 
-    def _verify_splits(self, data_dir, split):
+    def _verify_splits(self, data_dir: str, split: str):
         dirs = os.listdir(data_dir)
 
         if split not in dirs:
@@ -79,7 +79,7 @@ def prepare_data(self):
                 UnlabeledImagenet.generate_meta_bins(path)
                 """)
 
-    def train_dataloader(self, num_images_per_class=-1, add_normalize=False):
+    def train_dataloader(self, num_images_per_class: int = -1, add_normalize: bool = False) -> DataLoader:
         transforms = self._default_transforms() if self.train_transforms is None else self.train_transforms
 
         dataset = UnlabeledImagenet(self.data_dir,
@@ -97,7 +97,7 @@ def train_dataloader(self, num_images_per_class=-1, add_normalize=False):
         )
         return loader
 
-    def val_dataloader(self, num_images_per_class=50, add_normalize=False):
+    def val_dataloader(self, num_images_per_class: int = 50, add_normalize: bool = False) -> DataLoader:
         transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms
 
         dataset = UnlabeledImagenet(self.data_dir,
@@ -115,7 +115,7 @@ def val_dataloader(self, num_images_per_class=50, add_normalize=False):
         )
         return loader
 
-    def test_dataloader(self, num_images_per_class, add_normalize=False):
+    def test_dataloader(self, num_images_per_class: int, add_normalize: bool = False) -> DataLoader:
         transforms = self._default_transforms() if self.test_transforms is None else self.test_transforms
 
         dataset = UnlabeledImagenet(self.data_dir,
@@ -133,7 +133,7 @@ def test_dataloader(self, num_images_per_class, add_normalize=False):
         )
         return loader
 
-    def _default_transforms(self):
+    def _default_transforms(self) -> transform_lib.Compose:
         mnist_transforms = transform_lib.Compose([
             transform_lib.ToTensor(),
             imagenet_normalization()
diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py
index c666db9b9b..3b29995a1b 100644
--- a/pl_bolts/datamodules/stl10_datamodule.py
+++ b/pl_bolts/datamodules/stl10_datamodule.py
@@ -63,8 +63,8 @@ def __init__(
             shuffle: bool = False,
             pin_memory: bool = False,
             drop_last: bool = False,
-            *args,
-            **kwargs,
+            *args: Any,
+            **kwargs: Any,
     ):
         """
         Args:
@@ -99,7 +99,7 @@ def __init__(
         self.num_unlabeled_samples = 100000 - unlabeled_val_split
 
     @property
-    def num_classes(self):
+    def num_classes(self) -> int:
         return 10
 
     def prepare_data(self):
@@ -110,7 +110,7 @@ def prepare_data(self):
         STL10(self.data_dir, split='train', download=True, transform=transform_lib.ToTensor())
         STL10(self.data_dir, split='test', download=True, transform=transform_lib.ToTensor())
 
-    def train_dataloader(self):
+    def train_dataloader(self) -> DataLoader:
         """
         Loads the 'unlabeled' split minus a portion set aside for validation via `unlabeled_val_split`.
         """
@@ -131,7 +131,7 @@ def train_dataloader(self):
         )
         return loader
 
-    def train_dataloader_mixed(self):
+    def train_dataloader_mixed(self) -> DataLoader:
         """
         Loads a portion of the 'unlabeled' training data and 'train' (labeled) data.
         both portions have a subset removed for validation via `unlabeled_val_split` and `train_val_split`
@@ -169,7 +169,7 @@ def train_dataloader_mixed(self):
         )
         return loader
 
-    def val_dataloader(self):
+    def val_dataloader(self) -> DataLoader:
         """
         Loads a portion of the 'unlabeled' training data set aside for validation
         The val dataset = (unlabeled - train_val_split)
@@ -196,7 +196,7 @@ def val_dataloader(self):
         )
         return loader
 
-    def val_dataloader_mixed(self):
+    def val_dataloader_mixed(self) -> DataLoader:
         """
         Loads a portion of the 'unlabeled' training data set aside for validation along with
         the portion of the 'train' dataset to be used for validation
@@ -239,7 +239,7 @@ def val_dataloader_mixed(self):
         )
         return loader
 
-    def test_dataloader(self):
+    def test_dataloader(self) -> DataLoader:
         """
         Loads the test split of STL10
 
@@ -260,7 +260,7 @@ def test_dataloader(self):
         )
         return loader
 
-    def train_dataloader_labeled(self):
+    def train_dataloader_labeled(self) -> DataLoader:
         transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms
 
         dataset = STL10(self.data_dir, split='train', download=False, transform=transforms)
@@ -278,7 +278,7 @@ def train_dataloader_labeled(self):
         )
         return loader
 
-    def val_dataloader_labeled(self):
+    def val_dataloader_labeled(self) -> DataLoader:
         transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms
         dataset = STL10(self.data_dir,
                         split='train',
@@ -299,7 +299,7 @@ def val_dataloader_labeled(self):
         )
         return loader
 
-    def _default_transforms(self):
+    def _default_transforms(self) -> transform_lib.Compose:
         data_transforms = transform_lib.Compose([
             transform_lib.ToTensor(),
             stl10_normalization()
diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py
index 2144f0f509..5b8c508904 100644
--- a/pl_bolts/datamodules/vision_datamodule.py
+++ b/pl_bolts/datamodules/vision_datamodule.py
@@ -6,6 +6,14 @@
 from pytorch_lightning import LightningDataModule
 from torch.utils.data import DataLoader, Dataset, random_split
 
+from pl_bolts.utils import _TORCHVISION_AVAILABLE
+from pl_bolts.utils.warnings import warn_missing_pkg
+
+if _TORCHVISION_AVAILABLE:
+    from torchvision import transforms as transform_lib
+else:
+    warn_missing_pkg('torchvision')  # pragma: no-cover
+
 
 class VisionDataModule(LightningDataModule):
 
@@ -29,7 +37,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ) -> None:
+    ):
         """
         Args:
             data_dir: Where to save/load the data
@@ -56,14 +64,14 @@ def __init__(
         self.pin_memory = pin_memory
         self.drop_last = drop_last
 
-    def prepare_data(self) -> None:
+    def prepare_data(self):
         """
         Saves files to data_dir
         """
         self.dataset_cls(self.data_dir, train=True, download=True)
         self.dataset_cls(self.data_dir, train=False, download=True)
 
-    def setup(self, stage: Optional[str] = None) -> None:
+    def setup(self, stage: Optional[str] = None):
         """
         Creates train, val, and test dataset
         """
@@ -115,7 +123,7 @@ def _get_splits(self, len_dataset: int) -> List[int]:
         return splits
 
     @abstractmethod
-    def default_transforms(self):
+    def default_transforms(self) -> transform_lib.Compose:
         """ Default transform for the dataset """
 
     def train_dataloader(self) -> DataLoader:

From 2b55b328900ad96ce504ca2ff3f244cbe97c0597 Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Fri, 18 Dec 2020 20:20:20 +0900
Subject: [PATCH 02/75] Fixing typing imports

---
 pl_bolts/datamodules/async_dataloader.py       | 14 +++++++++++---
 pl_bolts/datamodules/cityscapes_datamodule.py  |  2 ++
 pl_bolts/datamodules/imagenet_datamodule.py    |  2 +-
 pl_bolts/datamodules/kitti_datamodule.py       |  3 ++-
 pl_bolts/datamodules/sklearn_datamodule.py     |  4 ++--
 .../datamodules/ssl_imagenet_datamodule.py     |  1 +
 pl_bolts/datamodules/stl10_datamodule.py       |  2 +-
 pl_bolts/datamodules/vision_datamodule.py      |  7 +------
 .../datamodules/vocdetection_datamodule.py     | 18 ++++++++++--------
 9 files changed, 31 insertions(+), 22 deletions(-)

diff --git a/pl_bolts/datamodules/async_dataloader.py b/pl_bolts/datamodules/async_dataloader.py
index 7ded9d9ef1..38a0b9bb58 100644
--- a/pl_bolts/datamodules/async_dataloader.py
+++ b/pl_bolts/datamodules/async_dataloader.py
@@ -1,10 +1,11 @@
 import re
 from queue import Queue
 from threading import Thread
+from typing import Any, Optional, Union
 
 import torch
 from torch._six import container_abcs, string_classes
-from torch.utils.data import DataLoader
+from torch.utils.data import DataLoader, Dataset
 
 
 class AsynchronousLoader(object):
@@ -26,7 +27,14 @@ class AsynchronousLoader(object):
             constructing one here
     """
 
-    def __init__(self, data, device=torch.device('cuda', 0), q_size=10, num_batches=None, **kwargs):
+    def __init__(
+            self,
+            data: Union[DataLoader, Dataset],
+            device: torch.device = torch.device('cuda', 0),
+            q_size: int = 10,
+            num_batches: Optional[int] = None,
+            **kwargs: Any
+    ):
         if isinstance(data, torch.utils.data.DataLoader):
             self.dataloader = data
         else:
@@ -105,5 +113,5 @@ def __next__(self):
         self.idx += 1
         return out
 
-    def __len__(self):
+    def __len__(self) -> int:
         return self.num_batches
diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py
index d27bfc3196..17812a0ac5 100644
--- a/pl_bolts/datamodules/cityscapes_datamodule.py
+++ b/pl_bolts/datamodules/cityscapes_datamodule.py
@@ -1,3 +1,5 @@
+from typing import Any
+
 from pytorch_lightning import LightningDataModule
 from torch.utils.data import DataLoader
 
diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py
index a31b637ba9..829c485aed 100644
--- a/pl_bolts/datamodules/imagenet_datamodule.py
+++ b/pl_bolts/datamodules/imagenet_datamodule.py
@@ -1,5 +1,5 @@
 import os
-from typing import Optional
+from typing import Any, Optional
 
 from pytorch_lightning import LightningDataModule
 from torch.utils.data import DataLoader
diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py
index a50528028b..9a82f0b7ec 100644
--- a/pl_bolts/datamodules/kitti_datamodule.py
+++ b/pl_bolts/datamodules/kitti_datamodule.py
@@ -1,4 +1,5 @@
 import os
+from typing import Any, Optional
 
 import torch
 from pytorch_lightning import LightningDataModule
@@ -133,7 +134,7 @@ def test_dataloader(self) -> DataLoader:
         )
         return loader
 
-    def _default_transforms(self) -> transform_lib.Compose:
+    def _default_transforms(self) -> transforms.Compose:
         kitti_transforms = transforms.Compose([
             transforms.ToTensor(),
             transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753],
diff --git a/pl_bolts/datamodules/sklearn_datamodule.py b/pl_bolts/datamodules/sklearn_datamodule.py
index dcfd559441..e80e1dfc9a 100644
--- a/pl_bolts/datamodules/sklearn_datamodule.py
+++ b/pl_bolts/datamodules/sklearn_datamodule.py
@@ -1,5 +1,5 @@
 import math
-from typing import Any
+from typing import Any, Tuple
 
 import numpy as np
 import torch
@@ -145,7 +145,7 @@ def __init__(
             x_val=None, y_val=None,
             x_test=None, y_test=None,
             val_split=0.2, test_split=0.1,
-            num_workers:int = 2,
+            num_workers: int = 2,
             random_state: int = 1234,
             shuffle: bool = True,
             batch_size: int = 16,
diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py
index 50315245af..1584583101 100644
--- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py
+++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py
@@ -1,4 +1,5 @@
 import os
+from typing import Any, Optional
 
 from pytorch_lightning import LightningDataModule
 from torch.utils.data import DataLoader
diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py
index 3b29995a1b..8d46cfd7bf 100644
--- a/pl_bolts/datamodules/stl10_datamodule.py
+++ b/pl_bolts/datamodules/stl10_datamodule.py
@@ -1,5 +1,5 @@
 import os
-from typing import Optional
+from typing import Any, Optional
 
 import torch
 from pytorch_lightning import LightningDataModule
diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py
index 5b8c508904..cdcefcb2eb 100644
--- a/pl_bolts/datamodules/vision_datamodule.py
+++ b/pl_bolts/datamodules/vision_datamodule.py
@@ -9,11 +9,6 @@
 from pl_bolts.utils import _TORCHVISION_AVAILABLE
 from pl_bolts.utils.warnings import warn_missing_pkg
 
-if _TORCHVISION_AVAILABLE:
-    from torchvision import transforms as transform_lib
-else:
-    warn_missing_pkg('torchvision')  # pragma: no-cover
-
 
 class VisionDataModule(LightningDataModule):
 
@@ -123,7 +118,7 @@ def _get_splits(self, len_dataset: int) -> List[int]:
         return splits
 
     @abstractmethod
-    def default_transforms(self) -> transform_lib.Compose:
+    def default_transforms(self):
         """ Default transform for the dataset """
 
     def train_dataloader(self) -> DataLoader:
diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py
index 34dd86811e..b9071f17be 100644
--- a/pl_bolts/datamodules/vocdetection_datamodule.py
+++ b/pl_bolts/datamodules/vocdetection_datamodule.py
@@ -1,3 +1,5 @@
+from typing import Any, Dict
+
 import torch
 from pytorch_lightning import LightningDataModule
 from torch.utils.data import DataLoader
@@ -17,7 +19,7 @@ class Compose(object):
     Like `torchvision.transforms.compose` but works for (image, target)
     """
 
-    def __init__(self, transforms):
+    def __init__(self, transforms: T.Compose):
         self.transforms = transforms
 
     def __call__(self, image, target):
@@ -55,7 +57,7 @@ def _collate_fn(batch):
 )
 
 
-def _prepare_voc_instance(image, target):
+def _prepare_voc_instance(image, target: Dict[str, Any]):
     """
     Prepares VOC dataset into appropriate target for fasterrcnn
 
@@ -114,8 +116,8 @@ def __init__(
         shuffle: bool = False,
         pin_memory: bool = False,
         drop_last: bool = False,
-        *args,
-        **kwargs,
+        *args: Any,
+        **kwargs: Any,
     ):
         if not _TORCHVISION_AVAILABLE:
             raise ModuleNotFoundError(  # pragma: no-cover
@@ -133,7 +135,7 @@ def __init__(
         self.drop_last = drop_last
 
     @property
-    def num_classes(self):
+    def num_classes(self) -> int:
         """
         Return:
             21
@@ -147,7 +149,7 @@ def prepare_data(self):
         VOCDetection(self.data_dir, year=self.year, image_set="train", download=True)
         VOCDetection(self.data_dir, year=self.year, image_set="val", download=True)
 
-    def train_dataloader(self, batch_size=1, transforms=None):
+    def train_dataloader(self, batch_size: int = 1, transforms=None) -> DataLoader:
         """
         VOCDetection train set uses the `train` subset
 
@@ -175,7 +177,7 @@ def train_dataloader(self, batch_size=1, transforms=None):
         )
         return loader
 
-    def val_dataloader(self, batch_size=1, transforms=None):
+    def val_dataloader(self, batch_size: int = 1, transforms=None) -> DataLoader:
         """
         VOCDetection val set uses the `val` subset
 
@@ -202,7 +204,7 @@ def val_dataloader(self, batch_size=1, transforms=None):
         )
         return loader
 
-    def _default_transforms(self):
+    def _default_transforms(self) -> T.Compose:
         if self.normalize:
             return (
                 lambda image, target: (

From ac3377dea28a0bce8051683388d2924ebe15f24c Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Sun, 20 Dec 2020 03:03:00 +0900
Subject: [PATCH 03/75] Removing torchvision.transforms from return typing

---
 pl_bolts/datamodules/binary_mnist_datamodule.py  | 2 +-
 pl_bolts/datamodules/cifar10_datamodule.py       | 2 +-
 pl_bolts/datamodules/cityscapes_datamodule.py    | 4 ++--
 pl_bolts/datamodules/fashion_mnist_datamodule.py | 2 +-
 pl_bolts/datamodules/imagenet_datamodule.py      | 4 ++--
 pl_bolts/datamodules/kitti_datamodule.py         | 2 +-
 pl_bolts/datamodules/mnist_datamodule.py         | 2 +-
 pl_bolts/datamodules/ssl_imagenet_datamodule.py  | 2 +-
 pl_bolts/datamodules/stl10_datamodule.py         | 2 +-
 pl_bolts/datamodules/vocdetection_datamodule.py  | 2 +-
 10 files changed, 12 insertions(+), 12 deletions(-)

diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py
index c713abe107..8dc02ec95e 100644
--- a/pl_bolts/datamodules/binary_mnist_datamodule.py
+++ b/pl_bolts/datamodules/binary_mnist_datamodule.py
@@ -98,7 +98,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self) -> transform_lib.Compose:
+    def default_transforms(self):
         if self.normalize:
             mnist_transforms = transform_lib.Compose(
                 [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]
diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py
index 12aea1ec87..b208172ed0 100644
--- a/pl_bolts/datamodules/cifar10_datamodule.py
+++ b/pl_bolts/datamodules/cifar10_datamodule.py
@@ -112,7 +112,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self) -> transform_lib.Compose:
+    def default_transforms(self):
         if self.normalize:
             cf10_transforms = transform_lib.Compose([transform_lib.ToTensor(), cifar10_normalization()])
         else:
diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py
index 17812a0ac5..dc8b866fba 100644
--- a/pl_bolts/datamodules/cityscapes_datamodule.py
+++ b/pl_bolts/datamodules/cityscapes_datamodule.py
@@ -192,7 +192,7 @@ def test_dataloader(self) -> DataLoader:
         )
         return loader
 
-    def _default_transforms(self) -> transform_lib.Compose:
+    def _default_transforms(self):
         cityscapes_transforms = transform_lib.Compose([
             transform_lib.ToTensor(),
             transform_lib.Normalize(
@@ -202,7 +202,7 @@ def _default_transforms(self) -> transform_lib.Compose:
         ])
         return cityscapes_transforms
 
-    def _default_target_transforms(self) -> transform_lib.Compose:
+    def _default_target_transforms(self):
         cityscapes_target_trasnforms = transform_lib.Compose([
             transform_lib.ToTensor(),
             transform_lib.Lambda(lambda t: t.squeeze())
diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py
index 9a73e6a637..a128ddfaab 100644
--- a/pl_bolts/datamodules/fashion_mnist_datamodule.py
+++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py
@@ -93,7 +93,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self) -> transform_lib.Compose:
+    def default_transforms(self):
         if self.normalize:
             mnist_transforms = transform_lib.Compose(
                 [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]
diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py
index 829c485aed..f8d8262108 100644
--- a/pl_bolts/datamodules/imagenet_datamodule.py
+++ b/pl_bolts/datamodules/imagenet_datamodule.py
@@ -206,7 +206,7 @@ def test_dataloader(self) -> DataLoader:
         )
         return loader
 
-    def train_transform(self) -> transform_lib.Compose:
+    def train_transform(self):
         """
         The standard imagenet transforms
 
@@ -232,7 +232,7 @@ def train_transform(self) -> transform_lib.Compose:
 
         return preprocessing
 
-    def val_transform(self) -> transform_lib.Compose:
+    def val_transform(self):
         """
         The standard imagenet transforms for validation
 
diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py
index 9a82f0b7ec..06778fdbfc 100644
--- a/pl_bolts/datamodules/kitti_datamodule.py
+++ b/pl_bolts/datamodules/kitti_datamodule.py
@@ -134,7 +134,7 @@ def test_dataloader(self) -> DataLoader:
         )
         return loader
 
-    def _default_transforms(self) -> transforms.Compose:
+    def _default_transforms(self):
         kitti_transforms = transforms.Compose([
             transforms.ToTensor(),
             transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753],
diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py
index 76a0438a0c..c57fe8ca82 100644
--- a/pl_bolts/datamodules/mnist_datamodule.py
+++ b/pl_bolts/datamodules/mnist_datamodule.py
@@ -92,7 +92,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self) -> transform_lib.Compose:
+    def default_transforms(self):
         if self.normalize:
             mnist_transforms = transform_lib.Compose(
                 [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]
diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py
index 1584583101..96041fd4d9 100644
--- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py
+++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py
@@ -134,7 +134,7 @@ def test_dataloader(self, num_images_per_class: int, add_normalize: bool = False
         )
         return loader
 
-    def _default_transforms(self) -> transform_lib.Compose:
+    def _default_transforms(self):
         mnist_transforms = transform_lib.Compose([
             transform_lib.ToTensor(),
             imagenet_normalization()
diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py
index 8d46cfd7bf..5cd680a535 100644
--- a/pl_bolts/datamodules/stl10_datamodule.py
+++ b/pl_bolts/datamodules/stl10_datamodule.py
@@ -299,7 +299,7 @@ def val_dataloader_labeled(self) -> DataLoader:
         )
         return loader
 
-    def _default_transforms(self) -> transform_lib.Compose:
+    def _default_transforms(self):
         data_transforms = transform_lib.Compose([
             transform_lib.ToTensor(),
             stl10_normalization()
diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py
index b9071f17be..3bea4ec2d4 100644
--- a/pl_bolts/datamodules/vocdetection_datamodule.py
+++ b/pl_bolts/datamodules/vocdetection_datamodule.py
@@ -204,7 +204,7 @@ def val_dataloader(self, batch_size: int = 1, transforms=None) -> DataLoader:
         )
         return loader
 
-    def _default_transforms(self) -> T.Compose:
+    def _default_transforms(self):
         if self.normalize:
             return (
                 lambda image, target: (

From a4c39c787fc83f998f42dfdab038d4a82a079ca2 Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Sun, 20 Dec 2020 03:09:03 +0900
Subject: [PATCH 04/75] Remove more torchvision.transforms typing

---
 pl_bolts/datamodules/vocdetection_datamodule.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py
index 3bea4ec2d4..448df864b6 100644
--- a/pl_bolts/datamodules/vocdetection_datamodule.py
+++ b/pl_bolts/datamodules/vocdetection_datamodule.py
@@ -19,7 +19,7 @@ class Compose(object):
     Like `torchvision.transforms.compose` but works for (image, target)
     """
 
-    def __init__(self, transforms: T.Compose):
+    def __init__(self, transforms):
         self.transforms = transforms
 
     def __call__(self, image, target):

From ffa0cb9fe6e3b2cd6f730520d55a494280f9d7f5 Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Mon, 21 Dec 2020 10:13:11 +0900
Subject: [PATCH 05/75] Removing return typing

---
 pl_bolts/datamodules/cifar10_datamodule.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py
index b208172ed0..85ba4de6e7 100644
--- a/pl_bolts/datamodules/cifar10_datamodule.py
+++ b/pl_bolts/datamodules/cifar10_datamodule.py
@@ -71,7 +71,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ) -> None:
+    ):
         """
         Args:
             data_dir: Where to save/load the data

From 5e6c5d406c0151bf4629fce60bf28dcf268fd0b9 Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Mon, 21 Dec 2020 16:00:55 +0900
Subject: [PATCH 06/75] Add `None` for optional arguments

---
 pl_bolts/datamodules/cifar10_datamodule.py | 2 +-
 pl_bolts/datamodules/kitti_datamodule.py   | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py
index 85ba4de6e7..534774684f 100644
--- a/pl_bolts/datamodules/cifar10_datamodule.py
+++ b/pl_bolts/datamodules/cifar10_datamodule.py
@@ -146,7 +146,7 @@ class TinyCIFAR10DataModule(CIFAR10DataModule):
 
     def __init__(
         self,
-        data_dir: Optional[str],
+        data_dir: Optional[str] = None,
         val_split: int = 50,
         num_workers: int = 16,
         num_samples: int = 100,
diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py
index 06778fdbfc..0067a1e53d 100644
--- a/pl_bolts/datamodules/kitti_datamodule.py
+++ b/pl_bolts/datamodules/kitti_datamodule.py
@@ -22,7 +22,7 @@ class KittiDataModule(LightningDataModule):
 
     def __init__(
             self,
-            data_dir: Optional[str],
+            data_dir: Optional[str] = None,
             val_split: float = 0.2,
             test_split: float = 0.1,
             num_workers: int = 16,

From 3a5a0ab24859431edb150df1934ff74b6b2e3b9f Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Mon, 21 Dec 2020 16:04:36 +0900
Subject: [PATCH 07/75] Remove unnecessary import

---
 pl_bolts/datamodules/vision_datamodule.py | 1 -
 1 file changed, 1 deletion(-)

diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py
index cdcefcb2eb..06ddc7ab18 100644
--- a/pl_bolts/datamodules/vision_datamodule.py
+++ b/pl_bolts/datamodules/vision_datamodule.py
@@ -6,7 +6,6 @@
 from pytorch_lightning import LightningDataModule
 from torch.utils.data import DataLoader, Dataset, random_split
 
-from pl_bolts.utils import _TORCHVISION_AVAILABLE
 from pl_bolts.utils.warnings import warn_missing_pkg
 
 

From 30579ed0a42703a7207c8b9cc0afc8143891ec32 Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Mon, 21 Dec 2020 16:04:36 +0900
Subject: [PATCH 08/75] Remove unnecessary import

---
 pl_bolts/datamodules/vision_datamodule.py | 3 ---
 1 file changed, 3 deletions(-)

diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py
index cdcefcb2eb..42252c4edf 100644
--- a/pl_bolts/datamodules/vision_datamodule.py
+++ b/pl_bolts/datamodules/vision_datamodule.py
@@ -6,9 +6,6 @@
 from pytorch_lightning import LightningDataModule
 from torch.utils.data import DataLoader, Dataset, random_split
 
-from pl_bolts.utils import _TORCHVISION_AVAILABLE
-from pl_bolts.utils.warnings import warn_missing_pkg
-
 
 class VisionDataModule(LightningDataModule):
 

From c6759311a1a24b0cc6cb8d350d96cdc582f26317 Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Mon, 21 Dec 2020 16:51:36 +0900
Subject: [PATCH 09/75] Add `None` return type

---
 pl_bolts/datamodules/async_dataloader.py         |  4 ++--
 pl_bolts/datamodules/binary_mnist_datamodule.py  |  2 +-
 pl_bolts/datamodules/cifar10_datamodule.py       |  4 ++--
 pl_bolts/datamodules/cityscapes_datamodule.py    |  2 +-
 pl_bolts/datamodules/experience_source.py        |  4 ++--
 pl_bolts/datamodules/fashion_mnist_datamodule.py |  2 +-
 pl_bolts/datamodules/imagenet_datamodule.py      |  6 +++---
 pl_bolts/datamodules/kitti_datamodule.py         |  2 +-
 pl_bolts/datamodules/mnist_datamodule.py         |  2 +-
 pl_bolts/datamodules/sklearn_datamodule.py       | 12 ++++++++----
 pl_bolts/datamodules/ssl_imagenet_datamodule.py  |  6 +++---
 pl_bolts/datamodules/stl10_datamodule.py         |  4 ++--
 pl_bolts/datamodules/vision_datamodule.py        |  6 +++---
 pl_bolts/datamodules/vocdetection_datamodule.py  |  6 +++---
 14 files changed, 33 insertions(+), 29 deletions(-)

diff --git a/pl_bolts/datamodules/async_dataloader.py b/pl_bolts/datamodules/async_dataloader.py
index 38a0b9bb58..224f34d5ee 100644
--- a/pl_bolts/datamodules/async_dataloader.py
+++ b/pl_bolts/datamodules/async_dataloader.py
@@ -34,7 +34,7 @@ def __init__(
             q_size: int = 10,
             num_batches: Optional[int] = None,
             **kwargs: Any
-    ):
+    ) -> None:
         if isinstance(data, torch.utils.data.DataLoader):
             self.dataloader = data
         else:
@@ -57,7 +57,7 @@ def __init__(
 
         self.np_str_obj_array_pattern = re.compile(r'[SaUO]')
 
-    def load_loop(self):  # The loop that will load into the queue in the background
+    def load_loop(self) -> None:  # The loop that will load into the queue in the background
         for i, sample in enumerate(self.dataloader):
             self.queue.put(self.load_instance(sample))
             if i == len(self):
diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py
index 8dc02ec95e..142b3d54ef 100644
--- a/pl_bolts/datamodules/binary_mnist_datamodule.py
+++ b/pl_bolts/datamodules/binary_mnist_datamodule.py
@@ -56,7 +56,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: Where to save/load the data
diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py
index 534774684f..2cb894d749 100644
--- a/pl_bolts/datamodules/cifar10_datamodule.py
+++ b/pl_bolts/datamodules/cifar10_datamodule.py
@@ -71,7 +71,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: Where to save/load the data
@@ -153,7 +153,7 @@ def __init__(
         labels: Optional[Sequence] = (1, 5, 8),
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: where to save/load the data
diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py
index dc8b866fba..61c1ae2bef 100644
--- a/pl_bolts/datamodules/cityscapes_datamodule.py
+++ b/pl_bolts/datamodules/cityscapes_datamodule.py
@@ -73,7 +73,7 @@ def __init__(
             drop_last: bool = False,
             *args: Any,
             **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: where to load the data from path, i.e. where directory leftImg8bit and gtFine or gtCoarse
diff --git a/pl_bolts/datamodules/experience_source.py b/pl_bolts/datamodules/experience_source.py
index 5fe1332dfd..fac4e82f25 100644
--- a/pl_bolts/datamodules/experience_source.py
+++ b/pl_bolts/datamodules/experience_source.py
@@ -30,7 +30,7 @@ class ExperienceSourceDataset(IterableDataset):
     The logic for the experience source and how the batch is generated is defined the Lightning model itself
     """
 
-    def __init__(self, generate_batch: Callable):
+    def __init__(self, generate_batch: Callable) -> None:
         self.generate_batch = generate_batch
 
     def __iter__(self) -> Iterable:
@@ -243,7 +243,7 @@ def pop_rewards_steps(self):
 class DiscountedExperienceSource(ExperienceSource):
     """Outputs experiences with a discounted reward over N steps"""
 
-    def __init__(self, env: Env, agent, n_steps: int = 1, gamma: float = 0.99):
+    def __init__(self, env: Env, agent, n_steps: int = 1, gamma: float = 0.99) -> None:
         super().__init__(env, agent, (n_steps + 1))
         self.gamma = gamma
         self.steps = n_steps
diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py
index a128ddfaab..833c4599a6 100644
--- a/pl_bolts/datamodules/fashion_mnist_datamodule.py
+++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py
@@ -57,7 +57,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: Where to save/load the data
diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py
index f8d8262108..6f06913f9f 100644
--- a/pl_bolts/datamodules/imagenet_datamodule.py
+++ b/pl_bolts/datamodules/imagenet_datamodule.py
@@ -60,7 +60,7 @@ def __init__(
             drop_last: bool = False,
             *args: Any,
             **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: path to the imagenet dataset file
@@ -103,14 +103,14 @@ def num_classes(self) -> int:
         """
         return 1000
 
-    def _verify_splits(self, data_dir: str, split: str):
+    def _verify_splits(self, data_dir: str, split: str) -> None:
         dirs = os.listdir(data_dir)
 
         if split not in dirs:
             raise FileNotFoundError(f'a {split} Imagenet split was not found in {data_dir},'
                                     f' make sure the folder contains a subfolder named {split}')
 
-    def prepare_data(self):
+    def prepare_data(self) -> None:
         """
         This method already assumes you have imagenet2012 downloaded.
         It validates the data using the meta.bin.
diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py
index 0067a1e53d..3cf26dc762 100644
--- a/pl_bolts/datamodules/kitti_datamodule.py
+++ b/pl_bolts/datamodules/kitti_datamodule.py
@@ -33,7 +33,7 @@ def __init__(
             drop_last: bool = False,
             *args: Any,
             **kwargs: Any,
-    ):
+    ) -> None:
         """
         Kitti train, validation and test dataloaders.
 
diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py
index c57fe8ca82..1dd5e927b6 100644
--- a/pl_bolts/datamodules/mnist_datamodule.py
+++ b/pl_bolts/datamodules/mnist_datamodule.py
@@ -56,7 +56,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: Where to save/load the data
diff --git a/pl_bolts/datamodules/sklearn_datamodule.py b/pl_bolts/datamodules/sklearn_datamodule.py
index e80e1dfc9a..d9477acc0b 100644
--- a/pl_bolts/datamodules/sklearn_datamodule.py
+++ b/pl_bolts/datamodules/sklearn_datamodule.py
@@ -28,7 +28,9 @@ class SklearnDataset(Dataset):
         >>> len(dataset)
         506
     """
-    def __init__(self, X: np.ndarray, y: np.ndarray, X_transform: Any = None, y_transform: Any = None):
+    def __init__(
+        self, X: np.ndarray, y: np.ndarray, X_transform: Any = None, y_transform: Any = None
+    ) -> None:
         """
         Args:
             X: Numpy ndarray
@@ -75,7 +77,9 @@ class TensorDataset(Dataset):
         >>> len(dataset)
         10
     """
-    def __init__(self, X: torch.Tensor, y: torch.Tensor, X_transform: Any = None, y_transform: Any = None):
+    def __init__(
+        self, X: torch.Tensor, y: torch.Tensor, X_transform: Any = None, y_transform: Any = None
+    ) -> None:
         """
         Args:
             X: PyTorch tensor
@@ -153,7 +157,7 @@ def __init__(
             drop_last: bool = False,
             *args: Any,
             **kwargs: Any,
-    ):
+    ) -> None:
 
         super().__init__(*args, **kwargs)
         self.num_workers = num_workers
@@ -201,7 +205,7 @@ def _init_datasets(
         y_val: np.ndarray,
         x_test: np.ndarray,
         y_test: np.ndarray
-    ):
+    ) -> None:
         self.train_dataset = SklearnDataset(X, y)
         self.val_dataset = SklearnDataset(x_val, y_val)
         self.test_dataset = SklearnDataset(x_test, y_test)
diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py
index 96041fd4d9..3dbda03527 100644
--- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py
+++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py
@@ -30,7 +30,7 @@ def __init__(
             drop_last: bool = False,
             *args: Any,
             **kwargs: Any,
-    ):
+    ) -> None:
         super().__init__(*args, **kwargs)
 
         if not _TORCHVISION_AVAILABLE:
@@ -50,14 +50,14 @@ def __init__(
     def num_classes(self) -> int:
         return 1000
 
-    def _verify_splits(self, data_dir: str, split: str):
+    def _verify_splits(self, data_dir: str, split: str) -> None:
         dirs = os.listdir(data_dir)
 
         if split not in dirs:
             raise FileNotFoundError(f'a {split} Imagenet split was not found in {data_dir}, make sure the'
                                     f' folder contains a subfolder named {split}')
 
-    def prepare_data(self):
+    def prepare_data(self) -> None:
         # imagenet cannot be downloaded... must provide path to folder with the train/val splits
         self._verify_splits(self.data_dir, 'train')
         self._verify_splits(self.data_dir, 'val')
diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py
index 5cd680a535..79420be149 100644
--- a/pl_bolts/datamodules/stl10_datamodule.py
+++ b/pl_bolts/datamodules/stl10_datamodule.py
@@ -65,7 +65,7 @@ def __init__(
             drop_last: bool = False,
             *args: Any,
             **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: where to save/load the data
@@ -102,7 +102,7 @@ def __init__(
     def num_classes(self) -> int:
         return 10
 
-    def prepare_data(self):
+    def prepare_data(self) -> None:
         """
         Downloads the unlabeled, train and test split
         """
diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py
index 42252c4edf..2144f0f509 100644
--- a/pl_bolts/datamodules/vision_datamodule.py
+++ b/pl_bolts/datamodules/vision_datamodule.py
@@ -29,7 +29,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: Where to save/load the data
@@ -56,14 +56,14 @@ def __init__(
         self.pin_memory = pin_memory
         self.drop_last = drop_last
 
-    def prepare_data(self):
+    def prepare_data(self) -> None:
         """
         Saves files to data_dir
         """
         self.dataset_cls(self.data_dir, train=True, download=True)
         self.dataset_cls(self.data_dir, train=False, download=True)
 
-    def setup(self, stage: Optional[str] = None):
+    def setup(self, stage: Optional[str] = None) -> None:
         """
         Creates train, val, and test dataset
         """
diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py
index 448df864b6..e7fc989330 100644
--- a/pl_bolts/datamodules/vocdetection_datamodule.py
+++ b/pl_bolts/datamodules/vocdetection_datamodule.py
@@ -19,7 +19,7 @@ class Compose(object):
     Like `torchvision.transforms.compose` but works for (image, target)
     """
 
-    def __init__(self, transforms):
+    def __init__(self, transforms) -> None:
         self.transforms = transforms
 
     def __call__(self, image, target):
@@ -118,7 +118,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         if not _TORCHVISION_AVAILABLE:
             raise ModuleNotFoundError(  # pragma: no-cover
                 'You want to use VOC dataset loaded from `torchvision` which is not installed yet.'
@@ -142,7 +142,7 @@ def num_classes(self) -> int:
         """
         return 21
 
-    def prepare_data(self):
+    def prepare_data(self) -> None:
         """
         Saves VOCDetection files to data_dir
         """

From 267649c125273644deeb31746e54d5c5d95d8357 Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Tue, 5 Jan 2021 20:30:43 +0900
Subject: [PATCH 10/75] Add type for torchvision transforms

---
 pl_bolts/datamodules/binary_mnist_datamodule.py  | 4 +++-
 pl_bolts/datamodules/cifar10_datamodule.py       | 4 +++-
 pl_bolts/datamodules/cityscapes_datamodule.py    | 6 ++++--
 pl_bolts/datamodules/fashion_mnist_datamodule.py | 4 +++-
 pl_bolts/datamodules/imagenet_datamodule.py      | 6 ++++--
 pl_bolts/datamodules/kitti_datamodule.py         | 4 +++-
 pl_bolts/datamodules/mnist_datamodule.py         | 4 +++-
 pl_bolts/datamodules/ssl_imagenet_datamodule.py  | 4 +++-
 pl_bolts/datamodules/stl10_datamodule.py         | 4 +++-
 pl_bolts/datamodules/vision_datamodule.py        | 9 ++++++++-
 10 files changed, 37 insertions(+), 12 deletions(-)

diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py
index 142b3d54ef..ad17360d08 100644
--- a/pl_bolts/datamodules/binary_mnist_datamodule.py
+++ b/pl_bolts/datamodules/binary_mnist_datamodule.py
@@ -7,8 +7,10 @@
 
 if _TORCHVISION_AVAILABLE:
     from torchvision import transforms as transform_lib
+    from torchvision.transforms import Compose
 else:  # pragma: no-cover
     warn_missing_pkg('torchvision')
+    Compose = object
 
 
 class BinaryMNISTDataModule(VisionDataModule):
@@ -98,7 +100,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self):
+    def default_transforms(self) -> Compose:
         if self.normalize:
             mnist_transforms = transform_lib.Compose(
                 [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]
diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py
index 2cb894d749..9dbf10b670 100644
--- a/pl_bolts/datamodules/cifar10_datamodule.py
+++ b/pl_bolts/datamodules/cifar10_datamodule.py
@@ -9,9 +9,11 @@
 if _TORCHVISION_AVAILABLE:
     from torchvision import transforms as transform_lib
     from torchvision.datasets import CIFAR10
+    from torchvision.transforms import Compose
 else:
     warn_missing_pkg('torchvision')  # pragma: no-cover
     CIFAR10 = None
+    Compose = object
 
 
 class CIFAR10DataModule(VisionDataModule):
@@ -112,7 +114,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self):
+    def default_transforms(self) -> Compose:
         if self.normalize:
             cf10_transforms = transform_lib.Compose([transform_lib.ToTensor(), cifar10_normalization()])
         else:
diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py
index 61c1ae2bef..130e333976 100644
--- a/pl_bolts/datamodules/cityscapes_datamodule.py
+++ b/pl_bolts/datamodules/cityscapes_datamodule.py
@@ -9,8 +9,10 @@
 if _TORCHVISION_AVAILABLE:
     from torchvision import transforms as transform_lib
     from torchvision.datasets import Cityscapes
+    from torchvision.transforms import Compose
 else:
     warn_missing_pkg('torchvision')  # pragma: no-cover
+    Compose = object
 
 
 class CityscapesDataModule(LightningDataModule):
@@ -192,7 +194,7 @@ def test_dataloader(self) -> DataLoader:
         )
         return loader
 
-    def _default_transforms(self):
+    def _default_transforms(self) -> Compose:
         cityscapes_transforms = transform_lib.Compose([
             transform_lib.ToTensor(),
             transform_lib.Normalize(
@@ -202,7 +204,7 @@ def _default_transforms(self):
         ])
         return cityscapes_transforms
 
-    def _default_target_transforms(self):
+    def _default_target_transforms(self) -> Compose:
         cityscapes_target_trasnforms = transform_lib.Compose([
             transform_lib.ToTensor(),
             transform_lib.Lambda(lambda t: t.squeeze())
diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py
index 833c4599a6..b37221bc74 100644
--- a/pl_bolts/datamodules/fashion_mnist_datamodule.py
+++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py
@@ -7,9 +7,11 @@
 if _TORCHVISION_AVAILABLE:
     from torchvision import transforms as transform_lib
     from torchvision.datasets import FashionMNIST
+    from torchvision.transforms import Compose
 else:
     warn_missing_pkg('torchvision')  # pragma: no-cover
     FashionMNIST = None
+    Compose = object
 
 
 class FashionMNISTDataModule(VisionDataModule):
@@ -93,7 +95,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self):
+    def default_transforms(self) -> Compose:
         if self.normalize:
             mnist_transforms = transform_lib.Compose(
                 [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]
diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py
index 6f06913f9f..db2fc68c0b 100644
--- a/pl_bolts/datamodules/imagenet_datamodule.py
+++ b/pl_bolts/datamodules/imagenet_datamodule.py
@@ -11,8 +11,10 @@
 
 if _TORCHVISION_AVAILABLE:
     from torchvision import transforms as transform_lib
+    from torchvision.transforms import Compose
 else:
     warn_missing_pkg('torchvision')  # pragma: no-cover
+    Compose = object
 
 
 class ImagenetDataModule(LightningDataModule):
@@ -206,7 +208,7 @@ def test_dataloader(self) -> DataLoader:
         )
         return loader
 
-    def train_transform(self):
+    def train_transform(self) -> Compose:
         """
         The standard imagenet transforms
 
@@ -232,7 +234,7 @@ def train_transform(self):
 
         return preprocessing
 
-    def val_transform(self):
+    def val_transform(self) -> Compose:
         """
         The standard imagenet transforms for validation
 
diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py
index 3cf26dc762..b63040f3bf 100644
--- a/pl_bolts/datamodules/kitti_datamodule.py
+++ b/pl_bolts/datamodules/kitti_datamodule.py
@@ -12,8 +12,10 @@
 
 if _TORCHVISION_AVAILABLE:
     import torchvision.transforms as transforms
+    from torchvision.transforms import Compose
 else:  # pragma: no cover
     warn_missing_pkg('torchvision')
+    Compose = object
 
 
 class KittiDataModule(LightningDataModule):
@@ -134,7 +136,7 @@ def test_dataloader(self) -> DataLoader:
         )
         return loader
 
-    def _default_transforms(self):
+    def _default_transforms(self) -> Compose:
         kitti_transforms = transforms.Compose([
             transforms.ToTensor(),
             transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753],
diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py
index 1dd5e927b6..711460023c 100644
--- a/pl_bolts/datamodules/mnist_datamodule.py
+++ b/pl_bolts/datamodules/mnist_datamodule.py
@@ -7,9 +7,11 @@
 if _TORCHVISION_AVAILABLE:
     from torchvision import transforms as transform_lib
     from torchvision.datasets import MNIST
+    from torchvision.transforms import Compose
 else:
     warn_missing_pkg('torchvision')  # pragma: no-cover
     MNIST = None
+    Compose = object
 
 
 class MNISTDataModule(VisionDataModule):
@@ -92,7 +94,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self):
+    def default_transforms(self) -> Compose:
         if self.normalize:
             mnist_transforms = transform_lib.Compose(
                 [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]
diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py
index 3dbda03527..d575eb2d01 100644
--- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py
+++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py
@@ -11,8 +11,10 @@
 
 if _TORCHVISION_AVAILABLE:
     from torchvision import transforms as transform_lib
+    from torchvision.transforms import Compose
 else:
     warn_missing_pkg('torchvision')  # pragma: no-cover
+    Compose = object
 
 
 class SSLImagenetDataModule(LightningDataModule):  # pragma: no cover
@@ -134,7 +136,7 @@ def test_dataloader(self, num_images_per_class: int, add_normalize: bool = False
         )
         return loader
 
-    def _default_transforms(self):
+    def _default_transforms(self) -> Compose:
         mnist_transforms = transform_lib.Compose([
             transform_lib.ToTensor(),
             imagenet_normalization()
diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py
index 79420be149..f9ac77e140 100644
--- a/pl_bolts/datamodules/stl10_datamodule.py
+++ b/pl_bolts/datamodules/stl10_datamodule.py
@@ -13,8 +13,10 @@
 if _TORCHVISION_AVAILABLE:
     from torchvision import transforms as transform_lib
     from torchvision.datasets import STL10
+    from torchvision.transforms import Compose
 else:
     warn_missing_pkg('torchvision')  # pragma: no-cover
+    Compose = object
 
 
 class STL10DataModule(LightningDataModule):  # pragma: no cover
@@ -299,7 +301,7 @@ def val_dataloader_labeled(self) -> DataLoader:
         )
         return loader
 
-    def _default_transforms(self):
+    def _default_transforms(self) -> Compose:
         data_transforms = transform_lib.Compose([
             transform_lib.ToTensor(),
             stl10_normalization()
diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py
index 2144f0f509..92e6723968 100644
--- a/pl_bolts/datamodules/vision_datamodule.py
+++ b/pl_bolts/datamodules/vision_datamodule.py
@@ -6,6 +6,13 @@
 from pytorch_lightning import LightningDataModule
 from torch.utils.data import DataLoader, Dataset, random_split
 
+from pl_bolts.utils import _TORCHVISION_AVAILABLE
+
+if _TORCHVISION_AVAILABLE:
+    from torchvision.transforms import Compose
+else:
+    Compose = object
+
 
 class VisionDataModule(LightningDataModule):
 
@@ -115,7 +122,7 @@ def _get_splits(self, len_dataset: int) -> List[int]:
         return splits
 
     @abstractmethod
-    def default_transforms(self):
+    def default_transforms(self) -> Compose:
         """ Default transform for the dataset """
 
     def train_dataloader(self) -> DataLoader:

From 3938fff80783b67ed26245a73dce63e7c52f42fb Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Fri, 18 Dec 2020 19:58:16 +0900
Subject: [PATCH 11/75] Adding types to datamodules

---
 .../datamodules/binary_mnist_datamodule.py    |  4 +-
 pl_bolts/datamodules/cifar10_datamodule.py    |  6 +--
 pl_bolts/datamodules/cityscapes_datamodule.py | 16 ++++----
 .../datamodules/fashion_mnist_datamodule.py   |  4 +-
 pl_bolts/datamodules/imagenet_datamodule.py   | 18 ++++-----
 pl_bolts/datamodules/kitti_datamodule.py      | 14 +++----
 pl_bolts/datamodules/mnist_datamodule.py      |  4 +-
 pl_bolts/datamodules/sklearn_datamodule.py    | 38 +++++++++++--------
 .../datamodules/ssl_imagenet_datamodule.py    | 22 +++++------
 pl_bolts/datamodules/stl10_datamodule.py      | 22 +++++------
 pl_bolts/datamodules/vision_datamodule.py     | 16 ++++++--
 11 files changed, 90 insertions(+), 74 deletions(-)

diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py
index 142b3d54ef..c713abe107 100644
--- a/pl_bolts/datamodules/binary_mnist_datamodule.py
+++ b/pl_bolts/datamodules/binary_mnist_datamodule.py
@@ -56,7 +56,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ) -> None:
+    ):
         """
         Args:
             data_dir: Where to save/load the data
@@ -98,7 +98,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self):
+    def default_transforms(self) -> transform_lib.Compose:
         if self.normalize:
             mnist_transforms = transform_lib.Compose(
                 [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]
diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py
index afb2df8c9a..12aea1ec87 100644
--- a/pl_bolts/datamodules/cifar10_datamodule.py
+++ b/pl_bolts/datamodules/cifar10_datamodule.py
@@ -112,7 +112,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self):
+    def default_transforms(self) -> transform_lib.Compose:
         if self.normalize:
             cf10_transforms = transform_lib.Compose([transform_lib.ToTensor(), cifar10_normalization()])
         else:
@@ -146,14 +146,14 @@ class TinyCIFAR10DataModule(CIFAR10DataModule):
 
     def __init__(
         self,
-        data_dir: str,
+        data_dir: Optional[str],
         val_split: int = 50,
         num_workers: int = 16,
         num_samples: int = 100,
         labels: Optional[Sequence] = (1, 5, 8),
         *args: Any,
         **kwargs: Any,
-    ) -> None:
+    ):
         """
         Args:
             data_dir: where to save/load the data
diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py
index b851617225..d27bfc3196 100644
--- a/pl_bolts/datamodules/cityscapes_datamodule.py
+++ b/pl_bolts/datamodules/cityscapes_datamodule.py
@@ -69,8 +69,8 @@ def __init__(
             shuffle: bool = False,
             pin_memory: bool = False,
             drop_last: bool = False,
-            *args,
-            **kwargs,
+            *args: Any,
+            **kwargs: Any,
     ):
         """
         Args:
@@ -109,14 +109,14 @@ def __init__(
         self.target_transforms = None
 
     @property
-    def num_classes(self):
+    def num_classes(self) -> int:
         """
         Return:
             30
         """
         return 30
 
-    def train_dataloader(self):
+    def train_dataloader(self) -> DataLoader:
         """
         Cityscapes train set
         """
@@ -141,7 +141,7 @@ def train_dataloader(self):
         )
         return loader
 
-    def val_dataloader(self):
+    def val_dataloader(self) -> DataLoader:
         """
         Cityscapes val set
         """
@@ -166,7 +166,7 @@ def val_dataloader(self):
         )
         return loader
 
-    def test_dataloader(self):
+    def test_dataloader(self) -> DataLoader:
         """
         Cityscapes test set
         """
@@ -190,7 +190,7 @@ def test_dataloader(self):
         )
         return loader
 
-    def _default_transforms(self):
+    def _default_transforms(self) -> transform_lib.Compose:
         cityscapes_transforms = transform_lib.Compose([
             transform_lib.ToTensor(),
             transform_lib.Normalize(
@@ -200,7 +200,7 @@ def _default_transforms(self):
         ])
         return cityscapes_transforms
 
-    def _default_target_transforms(self):
+    def _default_target_transforms(self) -> transform_lib.Compose:
         cityscapes_target_trasnforms = transform_lib.Compose([
             transform_lib.ToTensor(),
             transform_lib.Lambda(lambda t: t.squeeze())
diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py
index 833c4599a6..9a73e6a637 100644
--- a/pl_bolts/datamodules/fashion_mnist_datamodule.py
+++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py
@@ -57,7 +57,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ) -> None:
+    ):
         """
         Args:
             data_dir: Where to save/load the data
@@ -93,7 +93,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self):
+    def default_transforms(self) -> transform_lib.Compose:
         if self.normalize:
             mnist_transforms = transform_lib.Compose(
                 [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]
diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py
index 38546e29ee..a31b637ba9 100644
--- a/pl_bolts/datamodules/imagenet_datamodule.py
+++ b/pl_bolts/datamodules/imagenet_datamodule.py
@@ -58,8 +58,8 @@ def __init__(
             shuffle: bool = False,
             pin_memory: bool = False,
             drop_last: bool = False,
-            *args,
-            **kwargs,
+            *args: Any,
+            **kwargs: Any,
     ):
         """
         Args:
@@ -94,7 +94,7 @@ def __init__(
         self.num_samples = 1281167 - self.num_imgs_per_val_class * self.num_classes
 
     @property
-    def num_classes(self):
+    def num_classes(self) -> int:
         """
         Return:
 
@@ -103,7 +103,7 @@ def num_classes(self):
         """
         return 1000
 
-    def _verify_splits(self, data_dir, split):
+    def _verify_splits(self, data_dir: str, split: str):
         dirs = os.listdir(data_dir)
 
         if split not in dirs:
@@ -138,7 +138,7 @@ def prepare_data(self):
                 UnlabeledImagenet.generate_meta_bins(path)
                 """)
 
-    def train_dataloader(self):
+    def train_dataloader(self) -> DataLoader:
         """
         Uses the train split of imagenet2012 and puts away a portion of it for the validation split
         """
@@ -160,7 +160,7 @@ def train_dataloader(self):
         )
         return loader
 
-    def val_dataloader(self):
+    def val_dataloader(self) -> DataLoader:
         """
         Uses the part of the train split of imagenet2012  that was not used for training via `num_imgs_per_val_class`
 
@@ -185,7 +185,7 @@ def val_dataloader(self):
         )
         return loader
 
-    def test_dataloader(self):
+    def test_dataloader(self) -> DataLoader:
         """
         Uses the validation split of imagenet2012 for testing
         """
@@ -206,7 +206,7 @@ def test_dataloader(self):
         )
         return loader
 
-    def train_transform(self):
+    def train_transform(self) -> transform_lib.Compose:
         """
         The standard imagenet transforms
 
@@ -232,7 +232,7 @@ def train_transform(self):
 
         return preprocessing
 
-    def val_transform(self):
+    def val_transform(self) -> transform_lib.Compose:
         """
         The standard imagenet transforms for validation
 
diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py
index 433e7fffed..a50528028b 100644
--- a/pl_bolts/datamodules/kitti_datamodule.py
+++ b/pl_bolts/datamodules/kitti_datamodule.py
@@ -21,7 +21,7 @@ class KittiDataModule(LightningDataModule):
 
     def __init__(
             self,
-            data_dir: str,
+            data_dir: Optional[str],
             val_split: float = 0.2,
             test_split: float = 0.1,
             num_workers: int = 16,
@@ -30,8 +30,8 @@ def __init__(
             shuffle: bool = False,
             pin_memory: bool = False,
             drop_last: bool = False,
-            *args,
-            **kwargs,
+            *args: Any,
+            **kwargs: Any,
     ):
         """
         Kitti train, validation and test dataloaders.
@@ -100,7 +100,7 @@ def __init__(
                                                                 lengths=[train_len, val_len, test_len],
                                                                 generator=torch.Generator().manual_seed(self.seed))
 
-    def train_dataloader(self):
+    def train_dataloader(self) -> DataLoader:
         loader = DataLoader(
             self.trainset,
             batch_size=self.batch_size,
@@ -111,7 +111,7 @@ def train_dataloader(self):
         )
         return loader
 
-    def val_dataloader(self):
+    def val_dataloader(self) -> DataLoader:
         loader = DataLoader(
             self.valset,
             batch_size=self.batch_size,
@@ -122,7 +122,7 @@ def val_dataloader(self):
         )
         return loader
 
-    def test_dataloader(self):
+    def test_dataloader(self) -> DataLoader:
         loader = DataLoader(
             self.testset,
             batch_size=self.batch_size,
@@ -133,7 +133,7 @@ def test_dataloader(self):
         )
         return loader
 
-    def _default_transforms(self):
+    def _default_transforms(self) -> transform_lib.Compose:
         kitti_transforms = transforms.Compose([
             transforms.ToTensor(),
             transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753],
diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py
index 1dd5e927b6..76a0438a0c 100644
--- a/pl_bolts/datamodules/mnist_datamodule.py
+++ b/pl_bolts/datamodules/mnist_datamodule.py
@@ -56,7 +56,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ) -> None:
+    ):
         """
         Args:
             data_dir: Where to save/load the data
@@ -92,7 +92,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self):
+    def default_transforms(self) -> transform_lib.Compose:
         if self.normalize:
             mnist_transforms = transform_lib.Compose(
                 [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]
diff --git a/pl_bolts/datamodules/sklearn_datamodule.py b/pl_bolts/datamodules/sklearn_datamodule.py
index ed262b10c8..dcfd559441 100644
--- a/pl_bolts/datamodules/sklearn_datamodule.py
+++ b/pl_bolts/datamodules/sklearn_datamodule.py
@@ -42,10 +42,10 @@ def __init__(self, X: np.ndarray, y: np.ndarray, X_transform: Any = None, y_tran
         self.X_transform = X_transform
         self.y_transform = y_transform
 
-    def __len__(self):
+    def __len__(self) -> int:
         return len(self.X)
 
-    def __getitem__(self, idx):
+    def __getitem__(self, idx) -> Tuple[np.ndarray, np.ndarray]:
         x = self.X[idx].astype(np.float32)
         y = self.Y[idx]
 
@@ -89,10 +89,10 @@ def __init__(self, X: torch.Tensor, y: torch.Tensor, X_transform: Any = None, y_
         self.X_transform = X_transform
         self.y_transform = y_transform
 
-    def __len__(self):
+    def __len__(self) -> int:
         return len(self.X)
 
-    def __getitem__(self, idx):
+    def __getitem__(self, idx) -> Tuple[torch.Tensor, torch.Tensor]:
         x = self.X[idx].float()
         y = self.Y[idx]
 
@@ -145,14 +145,14 @@ def __init__(
             x_val=None, y_val=None,
             x_test=None, y_test=None,
             val_split=0.2, test_split=0.1,
-            num_workers=2,
-            random_state=1234,
-            shuffle=True,
+            num_workers:int = 2,
+            random_state: int = 1234,
+            shuffle: bool = True,
             batch_size: int = 16,
-            pin_memory=False,
-            drop_last=False,
-            *args,
-            **kwargs,
+            pin_memory: bool = False,
+            drop_last: bool = False,
+            *args: Any,
+            **kwargs: Any,
     ):
 
         super().__init__(*args, **kwargs)
@@ -193,12 +193,20 @@ def __init__(
 
         self._init_datasets(X, y, x_val, y_val, x_test, y_test)
 
-    def _init_datasets(self, X, y, x_val, y_val, x_test, y_test):
+    def _init_datasets(
+        self,
+        X: np.ndarray,
+        y: np.ndarray,
+        x_val: np.ndarray,
+        y_val: np.ndarray,
+        x_test: np.ndarray,
+        y_test: np.ndarray
+    ):
         self.train_dataset = SklearnDataset(X, y)
         self.val_dataset = SklearnDataset(x_val, y_val)
         self.test_dataset = SklearnDataset(x_test, y_test)
 
-    def train_dataloader(self):
+    def train_dataloader(self) -> DataLoader:
         loader = DataLoader(
             self.train_dataset,
             batch_size=self.batch_size,
@@ -209,7 +217,7 @@ def train_dataloader(self):
         )
         return loader
 
-    def val_dataloader(self):
+    def val_dataloader(self) -> DataLoader:
         loader = DataLoader(
             self.val_dataset,
             batch_size=self.batch_size,
@@ -220,7 +228,7 @@ def val_dataloader(self):
         )
         return loader
 
-    def test_dataloader(self):
+    def test_dataloader(self) -> DataLoader:
         loader = DataLoader(
             self.test_dataset,
             batch_size=self.batch_size,
diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py
index 03e459fd5e..50315245af 100644
--- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py
+++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py
@@ -20,15 +20,15 @@ class SSLImagenetDataModule(LightningDataModule):  # pragma: no cover
 
     def __init__(
             self,
-            data_dir,
-            meta_dir=None,
-            num_workers=16,
+            data_dir: str,
+            meta_dir: Optional[str] = None,
+            num_workers: int = 16,
             batch_size: int = 32,
             shuffle: bool = False,
             pin_memory: bool = False,
             drop_last: bool = False,
-            *args,
-            **kwargs,
+            *args: Any,
+            **kwargs: Any,
     ):
         super().__init__(*args, **kwargs)
 
@@ -46,10 +46,10 @@ def __init__(
         self.drop_last = drop_last
 
     @property
-    def num_classes(self):
+    def num_classes(self) -> int:
         return 1000
 
-    def _verify_splits(self, data_dir, split):
+    def _verify_splits(self, data_dir: str, split: str):
         dirs = os.listdir(data_dir)
 
         if split not in dirs:
@@ -79,7 +79,7 @@ def prepare_data(self):
                 UnlabeledImagenet.generate_meta_bins(path)
                 """)
 
-    def train_dataloader(self, num_images_per_class=-1, add_normalize=False):
+    def train_dataloader(self, num_images_per_class: int = -1, add_normalize: bool = False) -> DataLoader:
         transforms = self._default_transforms() if self.train_transforms is None else self.train_transforms
 
         dataset = UnlabeledImagenet(self.data_dir,
@@ -97,7 +97,7 @@ def train_dataloader(self, num_images_per_class=-1, add_normalize=False):
         )
         return loader
 
-    def val_dataloader(self, num_images_per_class=50, add_normalize=False):
+    def val_dataloader(self, num_images_per_class: int = 50, add_normalize: bool = False) -> DataLoader:
         transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms
 
         dataset = UnlabeledImagenet(self.data_dir,
@@ -115,7 +115,7 @@ def val_dataloader(self, num_images_per_class=50, add_normalize=False):
         )
         return loader
 
-    def test_dataloader(self, num_images_per_class, add_normalize=False):
+    def test_dataloader(self, num_images_per_class: int, add_normalize: bool = False) -> DataLoader:
         transforms = self._default_transforms() if self.test_transforms is None else self.test_transforms
 
         dataset = UnlabeledImagenet(self.data_dir,
@@ -133,7 +133,7 @@ def test_dataloader(self, num_images_per_class, add_normalize=False):
         )
         return loader
 
-    def _default_transforms(self):
+    def _default_transforms(self) -> transform_lib.Compose:
         mnist_transforms = transform_lib.Compose([
             transform_lib.ToTensor(),
             imagenet_normalization()
diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py
index c666db9b9b..3b29995a1b 100644
--- a/pl_bolts/datamodules/stl10_datamodule.py
+++ b/pl_bolts/datamodules/stl10_datamodule.py
@@ -63,8 +63,8 @@ def __init__(
             shuffle: bool = False,
             pin_memory: bool = False,
             drop_last: bool = False,
-            *args,
-            **kwargs,
+            *args: Any,
+            **kwargs: Any,
     ):
         """
         Args:
@@ -99,7 +99,7 @@ def __init__(
         self.num_unlabeled_samples = 100000 - unlabeled_val_split
 
     @property
-    def num_classes(self):
+    def num_classes(self) -> int:
         return 10
 
     def prepare_data(self):
@@ -110,7 +110,7 @@ def prepare_data(self):
         STL10(self.data_dir, split='train', download=True, transform=transform_lib.ToTensor())
         STL10(self.data_dir, split='test', download=True, transform=transform_lib.ToTensor())
 
-    def train_dataloader(self):
+    def train_dataloader(self) -> DataLoader:
         """
         Loads the 'unlabeled' split minus a portion set aside for validation via `unlabeled_val_split`.
         """
@@ -131,7 +131,7 @@ def train_dataloader(self):
         )
         return loader
 
-    def train_dataloader_mixed(self):
+    def train_dataloader_mixed(self) -> DataLoader:
         """
         Loads a portion of the 'unlabeled' training data and 'train' (labeled) data.
         both portions have a subset removed for validation via `unlabeled_val_split` and `train_val_split`
@@ -169,7 +169,7 @@ def train_dataloader_mixed(self):
         )
         return loader
 
-    def val_dataloader(self):
+    def val_dataloader(self) -> DataLoader:
         """
         Loads a portion of the 'unlabeled' training data set aside for validation
         The val dataset = (unlabeled - train_val_split)
@@ -196,7 +196,7 @@ def val_dataloader(self):
         )
         return loader
 
-    def val_dataloader_mixed(self):
+    def val_dataloader_mixed(self) -> DataLoader:
         """
         Loads a portion of the 'unlabeled' training data set aside for validation along with
         the portion of the 'train' dataset to be used for validation
@@ -239,7 +239,7 @@ def val_dataloader_mixed(self):
         )
         return loader
 
-    def test_dataloader(self):
+    def test_dataloader(self) -> DataLoader:
         """
         Loads the test split of STL10
 
@@ -260,7 +260,7 @@ def test_dataloader(self):
         )
         return loader
 
-    def train_dataloader_labeled(self):
+    def train_dataloader_labeled(self) -> DataLoader:
         transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms
 
         dataset = STL10(self.data_dir, split='train', download=False, transform=transforms)
@@ -278,7 +278,7 @@ def train_dataloader_labeled(self):
         )
         return loader
 
-    def val_dataloader_labeled(self):
+    def val_dataloader_labeled(self) -> DataLoader:
         transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms
         dataset = STL10(self.data_dir,
                         split='train',
@@ -299,7 +299,7 @@ def val_dataloader_labeled(self):
         )
         return loader
 
-    def _default_transforms(self):
+    def _default_transforms(self) -> transform_lib.Compose:
         data_transforms = transform_lib.Compose([
             transform_lib.ToTensor(),
             stl10_normalization()
diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py
index 2144f0f509..5b8c508904 100644
--- a/pl_bolts/datamodules/vision_datamodule.py
+++ b/pl_bolts/datamodules/vision_datamodule.py
@@ -6,6 +6,14 @@
 from pytorch_lightning import LightningDataModule
 from torch.utils.data import DataLoader, Dataset, random_split
 
+from pl_bolts.utils import _TORCHVISION_AVAILABLE
+from pl_bolts.utils.warnings import warn_missing_pkg
+
+if _TORCHVISION_AVAILABLE:
+    from torchvision import transforms as transform_lib
+else:
+    warn_missing_pkg('torchvision')  # pragma: no-cover
+
 
 class VisionDataModule(LightningDataModule):
 
@@ -29,7 +37,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ) -> None:
+    ):
         """
         Args:
             data_dir: Where to save/load the data
@@ -56,14 +64,14 @@ def __init__(
         self.pin_memory = pin_memory
         self.drop_last = drop_last
 
-    def prepare_data(self) -> None:
+    def prepare_data(self):
         """
         Saves files to data_dir
         """
         self.dataset_cls(self.data_dir, train=True, download=True)
         self.dataset_cls(self.data_dir, train=False, download=True)
 
-    def setup(self, stage: Optional[str] = None) -> None:
+    def setup(self, stage: Optional[str] = None):
         """
         Creates train, val, and test dataset
         """
@@ -115,7 +123,7 @@ def _get_splits(self, len_dataset: int) -> List[int]:
         return splits
 
     @abstractmethod
-    def default_transforms(self):
+    def default_transforms(self) -> transform_lib.Compose:
         """ Default transform for the dataset """
 
     def train_dataloader(self) -> DataLoader:

From cae3f461afe23ef67f8e11b80ef19cce72905e93 Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Fri, 18 Dec 2020 20:20:20 +0900
Subject: [PATCH 12/75] Fixing typing imports

---
 pl_bolts/datamodules/async_dataloader.py       | 14 +++++++++++---
 pl_bolts/datamodules/cityscapes_datamodule.py  |  2 ++
 pl_bolts/datamodules/imagenet_datamodule.py    |  2 +-
 pl_bolts/datamodules/kitti_datamodule.py       |  3 ++-
 pl_bolts/datamodules/sklearn_datamodule.py     |  4 ++--
 .../datamodules/ssl_imagenet_datamodule.py     |  1 +
 pl_bolts/datamodules/stl10_datamodule.py       |  2 +-
 pl_bolts/datamodules/vision_datamodule.py      |  7 +------
 .../datamodules/vocdetection_datamodule.py     | 18 ++++++++++--------
 9 files changed, 31 insertions(+), 22 deletions(-)

diff --git a/pl_bolts/datamodules/async_dataloader.py b/pl_bolts/datamodules/async_dataloader.py
index 7ded9d9ef1..38a0b9bb58 100644
--- a/pl_bolts/datamodules/async_dataloader.py
+++ b/pl_bolts/datamodules/async_dataloader.py
@@ -1,10 +1,11 @@
 import re
 from queue import Queue
 from threading import Thread
+from typing import Any, Optional, Union
 
 import torch
 from torch._six import container_abcs, string_classes
-from torch.utils.data import DataLoader
+from torch.utils.data import DataLoader, Dataset
 
 
 class AsynchronousLoader(object):
@@ -26,7 +27,14 @@ class AsynchronousLoader(object):
             constructing one here
     """
 
-    def __init__(self, data, device=torch.device('cuda', 0), q_size=10, num_batches=None, **kwargs):
+    def __init__(
+            self,
+            data: Union[DataLoader, Dataset],
+            device: torch.device = torch.device('cuda', 0),
+            q_size: int = 10,
+            num_batches: Optional[int] = None,
+            **kwargs: Any
+    ):
         if isinstance(data, torch.utils.data.DataLoader):
             self.dataloader = data
         else:
@@ -105,5 +113,5 @@ def __next__(self):
         self.idx += 1
         return out
 
-    def __len__(self):
+    def __len__(self) -> int:
         return self.num_batches
diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py
index d27bfc3196..17812a0ac5 100644
--- a/pl_bolts/datamodules/cityscapes_datamodule.py
+++ b/pl_bolts/datamodules/cityscapes_datamodule.py
@@ -1,3 +1,5 @@
+from typing import Any
+
 from pytorch_lightning import LightningDataModule
 from torch.utils.data import DataLoader
 
diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py
index a31b637ba9..829c485aed 100644
--- a/pl_bolts/datamodules/imagenet_datamodule.py
+++ b/pl_bolts/datamodules/imagenet_datamodule.py
@@ -1,5 +1,5 @@
 import os
-from typing import Optional
+from typing import Any, Optional
 
 from pytorch_lightning import LightningDataModule
 from torch.utils.data import DataLoader
diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py
index a50528028b..9a82f0b7ec 100644
--- a/pl_bolts/datamodules/kitti_datamodule.py
+++ b/pl_bolts/datamodules/kitti_datamodule.py
@@ -1,4 +1,5 @@
 import os
+from typing import Any, Optional
 
 import torch
 from pytorch_lightning import LightningDataModule
@@ -133,7 +134,7 @@ def test_dataloader(self) -> DataLoader:
         )
         return loader
 
-    def _default_transforms(self) -> transform_lib.Compose:
+    def _default_transforms(self) -> transforms.Compose:
         kitti_transforms = transforms.Compose([
             transforms.ToTensor(),
             transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753],
diff --git a/pl_bolts/datamodules/sklearn_datamodule.py b/pl_bolts/datamodules/sklearn_datamodule.py
index dcfd559441..e80e1dfc9a 100644
--- a/pl_bolts/datamodules/sklearn_datamodule.py
+++ b/pl_bolts/datamodules/sklearn_datamodule.py
@@ -1,5 +1,5 @@
 import math
-from typing import Any
+from typing import Any, Tuple
 
 import numpy as np
 import torch
@@ -145,7 +145,7 @@ def __init__(
             x_val=None, y_val=None,
             x_test=None, y_test=None,
             val_split=0.2, test_split=0.1,
-            num_workers:int = 2,
+            num_workers: int = 2,
             random_state: int = 1234,
             shuffle: bool = True,
             batch_size: int = 16,
diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py
index 50315245af..1584583101 100644
--- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py
+++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py
@@ -1,4 +1,5 @@
 import os
+from typing import Any, Optional
 
 from pytorch_lightning import LightningDataModule
 from torch.utils.data import DataLoader
diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py
index 3b29995a1b..8d46cfd7bf 100644
--- a/pl_bolts/datamodules/stl10_datamodule.py
+++ b/pl_bolts/datamodules/stl10_datamodule.py
@@ -1,5 +1,5 @@
 import os
-from typing import Optional
+from typing import Any, Optional
 
 import torch
 from pytorch_lightning import LightningDataModule
diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py
index 5b8c508904..cdcefcb2eb 100644
--- a/pl_bolts/datamodules/vision_datamodule.py
+++ b/pl_bolts/datamodules/vision_datamodule.py
@@ -9,11 +9,6 @@
 from pl_bolts.utils import _TORCHVISION_AVAILABLE
 from pl_bolts.utils.warnings import warn_missing_pkg
 
-if _TORCHVISION_AVAILABLE:
-    from torchvision import transforms as transform_lib
-else:
-    warn_missing_pkg('torchvision')  # pragma: no-cover
-
 
 class VisionDataModule(LightningDataModule):
 
@@ -123,7 +118,7 @@ def _get_splits(self, len_dataset: int) -> List[int]:
         return splits
 
     @abstractmethod
-    def default_transforms(self) -> transform_lib.Compose:
+    def default_transforms(self):
         """ Default transform for the dataset """
 
     def train_dataloader(self) -> DataLoader:
diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py
index a2087f9448..3f3767b4f8 100644
--- a/pl_bolts/datamodules/vocdetection_datamodule.py
+++ b/pl_bolts/datamodules/vocdetection_datamodule.py
@@ -1,3 +1,5 @@
+from typing import Any, Dict
+
 import torch
 from pytorch_lightning import LightningDataModule
 from torch.utils.data import DataLoader
@@ -17,7 +19,7 @@ class Compose(object):
     Like `torchvision.transforms.compose` but works for (image, target)
     """
 
-    def __init__(self, transforms):
+    def __init__(self, transforms: T.Compose):
         self.transforms = transforms
 
     def __call__(self, image, target):
@@ -55,7 +57,7 @@ def _collate_fn(batch):
 )
 
 
-def _prepare_voc_instance(image, target):
+def _prepare_voc_instance(image, target: Dict[str, Any]):
     """
     Prepares VOC dataset into appropriate target for fasterrcnn
 
@@ -113,8 +115,8 @@ def __init__(
         shuffle: bool = False,
         pin_memory: bool = False,
         drop_last: bool = False,
-        *args,
-        **kwargs,
+        *args: Any,
+        **kwargs: Any,
     ):
         if not _TORCHVISION_AVAILABLE:
             raise ModuleNotFoundError(  # pragma: no-cover
@@ -132,7 +134,7 @@ def __init__(
         self.drop_last = drop_last
 
     @property
-    def num_classes(self):
+    def num_classes(self) -> int:
         """
         Return:
             21
@@ -146,7 +148,7 @@ def prepare_data(self):
         VOCDetection(self.data_dir, year=self.year, image_set="train", download=True)
         VOCDetection(self.data_dir, year=self.year, image_set="val", download=True)
 
-    def train_dataloader(self, batch_size=1, transforms=None):
+    def train_dataloader(self, batch_size: int = 1, transforms=None) -> DataLoader:
         """
         VOCDetection train set uses the `train` subset
 
@@ -174,7 +176,7 @@ def train_dataloader(self, batch_size=1, transforms=None):
         )
         return loader
 
-    def val_dataloader(self, batch_size=1, transforms=None):
+    def val_dataloader(self, batch_size: int = 1, transforms=None) -> DataLoader:
         """
         VOCDetection val set uses the `val` subset
 
@@ -201,7 +203,7 @@ def val_dataloader(self, batch_size=1, transforms=None):
         )
         return loader
 
-    def _default_transforms(self):
+    def _default_transforms(self) -> T.Compose:
         if self.normalize:
             return (
                 lambda image, target: (

From 7af027e43bf96fe8d171c1a2c34da174acbbd7ae Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Sun, 20 Dec 2020 03:03:00 +0900
Subject: [PATCH 13/75] Removing torchvision.transforms from return typing

---
 pl_bolts/datamodules/binary_mnist_datamodule.py  | 2 +-
 pl_bolts/datamodules/cifar10_datamodule.py       | 2 +-
 pl_bolts/datamodules/cityscapes_datamodule.py    | 4 ++--
 pl_bolts/datamodules/fashion_mnist_datamodule.py | 2 +-
 pl_bolts/datamodules/imagenet_datamodule.py      | 4 ++--
 pl_bolts/datamodules/kitti_datamodule.py         | 2 +-
 pl_bolts/datamodules/mnist_datamodule.py         | 2 +-
 pl_bolts/datamodules/ssl_imagenet_datamodule.py  | 2 +-
 pl_bolts/datamodules/stl10_datamodule.py         | 2 +-
 pl_bolts/datamodules/vocdetection_datamodule.py  | 2 +-
 10 files changed, 12 insertions(+), 12 deletions(-)

diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py
index c713abe107..8dc02ec95e 100644
--- a/pl_bolts/datamodules/binary_mnist_datamodule.py
+++ b/pl_bolts/datamodules/binary_mnist_datamodule.py
@@ -98,7 +98,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self) -> transform_lib.Compose:
+    def default_transforms(self):
         if self.normalize:
             mnist_transforms = transform_lib.Compose(
                 [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]
diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py
index 12aea1ec87..b208172ed0 100644
--- a/pl_bolts/datamodules/cifar10_datamodule.py
+++ b/pl_bolts/datamodules/cifar10_datamodule.py
@@ -112,7 +112,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self) -> transform_lib.Compose:
+    def default_transforms(self):
         if self.normalize:
             cf10_transforms = transform_lib.Compose([transform_lib.ToTensor(), cifar10_normalization()])
         else:
diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py
index 17812a0ac5..dc8b866fba 100644
--- a/pl_bolts/datamodules/cityscapes_datamodule.py
+++ b/pl_bolts/datamodules/cityscapes_datamodule.py
@@ -192,7 +192,7 @@ def test_dataloader(self) -> DataLoader:
         )
         return loader
 
-    def _default_transforms(self) -> transform_lib.Compose:
+    def _default_transforms(self):
         cityscapes_transforms = transform_lib.Compose([
             transform_lib.ToTensor(),
             transform_lib.Normalize(
@@ -202,7 +202,7 @@ def _default_transforms(self) -> transform_lib.Compose:
         ])
         return cityscapes_transforms
 
-    def _default_target_transforms(self) -> transform_lib.Compose:
+    def _default_target_transforms(self):
         cityscapes_target_trasnforms = transform_lib.Compose([
             transform_lib.ToTensor(),
             transform_lib.Lambda(lambda t: t.squeeze())
diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py
index 9a73e6a637..a128ddfaab 100644
--- a/pl_bolts/datamodules/fashion_mnist_datamodule.py
+++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py
@@ -93,7 +93,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self) -> transform_lib.Compose:
+    def default_transforms(self):
         if self.normalize:
             mnist_transforms = transform_lib.Compose(
                 [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]
diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py
index 829c485aed..f8d8262108 100644
--- a/pl_bolts/datamodules/imagenet_datamodule.py
+++ b/pl_bolts/datamodules/imagenet_datamodule.py
@@ -206,7 +206,7 @@ def test_dataloader(self) -> DataLoader:
         )
         return loader
 
-    def train_transform(self) -> transform_lib.Compose:
+    def train_transform(self):
         """
         The standard imagenet transforms
 
@@ -232,7 +232,7 @@ def train_transform(self) -> transform_lib.Compose:
 
         return preprocessing
 
-    def val_transform(self) -> transform_lib.Compose:
+    def val_transform(self):
         """
         The standard imagenet transforms for validation
 
diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py
index 9a82f0b7ec..06778fdbfc 100644
--- a/pl_bolts/datamodules/kitti_datamodule.py
+++ b/pl_bolts/datamodules/kitti_datamodule.py
@@ -134,7 +134,7 @@ def test_dataloader(self) -> DataLoader:
         )
         return loader
 
-    def _default_transforms(self) -> transforms.Compose:
+    def _default_transforms(self):
         kitti_transforms = transforms.Compose([
             transforms.ToTensor(),
             transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753],
diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py
index 76a0438a0c..c57fe8ca82 100644
--- a/pl_bolts/datamodules/mnist_datamodule.py
+++ b/pl_bolts/datamodules/mnist_datamodule.py
@@ -92,7 +92,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self) -> transform_lib.Compose:
+    def default_transforms(self):
         if self.normalize:
             mnist_transforms = transform_lib.Compose(
                 [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]
diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py
index 1584583101..96041fd4d9 100644
--- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py
+++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py
@@ -134,7 +134,7 @@ def test_dataloader(self, num_images_per_class: int, add_normalize: bool = False
         )
         return loader
 
-    def _default_transforms(self) -> transform_lib.Compose:
+    def _default_transforms(self):
         mnist_transforms = transform_lib.Compose([
             transform_lib.ToTensor(),
             imagenet_normalization()
diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py
index 8d46cfd7bf..5cd680a535 100644
--- a/pl_bolts/datamodules/stl10_datamodule.py
+++ b/pl_bolts/datamodules/stl10_datamodule.py
@@ -299,7 +299,7 @@ def val_dataloader_labeled(self) -> DataLoader:
         )
         return loader
 
-    def _default_transforms(self) -> transform_lib.Compose:
+    def _default_transforms(self):
         data_transforms = transform_lib.Compose([
             transform_lib.ToTensor(),
             stl10_normalization()
diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py
index 3f3767b4f8..e0768b42d5 100644
--- a/pl_bolts/datamodules/vocdetection_datamodule.py
+++ b/pl_bolts/datamodules/vocdetection_datamodule.py
@@ -203,7 +203,7 @@ def val_dataloader(self, batch_size: int = 1, transforms=None) -> DataLoader:
         )
         return loader
 
-    def _default_transforms(self) -> T.Compose:
+    def _default_transforms(self):
         if self.normalize:
             return (
                 lambda image, target: (

From 7bec6059691a3b025a37e6e5d37d8d8a218832bf Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Sun, 20 Dec 2020 03:09:03 +0900
Subject: [PATCH 14/75] Remove more torchvision.transforms typing

---
 pl_bolts/datamodules/vocdetection_datamodule.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py
index e0768b42d5..3753e727ad 100644
--- a/pl_bolts/datamodules/vocdetection_datamodule.py
+++ b/pl_bolts/datamodules/vocdetection_datamodule.py
@@ -19,7 +19,7 @@ class Compose(object):
     Like `torchvision.transforms.compose` but works for (image, target)
     """
 
-    def __init__(self, transforms: T.Compose):
+    def __init__(self, transforms):
         self.transforms = transforms
 
     def __call__(self, image, target):

From afbc918f9e53d93aeba5ea7ff0072b073e395846 Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Mon, 21 Dec 2020 10:13:11 +0900
Subject: [PATCH 15/75] Removing return typing

---
 pl_bolts/datamodules/cifar10_datamodule.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py
index b208172ed0..85ba4de6e7 100644
--- a/pl_bolts/datamodules/cifar10_datamodule.py
+++ b/pl_bolts/datamodules/cifar10_datamodule.py
@@ -71,7 +71,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ) -> None:
+    ):
         """
         Args:
             data_dir: Where to save/load the data

From 17ce335c0deb8b6ac2252d03f405a91eff9d3425 Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Mon, 21 Dec 2020 16:00:55 +0900
Subject: [PATCH 16/75] Add `None` for optional arguments

---
 pl_bolts/datamodules/cifar10_datamodule.py | 2 +-
 pl_bolts/datamodules/kitti_datamodule.py   | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py
index 85ba4de6e7..534774684f 100644
--- a/pl_bolts/datamodules/cifar10_datamodule.py
+++ b/pl_bolts/datamodules/cifar10_datamodule.py
@@ -146,7 +146,7 @@ class TinyCIFAR10DataModule(CIFAR10DataModule):
 
     def __init__(
         self,
-        data_dir: Optional[str],
+        data_dir: Optional[str] = None,
         val_split: int = 50,
         num_workers: int = 16,
         num_samples: int = 100,
diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py
index 06778fdbfc..0067a1e53d 100644
--- a/pl_bolts/datamodules/kitti_datamodule.py
+++ b/pl_bolts/datamodules/kitti_datamodule.py
@@ -22,7 +22,7 @@ class KittiDataModule(LightningDataModule):
 
     def __init__(
             self,
-            data_dir: Optional[str],
+            data_dir: Optional[str] = None,
             val_split: float = 0.2,
             test_split: float = 0.1,
             num_workers: int = 16,

From d09f98d708f4cb75d1d8476bed0fe201f3bea9b9 Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Mon, 21 Dec 2020 16:04:36 +0900
Subject: [PATCH 17/75] Remove unnecessary import

---
 pl_bolts/datamodules/vision_datamodule.py | 3 ---
 1 file changed, 3 deletions(-)

diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py
index cdcefcb2eb..42252c4edf 100644
--- a/pl_bolts/datamodules/vision_datamodule.py
+++ b/pl_bolts/datamodules/vision_datamodule.py
@@ -6,9 +6,6 @@
 from pytorch_lightning import LightningDataModule
 from torch.utils.data import DataLoader, Dataset, random_split
 
-from pl_bolts.utils import _TORCHVISION_AVAILABLE
-from pl_bolts.utils.warnings import warn_missing_pkg
-
 
 class VisionDataModule(LightningDataModule):
 

From b61fdc07f107fc74d1cc0432f29561cc494f7b65 Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Mon, 21 Dec 2020 16:51:36 +0900
Subject: [PATCH 18/75] Add `None` return type

---
 pl_bolts/datamodules/async_dataloader.py         |  4 ++--
 pl_bolts/datamodules/binary_mnist_datamodule.py  |  2 +-
 pl_bolts/datamodules/cifar10_datamodule.py       |  4 ++--
 pl_bolts/datamodules/cityscapes_datamodule.py    |  2 +-
 pl_bolts/datamodules/experience_source.py        |  4 ++--
 pl_bolts/datamodules/fashion_mnist_datamodule.py |  2 +-
 pl_bolts/datamodules/imagenet_datamodule.py      |  6 +++---
 pl_bolts/datamodules/kitti_datamodule.py         |  2 +-
 pl_bolts/datamodules/mnist_datamodule.py         |  2 +-
 pl_bolts/datamodules/sklearn_datamodule.py       | 12 ++++++++----
 pl_bolts/datamodules/ssl_imagenet_datamodule.py  |  6 +++---
 pl_bolts/datamodules/stl10_datamodule.py         |  4 ++--
 pl_bolts/datamodules/vision_datamodule.py        |  6 +++---
 pl_bolts/datamodules/vocdetection_datamodule.py  |  6 +++---
 14 files changed, 33 insertions(+), 29 deletions(-)

diff --git a/pl_bolts/datamodules/async_dataloader.py b/pl_bolts/datamodules/async_dataloader.py
index 38a0b9bb58..224f34d5ee 100644
--- a/pl_bolts/datamodules/async_dataloader.py
+++ b/pl_bolts/datamodules/async_dataloader.py
@@ -34,7 +34,7 @@ def __init__(
             q_size: int = 10,
             num_batches: Optional[int] = None,
             **kwargs: Any
-    ):
+    ) -> None:
         if isinstance(data, torch.utils.data.DataLoader):
             self.dataloader = data
         else:
@@ -57,7 +57,7 @@ def __init__(
 
         self.np_str_obj_array_pattern = re.compile(r'[SaUO]')
 
-    def load_loop(self):  # The loop that will load into the queue in the background
+    def load_loop(self) -> None:  # The loop that will load into the queue in the background
         for i, sample in enumerate(self.dataloader):
             self.queue.put(self.load_instance(sample))
             if i == len(self):
diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py
index 8dc02ec95e..142b3d54ef 100644
--- a/pl_bolts/datamodules/binary_mnist_datamodule.py
+++ b/pl_bolts/datamodules/binary_mnist_datamodule.py
@@ -56,7 +56,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: Where to save/load the data
diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py
index 534774684f..2cb894d749 100644
--- a/pl_bolts/datamodules/cifar10_datamodule.py
+++ b/pl_bolts/datamodules/cifar10_datamodule.py
@@ -71,7 +71,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: Where to save/load the data
@@ -153,7 +153,7 @@ def __init__(
         labels: Optional[Sequence] = (1, 5, 8),
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: where to save/load the data
diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py
index dc8b866fba..61c1ae2bef 100644
--- a/pl_bolts/datamodules/cityscapes_datamodule.py
+++ b/pl_bolts/datamodules/cityscapes_datamodule.py
@@ -73,7 +73,7 @@ def __init__(
             drop_last: bool = False,
             *args: Any,
             **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: where to load the data from path, i.e. where directory leftImg8bit and gtFine or gtCoarse
diff --git a/pl_bolts/datamodules/experience_source.py b/pl_bolts/datamodules/experience_source.py
index 5fe1332dfd..fac4e82f25 100644
--- a/pl_bolts/datamodules/experience_source.py
+++ b/pl_bolts/datamodules/experience_source.py
@@ -30,7 +30,7 @@ class ExperienceSourceDataset(IterableDataset):
     The logic for the experience source and how the batch is generated is defined the Lightning model itself
     """
 
-    def __init__(self, generate_batch: Callable):
+    def __init__(self, generate_batch: Callable) -> None:
         self.generate_batch = generate_batch
 
     def __iter__(self) -> Iterable:
@@ -243,7 +243,7 @@ def pop_rewards_steps(self):
 class DiscountedExperienceSource(ExperienceSource):
     """Outputs experiences with a discounted reward over N steps"""
 
-    def __init__(self, env: Env, agent, n_steps: int = 1, gamma: float = 0.99):
+    def __init__(self, env: Env, agent, n_steps: int = 1, gamma: float = 0.99) -> None:
         super().__init__(env, agent, (n_steps + 1))
         self.gamma = gamma
         self.steps = n_steps
diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py
index a128ddfaab..833c4599a6 100644
--- a/pl_bolts/datamodules/fashion_mnist_datamodule.py
+++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py
@@ -57,7 +57,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: Where to save/load the data
diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py
index f8d8262108..6f06913f9f 100644
--- a/pl_bolts/datamodules/imagenet_datamodule.py
+++ b/pl_bolts/datamodules/imagenet_datamodule.py
@@ -60,7 +60,7 @@ def __init__(
             drop_last: bool = False,
             *args: Any,
             **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: path to the imagenet dataset file
@@ -103,14 +103,14 @@ def num_classes(self) -> int:
         """
         return 1000
 
-    def _verify_splits(self, data_dir: str, split: str):
+    def _verify_splits(self, data_dir: str, split: str) -> None:
         dirs = os.listdir(data_dir)
 
         if split not in dirs:
             raise FileNotFoundError(f'a {split} Imagenet split was not found in {data_dir},'
                                     f' make sure the folder contains a subfolder named {split}')
 
-    def prepare_data(self):
+    def prepare_data(self) -> None:
         """
         This method already assumes you have imagenet2012 downloaded.
         It validates the data using the meta.bin.
diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py
index 0067a1e53d..3cf26dc762 100644
--- a/pl_bolts/datamodules/kitti_datamodule.py
+++ b/pl_bolts/datamodules/kitti_datamodule.py
@@ -33,7 +33,7 @@ def __init__(
             drop_last: bool = False,
             *args: Any,
             **kwargs: Any,
-    ):
+    ) -> None:
         """
         Kitti train, validation and test dataloaders.
 
diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py
index c57fe8ca82..1dd5e927b6 100644
--- a/pl_bolts/datamodules/mnist_datamodule.py
+++ b/pl_bolts/datamodules/mnist_datamodule.py
@@ -56,7 +56,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: Where to save/load the data
diff --git a/pl_bolts/datamodules/sklearn_datamodule.py b/pl_bolts/datamodules/sklearn_datamodule.py
index e80e1dfc9a..d9477acc0b 100644
--- a/pl_bolts/datamodules/sklearn_datamodule.py
+++ b/pl_bolts/datamodules/sklearn_datamodule.py
@@ -28,7 +28,9 @@ class SklearnDataset(Dataset):
         >>> len(dataset)
         506
     """
-    def __init__(self, X: np.ndarray, y: np.ndarray, X_transform: Any = None, y_transform: Any = None):
+    def __init__(
+        self, X: np.ndarray, y: np.ndarray, X_transform: Any = None, y_transform: Any = None
+    ) -> None:
         """
         Args:
             X: Numpy ndarray
@@ -75,7 +77,9 @@ class TensorDataset(Dataset):
         >>> len(dataset)
         10
     """
-    def __init__(self, X: torch.Tensor, y: torch.Tensor, X_transform: Any = None, y_transform: Any = None):
+    def __init__(
+        self, X: torch.Tensor, y: torch.Tensor, X_transform: Any = None, y_transform: Any = None
+    ) -> None:
         """
         Args:
             X: PyTorch tensor
@@ -153,7 +157,7 @@ def __init__(
             drop_last: bool = False,
             *args: Any,
             **kwargs: Any,
-    ):
+    ) -> None:
 
         super().__init__(*args, **kwargs)
         self.num_workers = num_workers
@@ -201,7 +205,7 @@ def _init_datasets(
         y_val: np.ndarray,
         x_test: np.ndarray,
         y_test: np.ndarray
-    ):
+    ) -> None:
         self.train_dataset = SklearnDataset(X, y)
         self.val_dataset = SklearnDataset(x_val, y_val)
         self.test_dataset = SklearnDataset(x_test, y_test)
diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py
index 96041fd4d9..3dbda03527 100644
--- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py
+++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py
@@ -30,7 +30,7 @@ def __init__(
             drop_last: bool = False,
             *args: Any,
             **kwargs: Any,
-    ):
+    ) -> None:
         super().__init__(*args, **kwargs)
 
         if not _TORCHVISION_AVAILABLE:
@@ -50,14 +50,14 @@ def __init__(
     def num_classes(self) -> int:
         return 1000
 
-    def _verify_splits(self, data_dir: str, split: str):
+    def _verify_splits(self, data_dir: str, split: str) -> None:
         dirs = os.listdir(data_dir)
 
         if split not in dirs:
             raise FileNotFoundError(f'a {split} Imagenet split was not found in {data_dir}, make sure the'
                                     f' folder contains a subfolder named {split}')
 
-    def prepare_data(self):
+    def prepare_data(self) -> None:
         # imagenet cannot be downloaded... must provide path to folder with the train/val splits
         self._verify_splits(self.data_dir, 'train')
         self._verify_splits(self.data_dir, 'val')
diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py
index 5cd680a535..79420be149 100644
--- a/pl_bolts/datamodules/stl10_datamodule.py
+++ b/pl_bolts/datamodules/stl10_datamodule.py
@@ -65,7 +65,7 @@ def __init__(
             drop_last: bool = False,
             *args: Any,
             **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: where to save/load the data
@@ -102,7 +102,7 @@ def __init__(
     def num_classes(self) -> int:
         return 10
 
-    def prepare_data(self):
+    def prepare_data(self) -> None:
         """
         Downloads the unlabeled, train and test split
         """
diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py
index 42252c4edf..2144f0f509 100644
--- a/pl_bolts/datamodules/vision_datamodule.py
+++ b/pl_bolts/datamodules/vision_datamodule.py
@@ -29,7 +29,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: Where to save/load the data
@@ -56,14 +56,14 @@ def __init__(
         self.pin_memory = pin_memory
         self.drop_last = drop_last
 
-    def prepare_data(self):
+    def prepare_data(self) -> None:
         """
         Saves files to data_dir
         """
         self.dataset_cls(self.data_dir, train=True, download=True)
         self.dataset_cls(self.data_dir, train=False, download=True)
 
-    def setup(self, stage: Optional[str] = None):
+    def setup(self, stage: Optional[str] = None) -> None:
         """
         Creates train, val, and test dataset
         """
diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py
index 3753e727ad..52d5065d97 100644
--- a/pl_bolts/datamodules/vocdetection_datamodule.py
+++ b/pl_bolts/datamodules/vocdetection_datamodule.py
@@ -19,7 +19,7 @@ class Compose(object):
     Like `torchvision.transforms.compose` but works for (image, target)
     """
 
-    def __init__(self, transforms):
+    def __init__(self, transforms) -> None:
         self.transforms = transforms
 
     def __call__(self, image, target):
@@ -117,7 +117,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         if not _TORCHVISION_AVAILABLE:
             raise ModuleNotFoundError(  # pragma: no-cover
                 'You want to use VOC dataset loaded from `torchvision` which is not installed yet.'
@@ -141,7 +141,7 @@ def num_classes(self) -> int:
         """
         return 21
 
-    def prepare_data(self):
+    def prepare_data(self) -> None:
         """
         Saves VOCDetection files to data_dir
         """

From f2f4305d9b67f18840be494d966bea2870ee0d4b Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Tue, 5 Jan 2021 20:30:43 +0900
Subject: [PATCH 19/75] Add type for torchvision transforms

---
 pl_bolts/datamodules/binary_mnist_datamodule.py  | 4 +++-
 pl_bolts/datamodules/cifar10_datamodule.py       | 4 +++-
 pl_bolts/datamodules/cityscapes_datamodule.py    | 6 ++++--
 pl_bolts/datamodules/fashion_mnist_datamodule.py | 4 +++-
 pl_bolts/datamodules/imagenet_datamodule.py      | 6 ++++--
 pl_bolts/datamodules/kitti_datamodule.py         | 4 +++-
 pl_bolts/datamodules/mnist_datamodule.py         | 4 +++-
 pl_bolts/datamodules/ssl_imagenet_datamodule.py  | 4 +++-
 pl_bolts/datamodules/stl10_datamodule.py         | 4 +++-
 pl_bolts/datamodules/vision_datamodule.py        | 9 ++++++++-
 10 files changed, 37 insertions(+), 12 deletions(-)

diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py
index 142b3d54ef..ad17360d08 100644
--- a/pl_bolts/datamodules/binary_mnist_datamodule.py
+++ b/pl_bolts/datamodules/binary_mnist_datamodule.py
@@ -7,8 +7,10 @@
 
 if _TORCHVISION_AVAILABLE:
     from torchvision import transforms as transform_lib
+    from torchvision.transforms import Compose
 else:  # pragma: no-cover
     warn_missing_pkg('torchvision')
+    Compose = object
 
 
 class BinaryMNISTDataModule(VisionDataModule):
@@ -98,7 +100,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self):
+    def default_transforms(self) -> Compose:
         if self.normalize:
             mnist_transforms = transform_lib.Compose(
                 [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]
diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py
index 2cb894d749..9dbf10b670 100644
--- a/pl_bolts/datamodules/cifar10_datamodule.py
+++ b/pl_bolts/datamodules/cifar10_datamodule.py
@@ -9,9 +9,11 @@
 if _TORCHVISION_AVAILABLE:
     from torchvision import transforms as transform_lib
     from torchvision.datasets import CIFAR10
+    from torchvision.transforms import Compose
 else:
     warn_missing_pkg('torchvision')  # pragma: no-cover
     CIFAR10 = None
+    Compose = object
 
 
 class CIFAR10DataModule(VisionDataModule):
@@ -112,7 +114,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self):
+    def default_transforms(self) -> Compose:
         if self.normalize:
             cf10_transforms = transform_lib.Compose([transform_lib.ToTensor(), cifar10_normalization()])
         else:
diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py
index 61c1ae2bef..130e333976 100644
--- a/pl_bolts/datamodules/cityscapes_datamodule.py
+++ b/pl_bolts/datamodules/cityscapes_datamodule.py
@@ -9,8 +9,10 @@
 if _TORCHVISION_AVAILABLE:
     from torchvision import transforms as transform_lib
     from torchvision.datasets import Cityscapes
+    from torchvision.transforms import Compose
 else:
     warn_missing_pkg('torchvision')  # pragma: no-cover
+    Compose = object
 
 
 class CityscapesDataModule(LightningDataModule):
@@ -192,7 +194,7 @@ def test_dataloader(self) -> DataLoader:
         )
         return loader
 
-    def _default_transforms(self):
+    def _default_transforms(self) -> Compose:
         cityscapes_transforms = transform_lib.Compose([
             transform_lib.ToTensor(),
             transform_lib.Normalize(
@@ -202,7 +204,7 @@ def _default_transforms(self):
         ])
         return cityscapes_transforms
 
-    def _default_target_transforms(self):
+    def _default_target_transforms(self) -> Compose:
         cityscapes_target_trasnforms = transform_lib.Compose([
             transform_lib.ToTensor(),
             transform_lib.Lambda(lambda t: t.squeeze())
diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py
index 833c4599a6..b37221bc74 100644
--- a/pl_bolts/datamodules/fashion_mnist_datamodule.py
+++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py
@@ -7,9 +7,11 @@
 if _TORCHVISION_AVAILABLE:
     from torchvision import transforms as transform_lib
     from torchvision.datasets import FashionMNIST
+    from torchvision.transforms import Compose
 else:
     warn_missing_pkg('torchvision')  # pragma: no-cover
     FashionMNIST = None
+    Compose = object
 
 
 class FashionMNISTDataModule(VisionDataModule):
@@ -93,7 +95,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self):
+    def default_transforms(self) -> Compose:
         if self.normalize:
             mnist_transforms = transform_lib.Compose(
                 [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]
diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py
index 6f06913f9f..db2fc68c0b 100644
--- a/pl_bolts/datamodules/imagenet_datamodule.py
+++ b/pl_bolts/datamodules/imagenet_datamodule.py
@@ -11,8 +11,10 @@
 
 if _TORCHVISION_AVAILABLE:
     from torchvision import transforms as transform_lib
+    from torchvision.transforms import Compose
 else:
     warn_missing_pkg('torchvision')  # pragma: no-cover
+    Compose = object
 
 
 class ImagenetDataModule(LightningDataModule):
@@ -206,7 +208,7 @@ def test_dataloader(self) -> DataLoader:
         )
         return loader
 
-    def train_transform(self):
+    def train_transform(self) -> Compose:
         """
         The standard imagenet transforms
 
@@ -232,7 +234,7 @@ def train_transform(self):
 
         return preprocessing
 
-    def val_transform(self):
+    def val_transform(self) -> Compose:
         """
         The standard imagenet transforms for validation
 
diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py
index 3cf26dc762..b63040f3bf 100644
--- a/pl_bolts/datamodules/kitti_datamodule.py
+++ b/pl_bolts/datamodules/kitti_datamodule.py
@@ -12,8 +12,10 @@
 
 if _TORCHVISION_AVAILABLE:
     import torchvision.transforms as transforms
+    from torchvision.transforms import Compose
 else:  # pragma: no cover
     warn_missing_pkg('torchvision')
+    Compose = object
 
 
 class KittiDataModule(LightningDataModule):
@@ -134,7 +136,7 @@ def test_dataloader(self) -> DataLoader:
         )
         return loader
 
-    def _default_transforms(self):
+    def _default_transforms(self) -> Compose:
         kitti_transforms = transforms.Compose([
             transforms.ToTensor(),
             transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753],
diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py
index 1dd5e927b6..711460023c 100644
--- a/pl_bolts/datamodules/mnist_datamodule.py
+++ b/pl_bolts/datamodules/mnist_datamodule.py
@@ -7,9 +7,11 @@
 if _TORCHVISION_AVAILABLE:
     from torchvision import transforms as transform_lib
     from torchvision.datasets import MNIST
+    from torchvision.transforms import Compose
 else:
     warn_missing_pkg('torchvision')  # pragma: no-cover
     MNIST = None
+    Compose = object
 
 
 class MNISTDataModule(VisionDataModule):
@@ -92,7 +94,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self):
+    def default_transforms(self) -> Compose:
         if self.normalize:
             mnist_transforms = transform_lib.Compose(
                 [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]
diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py
index 3dbda03527..d575eb2d01 100644
--- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py
+++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py
@@ -11,8 +11,10 @@
 
 if _TORCHVISION_AVAILABLE:
     from torchvision import transforms as transform_lib
+    from torchvision.transforms import Compose
 else:
     warn_missing_pkg('torchvision')  # pragma: no-cover
+    Compose = object
 
 
 class SSLImagenetDataModule(LightningDataModule):  # pragma: no cover
@@ -134,7 +136,7 @@ def test_dataloader(self, num_images_per_class: int, add_normalize: bool = False
         )
         return loader
 
-    def _default_transforms(self):
+    def _default_transforms(self) -> Compose:
         mnist_transforms = transform_lib.Compose([
             transform_lib.ToTensor(),
             imagenet_normalization()
diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py
index 79420be149..f9ac77e140 100644
--- a/pl_bolts/datamodules/stl10_datamodule.py
+++ b/pl_bolts/datamodules/stl10_datamodule.py
@@ -13,8 +13,10 @@
 if _TORCHVISION_AVAILABLE:
     from torchvision import transforms as transform_lib
     from torchvision.datasets import STL10
+    from torchvision.transforms import Compose
 else:
     warn_missing_pkg('torchvision')  # pragma: no-cover
+    Compose = object
 
 
 class STL10DataModule(LightningDataModule):  # pragma: no cover
@@ -299,7 +301,7 @@ def val_dataloader_labeled(self) -> DataLoader:
         )
         return loader
 
-    def _default_transforms(self):
+    def _default_transforms(self) -> Compose:
         data_transforms = transform_lib.Compose([
             transform_lib.ToTensor(),
             stl10_normalization()
diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py
index 2144f0f509..92e6723968 100644
--- a/pl_bolts/datamodules/vision_datamodule.py
+++ b/pl_bolts/datamodules/vision_datamodule.py
@@ -6,6 +6,13 @@
 from pytorch_lightning import LightningDataModule
 from torch.utils.data import DataLoader, Dataset, random_split
 
+from pl_bolts.utils import _TORCHVISION_AVAILABLE
+
+if _TORCHVISION_AVAILABLE:
+    from torchvision.transforms import Compose
+else:
+    Compose = object
+
 
 class VisionDataModule(LightningDataModule):
 
@@ -115,7 +122,7 @@ def _get_splits(self, len_dataset: int) -> List[int]:
         return splits
 
     @abstractmethod
-    def default_transforms(self):
+    def default_transforms(self) -> Compose:
         """ Default transform for the dataset """
 
     def train_dataloader(self) -> DataLoader:

From cd09554cc8b4b2e5d75d0cf3db2c237e2d10e1d0 Mon Sep 17 00:00:00 2001
From: Jirka Borovec <jirka.borovec@seznam.cz>
Date: Tue, 5 Jan 2021 14:04:40 +0100
Subject: [PATCH 20/75] enable check

---
 setup.cfg | 3 ---
 1 file changed, 3 deletions(-)

diff --git a/setup.cfg b/setup.cfg
index dcd35979f9..bda41d20f4 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -61,9 +61,6 @@ show_error_codes = True
 disallow_untyped_defs = True
 ignore_missing_imports = True
 
-[mypy-pl_bolts.datamodules.*]
-ignore_errors = True
-
 [mypy-pl_bolts.datasets.*]
 ignore_errors = True
 

From 0fcd1862c8c1e8d0fcb33d215cc47d8d58dfb432 Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Fri, 18 Dec 2020 19:58:16 +0900
Subject: [PATCH 21/75] Adding types to datamodules

---
 .../datamodules/binary_mnist_datamodule.py    |  4 +-
 pl_bolts/datamodules/cifar10_datamodule.py    |  6 +--
 pl_bolts/datamodules/cityscapes_datamodule.py | 16 ++++----
 .../datamodules/fashion_mnist_datamodule.py   |  4 +-
 pl_bolts/datamodules/imagenet_datamodule.py   | 18 ++++-----
 pl_bolts/datamodules/kitti_datamodule.py      | 14 +++----
 pl_bolts/datamodules/mnist_datamodule.py      |  4 +-
 pl_bolts/datamodules/sklearn_datamodule.py    | 38 +++++++++++--------
 .../datamodules/ssl_imagenet_datamodule.py    | 22 +++++------
 pl_bolts/datamodules/stl10_datamodule.py      | 22 +++++------
 pl_bolts/datamodules/vision_datamodule.py     | 16 ++++++--
 11 files changed, 90 insertions(+), 74 deletions(-)

diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py
index 142b3d54ef..c713abe107 100644
--- a/pl_bolts/datamodules/binary_mnist_datamodule.py
+++ b/pl_bolts/datamodules/binary_mnist_datamodule.py
@@ -56,7 +56,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ) -> None:
+    ):
         """
         Args:
             data_dir: Where to save/load the data
@@ -98,7 +98,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self):
+    def default_transforms(self) -> transform_lib.Compose:
         if self.normalize:
             mnist_transforms = transform_lib.Compose(
                 [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]
diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py
index afb2df8c9a..12aea1ec87 100644
--- a/pl_bolts/datamodules/cifar10_datamodule.py
+++ b/pl_bolts/datamodules/cifar10_datamodule.py
@@ -112,7 +112,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self):
+    def default_transforms(self) -> transform_lib.Compose:
         if self.normalize:
             cf10_transforms = transform_lib.Compose([transform_lib.ToTensor(), cifar10_normalization()])
         else:
@@ -146,14 +146,14 @@ class TinyCIFAR10DataModule(CIFAR10DataModule):
 
     def __init__(
         self,
-        data_dir: str,
+        data_dir: Optional[str],
         val_split: int = 50,
         num_workers: int = 16,
         num_samples: int = 100,
         labels: Optional[Sequence] = (1, 5, 8),
         *args: Any,
         **kwargs: Any,
-    ) -> None:
+    ):
         """
         Args:
             data_dir: where to save/load the data
diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py
index b851617225..d27bfc3196 100644
--- a/pl_bolts/datamodules/cityscapes_datamodule.py
+++ b/pl_bolts/datamodules/cityscapes_datamodule.py
@@ -69,8 +69,8 @@ def __init__(
             shuffle: bool = False,
             pin_memory: bool = False,
             drop_last: bool = False,
-            *args,
-            **kwargs,
+            *args: Any,
+            **kwargs: Any,
     ):
         """
         Args:
@@ -109,14 +109,14 @@ def __init__(
         self.target_transforms = None
 
     @property
-    def num_classes(self):
+    def num_classes(self) -> int:
         """
         Return:
             30
         """
         return 30
 
-    def train_dataloader(self):
+    def train_dataloader(self) -> DataLoader:
         """
         Cityscapes train set
         """
@@ -141,7 +141,7 @@ def train_dataloader(self):
         )
         return loader
 
-    def val_dataloader(self):
+    def val_dataloader(self) -> DataLoader:
         """
         Cityscapes val set
         """
@@ -166,7 +166,7 @@ def val_dataloader(self):
         )
         return loader
 
-    def test_dataloader(self):
+    def test_dataloader(self) -> DataLoader:
         """
         Cityscapes test set
         """
@@ -190,7 +190,7 @@ def test_dataloader(self):
         )
         return loader
 
-    def _default_transforms(self):
+    def _default_transforms(self) -> transform_lib.Compose:
         cityscapes_transforms = transform_lib.Compose([
             transform_lib.ToTensor(),
             transform_lib.Normalize(
@@ -200,7 +200,7 @@ def _default_transforms(self):
         ])
         return cityscapes_transforms
 
-    def _default_target_transforms(self):
+    def _default_target_transforms(self) -> transform_lib.Compose:
         cityscapes_target_trasnforms = transform_lib.Compose([
             transform_lib.ToTensor(),
             transform_lib.Lambda(lambda t: t.squeeze())
diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py
index 833c4599a6..9a73e6a637 100644
--- a/pl_bolts/datamodules/fashion_mnist_datamodule.py
+++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py
@@ -57,7 +57,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ) -> None:
+    ):
         """
         Args:
             data_dir: Where to save/load the data
@@ -93,7 +93,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self):
+    def default_transforms(self) -> transform_lib.Compose:
         if self.normalize:
             mnist_transforms = transform_lib.Compose(
                 [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]
diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py
index 38546e29ee..a31b637ba9 100644
--- a/pl_bolts/datamodules/imagenet_datamodule.py
+++ b/pl_bolts/datamodules/imagenet_datamodule.py
@@ -58,8 +58,8 @@ def __init__(
             shuffle: bool = False,
             pin_memory: bool = False,
             drop_last: bool = False,
-            *args,
-            **kwargs,
+            *args: Any,
+            **kwargs: Any,
     ):
         """
         Args:
@@ -94,7 +94,7 @@ def __init__(
         self.num_samples = 1281167 - self.num_imgs_per_val_class * self.num_classes
 
     @property
-    def num_classes(self):
+    def num_classes(self) -> int:
         """
         Return:
 
@@ -103,7 +103,7 @@ def num_classes(self):
         """
         return 1000
 
-    def _verify_splits(self, data_dir, split):
+    def _verify_splits(self, data_dir: str, split: str):
         dirs = os.listdir(data_dir)
 
         if split not in dirs:
@@ -138,7 +138,7 @@ def prepare_data(self):
                 UnlabeledImagenet.generate_meta_bins(path)
                 """)
 
-    def train_dataloader(self):
+    def train_dataloader(self) -> DataLoader:
         """
         Uses the train split of imagenet2012 and puts away a portion of it for the validation split
         """
@@ -160,7 +160,7 @@ def train_dataloader(self):
         )
         return loader
 
-    def val_dataloader(self):
+    def val_dataloader(self) -> DataLoader:
         """
         Uses the part of the train split of imagenet2012  that was not used for training via `num_imgs_per_val_class`
 
@@ -185,7 +185,7 @@ def val_dataloader(self):
         )
         return loader
 
-    def test_dataloader(self):
+    def test_dataloader(self) -> DataLoader:
         """
         Uses the validation split of imagenet2012 for testing
         """
@@ -206,7 +206,7 @@ def test_dataloader(self):
         )
         return loader
 
-    def train_transform(self):
+    def train_transform(self) -> transform_lib.Compose:
         """
         The standard imagenet transforms
 
@@ -232,7 +232,7 @@ def train_transform(self):
 
         return preprocessing
 
-    def val_transform(self):
+    def val_transform(self) -> transform_lib.Compose:
         """
         The standard imagenet transforms for validation
 
diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py
index 433e7fffed..a50528028b 100644
--- a/pl_bolts/datamodules/kitti_datamodule.py
+++ b/pl_bolts/datamodules/kitti_datamodule.py
@@ -21,7 +21,7 @@ class KittiDataModule(LightningDataModule):
 
     def __init__(
             self,
-            data_dir: str,
+            data_dir: Optional[str],
             val_split: float = 0.2,
             test_split: float = 0.1,
             num_workers: int = 16,
@@ -30,8 +30,8 @@ def __init__(
             shuffle: bool = False,
             pin_memory: bool = False,
             drop_last: bool = False,
-            *args,
-            **kwargs,
+            *args: Any,
+            **kwargs: Any,
     ):
         """
         Kitti train, validation and test dataloaders.
@@ -100,7 +100,7 @@ def __init__(
                                                                 lengths=[train_len, val_len, test_len],
                                                                 generator=torch.Generator().manual_seed(self.seed))
 
-    def train_dataloader(self):
+    def train_dataloader(self) -> DataLoader:
         loader = DataLoader(
             self.trainset,
             batch_size=self.batch_size,
@@ -111,7 +111,7 @@ def train_dataloader(self):
         )
         return loader
 
-    def val_dataloader(self):
+    def val_dataloader(self) -> DataLoader:
         loader = DataLoader(
             self.valset,
             batch_size=self.batch_size,
@@ -122,7 +122,7 @@ def val_dataloader(self):
         )
         return loader
 
-    def test_dataloader(self):
+    def test_dataloader(self) -> DataLoader:
         loader = DataLoader(
             self.testset,
             batch_size=self.batch_size,
@@ -133,7 +133,7 @@ def test_dataloader(self):
         )
         return loader
 
-    def _default_transforms(self):
+    def _default_transforms(self) -> transform_lib.Compose:
         kitti_transforms = transforms.Compose([
             transforms.ToTensor(),
             transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753],
diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py
index 1dd5e927b6..76a0438a0c 100644
--- a/pl_bolts/datamodules/mnist_datamodule.py
+++ b/pl_bolts/datamodules/mnist_datamodule.py
@@ -56,7 +56,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ) -> None:
+    ):
         """
         Args:
             data_dir: Where to save/load the data
@@ -92,7 +92,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self):
+    def default_transforms(self) -> transform_lib.Compose:
         if self.normalize:
             mnist_transforms = transform_lib.Compose(
                 [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]
diff --git a/pl_bolts/datamodules/sklearn_datamodule.py b/pl_bolts/datamodules/sklearn_datamodule.py
index ed262b10c8..dcfd559441 100644
--- a/pl_bolts/datamodules/sklearn_datamodule.py
+++ b/pl_bolts/datamodules/sklearn_datamodule.py
@@ -42,10 +42,10 @@ def __init__(self, X: np.ndarray, y: np.ndarray, X_transform: Any = None, y_tran
         self.X_transform = X_transform
         self.y_transform = y_transform
 
-    def __len__(self):
+    def __len__(self) -> int:
         return len(self.X)
 
-    def __getitem__(self, idx):
+    def __getitem__(self, idx) -> Tuple[np.ndarray, np.ndarray]:
         x = self.X[idx].astype(np.float32)
         y = self.Y[idx]
 
@@ -89,10 +89,10 @@ def __init__(self, X: torch.Tensor, y: torch.Tensor, X_transform: Any = None, y_
         self.X_transform = X_transform
         self.y_transform = y_transform
 
-    def __len__(self):
+    def __len__(self) -> int:
         return len(self.X)
 
-    def __getitem__(self, idx):
+    def __getitem__(self, idx) -> Tuple[torch.Tensor, torch.Tensor]:
         x = self.X[idx].float()
         y = self.Y[idx]
 
@@ -145,14 +145,14 @@ def __init__(
             x_val=None, y_val=None,
             x_test=None, y_test=None,
             val_split=0.2, test_split=0.1,
-            num_workers=2,
-            random_state=1234,
-            shuffle=True,
+            num_workers:int = 2,
+            random_state: int = 1234,
+            shuffle: bool = True,
             batch_size: int = 16,
-            pin_memory=False,
-            drop_last=False,
-            *args,
-            **kwargs,
+            pin_memory: bool = False,
+            drop_last: bool = False,
+            *args: Any,
+            **kwargs: Any,
     ):
 
         super().__init__(*args, **kwargs)
@@ -193,12 +193,20 @@ def __init__(
 
         self._init_datasets(X, y, x_val, y_val, x_test, y_test)
 
-    def _init_datasets(self, X, y, x_val, y_val, x_test, y_test):
+    def _init_datasets(
+        self,
+        X: np.ndarray,
+        y: np.ndarray,
+        x_val: np.ndarray,
+        y_val: np.ndarray,
+        x_test: np.ndarray,
+        y_test: np.ndarray
+    ):
         self.train_dataset = SklearnDataset(X, y)
         self.val_dataset = SklearnDataset(x_val, y_val)
         self.test_dataset = SklearnDataset(x_test, y_test)
 
-    def train_dataloader(self):
+    def train_dataloader(self) -> DataLoader:
         loader = DataLoader(
             self.train_dataset,
             batch_size=self.batch_size,
@@ -209,7 +217,7 @@ def train_dataloader(self):
         )
         return loader
 
-    def val_dataloader(self):
+    def val_dataloader(self) -> DataLoader:
         loader = DataLoader(
             self.val_dataset,
             batch_size=self.batch_size,
@@ -220,7 +228,7 @@ def val_dataloader(self):
         )
         return loader
 
-    def test_dataloader(self):
+    def test_dataloader(self) -> DataLoader:
         loader = DataLoader(
             self.test_dataset,
             batch_size=self.batch_size,
diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py
index 03e459fd5e..50315245af 100644
--- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py
+++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py
@@ -20,15 +20,15 @@ class SSLImagenetDataModule(LightningDataModule):  # pragma: no cover
 
     def __init__(
             self,
-            data_dir,
-            meta_dir=None,
-            num_workers=16,
+            data_dir: str,
+            meta_dir: Optional[str] = None,
+            num_workers: int = 16,
             batch_size: int = 32,
             shuffle: bool = False,
             pin_memory: bool = False,
             drop_last: bool = False,
-            *args,
-            **kwargs,
+            *args: Any,
+            **kwargs: Any,
     ):
         super().__init__(*args, **kwargs)
 
@@ -46,10 +46,10 @@ def __init__(
         self.drop_last = drop_last
 
     @property
-    def num_classes(self):
+    def num_classes(self) -> int:
         return 1000
 
-    def _verify_splits(self, data_dir, split):
+    def _verify_splits(self, data_dir: str, split: str):
         dirs = os.listdir(data_dir)
 
         if split not in dirs:
@@ -79,7 +79,7 @@ def prepare_data(self):
                 UnlabeledImagenet.generate_meta_bins(path)
                 """)
 
-    def train_dataloader(self, num_images_per_class=-1, add_normalize=False):
+    def train_dataloader(self, num_images_per_class: int = -1, add_normalize: bool = False) -> DataLoader:
         transforms = self._default_transforms() if self.train_transforms is None else self.train_transforms
 
         dataset = UnlabeledImagenet(self.data_dir,
@@ -97,7 +97,7 @@ def train_dataloader(self, num_images_per_class=-1, add_normalize=False):
         )
         return loader
 
-    def val_dataloader(self, num_images_per_class=50, add_normalize=False):
+    def val_dataloader(self, num_images_per_class: int = 50, add_normalize: bool = False) -> DataLoader:
         transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms
 
         dataset = UnlabeledImagenet(self.data_dir,
@@ -115,7 +115,7 @@ def val_dataloader(self, num_images_per_class=50, add_normalize=False):
         )
         return loader
 
-    def test_dataloader(self, num_images_per_class, add_normalize=False):
+    def test_dataloader(self, num_images_per_class: int, add_normalize: bool = False) -> DataLoader:
         transforms = self._default_transforms() if self.test_transforms is None else self.test_transforms
 
         dataset = UnlabeledImagenet(self.data_dir,
@@ -133,7 +133,7 @@ def test_dataloader(self, num_images_per_class, add_normalize=False):
         )
         return loader
 
-    def _default_transforms(self):
+    def _default_transforms(self) -> transform_lib.Compose:
         mnist_transforms = transform_lib.Compose([
             transform_lib.ToTensor(),
             imagenet_normalization()
diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py
index c666db9b9b..3b29995a1b 100644
--- a/pl_bolts/datamodules/stl10_datamodule.py
+++ b/pl_bolts/datamodules/stl10_datamodule.py
@@ -63,8 +63,8 @@ def __init__(
             shuffle: bool = False,
             pin_memory: bool = False,
             drop_last: bool = False,
-            *args,
-            **kwargs,
+            *args: Any,
+            **kwargs: Any,
     ):
         """
         Args:
@@ -99,7 +99,7 @@ def __init__(
         self.num_unlabeled_samples = 100000 - unlabeled_val_split
 
     @property
-    def num_classes(self):
+    def num_classes(self) -> int:
         return 10
 
     def prepare_data(self):
@@ -110,7 +110,7 @@ def prepare_data(self):
         STL10(self.data_dir, split='train', download=True, transform=transform_lib.ToTensor())
         STL10(self.data_dir, split='test', download=True, transform=transform_lib.ToTensor())
 
-    def train_dataloader(self):
+    def train_dataloader(self) -> DataLoader:
         """
         Loads the 'unlabeled' split minus a portion set aside for validation via `unlabeled_val_split`.
         """
@@ -131,7 +131,7 @@ def train_dataloader(self):
         )
         return loader
 
-    def train_dataloader_mixed(self):
+    def train_dataloader_mixed(self) -> DataLoader:
         """
         Loads a portion of the 'unlabeled' training data and 'train' (labeled) data.
         both portions have a subset removed for validation via `unlabeled_val_split` and `train_val_split`
@@ -169,7 +169,7 @@ def train_dataloader_mixed(self):
         )
         return loader
 
-    def val_dataloader(self):
+    def val_dataloader(self) -> DataLoader:
         """
         Loads a portion of the 'unlabeled' training data set aside for validation
         The val dataset = (unlabeled - train_val_split)
@@ -196,7 +196,7 @@ def val_dataloader(self):
         )
         return loader
 
-    def val_dataloader_mixed(self):
+    def val_dataloader_mixed(self) -> DataLoader:
         """
         Loads a portion of the 'unlabeled' training data set aside for validation along with
         the portion of the 'train' dataset to be used for validation
@@ -239,7 +239,7 @@ def val_dataloader_mixed(self):
         )
         return loader
 
-    def test_dataloader(self):
+    def test_dataloader(self) -> DataLoader:
         """
         Loads the test split of STL10
 
@@ -260,7 +260,7 @@ def test_dataloader(self):
         )
         return loader
 
-    def train_dataloader_labeled(self):
+    def train_dataloader_labeled(self) -> DataLoader:
         transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms
 
         dataset = STL10(self.data_dir, split='train', download=False, transform=transforms)
@@ -278,7 +278,7 @@ def train_dataloader_labeled(self):
         )
         return loader
 
-    def val_dataloader_labeled(self):
+    def val_dataloader_labeled(self) -> DataLoader:
         transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms
         dataset = STL10(self.data_dir,
                         split='train',
@@ -299,7 +299,7 @@ def val_dataloader_labeled(self):
         )
         return loader
 
-    def _default_transforms(self):
+    def _default_transforms(self) -> transform_lib.Compose:
         data_transforms = transform_lib.Compose([
             transform_lib.ToTensor(),
             stl10_normalization()
diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py
index 2144f0f509..5b8c508904 100644
--- a/pl_bolts/datamodules/vision_datamodule.py
+++ b/pl_bolts/datamodules/vision_datamodule.py
@@ -6,6 +6,14 @@
 from pytorch_lightning import LightningDataModule
 from torch.utils.data import DataLoader, Dataset, random_split
 
+from pl_bolts.utils import _TORCHVISION_AVAILABLE
+from pl_bolts.utils.warnings import warn_missing_pkg
+
+if _TORCHVISION_AVAILABLE:
+    from torchvision import transforms as transform_lib
+else:
+    warn_missing_pkg('torchvision')  # pragma: no-cover
+
 
 class VisionDataModule(LightningDataModule):
 
@@ -29,7 +37,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ) -> None:
+    ):
         """
         Args:
             data_dir: Where to save/load the data
@@ -56,14 +64,14 @@ def __init__(
         self.pin_memory = pin_memory
         self.drop_last = drop_last
 
-    def prepare_data(self) -> None:
+    def prepare_data(self):
         """
         Saves files to data_dir
         """
         self.dataset_cls(self.data_dir, train=True, download=True)
         self.dataset_cls(self.data_dir, train=False, download=True)
 
-    def setup(self, stage: Optional[str] = None) -> None:
+    def setup(self, stage: Optional[str] = None):
         """
         Creates train, val, and test dataset
         """
@@ -115,7 +123,7 @@ def _get_splits(self, len_dataset: int) -> List[int]:
         return splits
 
     @abstractmethod
-    def default_transforms(self):
+    def default_transforms(self) -> transform_lib.Compose:
         """ Default transform for the dataset """
 
     def train_dataloader(self) -> DataLoader:

From a4306969a9e3fbdf517a2da7889107a599c922e6 Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Fri, 18 Dec 2020 20:20:20 +0900
Subject: [PATCH 22/75] Fixing typing imports

---
 pl_bolts/datamodules/async_dataloader.py       | 14 +++++++++++---
 pl_bolts/datamodules/cityscapes_datamodule.py  |  2 ++
 pl_bolts/datamodules/imagenet_datamodule.py    |  2 +-
 pl_bolts/datamodules/kitti_datamodule.py       |  3 ++-
 pl_bolts/datamodules/sklearn_datamodule.py     |  4 ++--
 .../datamodules/ssl_imagenet_datamodule.py     |  1 +
 pl_bolts/datamodules/stl10_datamodule.py       |  2 +-
 pl_bolts/datamodules/vision_datamodule.py      |  7 +------
 .../datamodules/vocdetection_datamodule.py     | 18 ++++++++++--------
 9 files changed, 31 insertions(+), 22 deletions(-)

diff --git a/pl_bolts/datamodules/async_dataloader.py b/pl_bolts/datamodules/async_dataloader.py
index 7ded9d9ef1..38a0b9bb58 100644
--- a/pl_bolts/datamodules/async_dataloader.py
+++ b/pl_bolts/datamodules/async_dataloader.py
@@ -1,10 +1,11 @@
 import re
 from queue import Queue
 from threading import Thread
+from typing import Any, Optional, Union
 
 import torch
 from torch._six import container_abcs, string_classes
-from torch.utils.data import DataLoader
+from torch.utils.data import DataLoader, Dataset
 
 
 class AsynchronousLoader(object):
@@ -26,7 +27,14 @@ class AsynchronousLoader(object):
             constructing one here
     """
 
-    def __init__(self, data, device=torch.device('cuda', 0), q_size=10, num_batches=None, **kwargs):
+    def __init__(
+            self,
+            data: Union[DataLoader, Dataset],
+            device: torch.device = torch.device('cuda', 0),
+            q_size: int = 10,
+            num_batches: Optional[int] = None,
+            **kwargs: Any
+    ):
         if isinstance(data, torch.utils.data.DataLoader):
             self.dataloader = data
         else:
@@ -105,5 +113,5 @@ def __next__(self):
         self.idx += 1
         return out
 
-    def __len__(self):
+    def __len__(self) -> int:
         return self.num_batches
diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py
index d27bfc3196..17812a0ac5 100644
--- a/pl_bolts/datamodules/cityscapes_datamodule.py
+++ b/pl_bolts/datamodules/cityscapes_datamodule.py
@@ -1,3 +1,5 @@
+from typing import Any
+
 from pytorch_lightning import LightningDataModule
 from torch.utils.data import DataLoader
 
diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py
index a31b637ba9..829c485aed 100644
--- a/pl_bolts/datamodules/imagenet_datamodule.py
+++ b/pl_bolts/datamodules/imagenet_datamodule.py
@@ -1,5 +1,5 @@
 import os
-from typing import Optional
+from typing import Any, Optional
 
 from pytorch_lightning import LightningDataModule
 from torch.utils.data import DataLoader
diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py
index a50528028b..9a82f0b7ec 100644
--- a/pl_bolts/datamodules/kitti_datamodule.py
+++ b/pl_bolts/datamodules/kitti_datamodule.py
@@ -1,4 +1,5 @@
 import os
+from typing import Any, Optional
 
 import torch
 from pytorch_lightning import LightningDataModule
@@ -133,7 +134,7 @@ def test_dataloader(self) -> DataLoader:
         )
         return loader
 
-    def _default_transforms(self) -> transform_lib.Compose:
+    def _default_transforms(self) -> transforms.Compose:
         kitti_transforms = transforms.Compose([
             transforms.ToTensor(),
             transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753],
diff --git a/pl_bolts/datamodules/sklearn_datamodule.py b/pl_bolts/datamodules/sklearn_datamodule.py
index dcfd559441..e80e1dfc9a 100644
--- a/pl_bolts/datamodules/sklearn_datamodule.py
+++ b/pl_bolts/datamodules/sklearn_datamodule.py
@@ -1,5 +1,5 @@
 import math
-from typing import Any
+from typing import Any, Tuple
 
 import numpy as np
 import torch
@@ -145,7 +145,7 @@ def __init__(
             x_val=None, y_val=None,
             x_test=None, y_test=None,
             val_split=0.2, test_split=0.1,
-            num_workers:int = 2,
+            num_workers: int = 2,
             random_state: int = 1234,
             shuffle: bool = True,
             batch_size: int = 16,
diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py
index 50315245af..1584583101 100644
--- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py
+++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py
@@ -1,4 +1,5 @@
 import os
+from typing import Any, Optional
 
 from pytorch_lightning import LightningDataModule
 from torch.utils.data import DataLoader
diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py
index 3b29995a1b..8d46cfd7bf 100644
--- a/pl_bolts/datamodules/stl10_datamodule.py
+++ b/pl_bolts/datamodules/stl10_datamodule.py
@@ -1,5 +1,5 @@
 import os
-from typing import Optional
+from typing import Any, Optional
 
 import torch
 from pytorch_lightning import LightningDataModule
diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py
index 5b8c508904..cdcefcb2eb 100644
--- a/pl_bolts/datamodules/vision_datamodule.py
+++ b/pl_bolts/datamodules/vision_datamodule.py
@@ -9,11 +9,6 @@
 from pl_bolts.utils import _TORCHVISION_AVAILABLE
 from pl_bolts.utils.warnings import warn_missing_pkg
 
-if _TORCHVISION_AVAILABLE:
-    from torchvision import transforms as transform_lib
-else:
-    warn_missing_pkg('torchvision')  # pragma: no-cover
-
 
 class VisionDataModule(LightningDataModule):
 
@@ -123,7 +118,7 @@ def _get_splits(self, len_dataset: int) -> List[int]:
         return splits
 
     @abstractmethod
-    def default_transforms(self) -> transform_lib.Compose:
+    def default_transforms(self):
         """ Default transform for the dataset """
 
     def train_dataloader(self) -> DataLoader:
diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py
index a2087f9448..3f3767b4f8 100644
--- a/pl_bolts/datamodules/vocdetection_datamodule.py
+++ b/pl_bolts/datamodules/vocdetection_datamodule.py
@@ -1,3 +1,5 @@
+from typing import Any, Dict
+
 import torch
 from pytorch_lightning import LightningDataModule
 from torch.utils.data import DataLoader
@@ -17,7 +19,7 @@ class Compose(object):
     Like `torchvision.transforms.compose` but works for (image, target)
     """
 
-    def __init__(self, transforms):
+    def __init__(self, transforms: T.Compose):
         self.transforms = transforms
 
     def __call__(self, image, target):
@@ -55,7 +57,7 @@ def _collate_fn(batch):
 )
 
 
-def _prepare_voc_instance(image, target):
+def _prepare_voc_instance(image, target: Dict[str, Any]):
     """
     Prepares VOC dataset into appropriate target for fasterrcnn
 
@@ -113,8 +115,8 @@ def __init__(
         shuffle: bool = False,
         pin_memory: bool = False,
         drop_last: bool = False,
-        *args,
-        **kwargs,
+        *args: Any,
+        **kwargs: Any,
     ):
         if not _TORCHVISION_AVAILABLE:
             raise ModuleNotFoundError(  # pragma: no-cover
@@ -132,7 +134,7 @@ def __init__(
         self.drop_last = drop_last
 
     @property
-    def num_classes(self):
+    def num_classes(self) -> int:
         """
         Return:
             21
@@ -146,7 +148,7 @@ def prepare_data(self):
         VOCDetection(self.data_dir, year=self.year, image_set="train", download=True)
         VOCDetection(self.data_dir, year=self.year, image_set="val", download=True)
 
-    def train_dataloader(self, batch_size=1, transforms=None):
+    def train_dataloader(self, batch_size: int = 1, transforms=None) -> DataLoader:
         """
         VOCDetection train set uses the `train` subset
 
@@ -174,7 +176,7 @@ def train_dataloader(self, batch_size=1, transforms=None):
         )
         return loader
 
-    def val_dataloader(self, batch_size=1, transforms=None):
+    def val_dataloader(self, batch_size: int = 1, transforms=None) -> DataLoader:
         """
         VOCDetection val set uses the `val` subset
 
@@ -201,7 +203,7 @@ def val_dataloader(self, batch_size=1, transforms=None):
         )
         return loader
 
-    def _default_transforms(self):
+    def _default_transforms(self) -> T.Compose:
         if self.normalize:
             return (
                 lambda image, target: (

From a84551e0c69853b31c3eb41c49b2483c3a012133 Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Sun, 20 Dec 2020 03:03:00 +0900
Subject: [PATCH 23/75] Removing torchvision.transforms from return typing

---
 pl_bolts/datamodules/binary_mnist_datamodule.py  | 2 +-
 pl_bolts/datamodules/cifar10_datamodule.py       | 2 +-
 pl_bolts/datamodules/cityscapes_datamodule.py    | 4 ++--
 pl_bolts/datamodules/fashion_mnist_datamodule.py | 2 +-
 pl_bolts/datamodules/imagenet_datamodule.py      | 4 ++--
 pl_bolts/datamodules/kitti_datamodule.py         | 2 +-
 pl_bolts/datamodules/mnist_datamodule.py         | 2 +-
 pl_bolts/datamodules/ssl_imagenet_datamodule.py  | 2 +-
 pl_bolts/datamodules/stl10_datamodule.py         | 2 +-
 pl_bolts/datamodules/vocdetection_datamodule.py  | 2 +-
 10 files changed, 12 insertions(+), 12 deletions(-)

diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py
index c713abe107..8dc02ec95e 100644
--- a/pl_bolts/datamodules/binary_mnist_datamodule.py
+++ b/pl_bolts/datamodules/binary_mnist_datamodule.py
@@ -98,7 +98,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self) -> transform_lib.Compose:
+    def default_transforms(self):
         if self.normalize:
             mnist_transforms = transform_lib.Compose(
                 [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]
diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py
index 12aea1ec87..b208172ed0 100644
--- a/pl_bolts/datamodules/cifar10_datamodule.py
+++ b/pl_bolts/datamodules/cifar10_datamodule.py
@@ -112,7 +112,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self) -> transform_lib.Compose:
+    def default_transforms(self):
         if self.normalize:
             cf10_transforms = transform_lib.Compose([transform_lib.ToTensor(), cifar10_normalization()])
         else:
diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py
index 17812a0ac5..dc8b866fba 100644
--- a/pl_bolts/datamodules/cityscapes_datamodule.py
+++ b/pl_bolts/datamodules/cityscapes_datamodule.py
@@ -192,7 +192,7 @@ def test_dataloader(self) -> DataLoader:
         )
         return loader
 
-    def _default_transforms(self) -> transform_lib.Compose:
+    def _default_transforms(self):
         cityscapes_transforms = transform_lib.Compose([
             transform_lib.ToTensor(),
             transform_lib.Normalize(
@@ -202,7 +202,7 @@ def _default_transforms(self) -> transform_lib.Compose:
         ])
         return cityscapes_transforms
 
-    def _default_target_transforms(self) -> transform_lib.Compose:
+    def _default_target_transforms(self):
         cityscapes_target_trasnforms = transform_lib.Compose([
             transform_lib.ToTensor(),
             transform_lib.Lambda(lambda t: t.squeeze())
diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py
index 9a73e6a637..a128ddfaab 100644
--- a/pl_bolts/datamodules/fashion_mnist_datamodule.py
+++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py
@@ -93,7 +93,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self) -> transform_lib.Compose:
+    def default_transforms(self):
         if self.normalize:
             mnist_transforms = transform_lib.Compose(
                 [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]
diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py
index 829c485aed..f8d8262108 100644
--- a/pl_bolts/datamodules/imagenet_datamodule.py
+++ b/pl_bolts/datamodules/imagenet_datamodule.py
@@ -206,7 +206,7 @@ def test_dataloader(self) -> DataLoader:
         )
         return loader
 
-    def train_transform(self) -> transform_lib.Compose:
+    def train_transform(self):
         """
         The standard imagenet transforms
 
@@ -232,7 +232,7 @@ def train_transform(self) -> transform_lib.Compose:
 
         return preprocessing
 
-    def val_transform(self) -> transform_lib.Compose:
+    def val_transform(self):
         """
         The standard imagenet transforms for validation
 
diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py
index 9a82f0b7ec..06778fdbfc 100644
--- a/pl_bolts/datamodules/kitti_datamodule.py
+++ b/pl_bolts/datamodules/kitti_datamodule.py
@@ -134,7 +134,7 @@ def test_dataloader(self) -> DataLoader:
         )
         return loader
 
-    def _default_transforms(self) -> transforms.Compose:
+    def _default_transforms(self):
         kitti_transforms = transforms.Compose([
             transforms.ToTensor(),
             transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753],
diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py
index 76a0438a0c..c57fe8ca82 100644
--- a/pl_bolts/datamodules/mnist_datamodule.py
+++ b/pl_bolts/datamodules/mnist_datamodule.py
@@ -92,7 +92,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self) -> transform_lib.Compose:
+    def default_transforms(self):
         if self.normalize:
             mnist_transforms = transform_lib.Compose(
                 [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]
diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py
index 1584583101..96041fd4d9 100644
--- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py
+++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py
@@ -134,7 +134,7 @@ def test_dataloader(self, num_images_per_class: int, add_normalize: bool = False
         )
         return loader
 
-    def _default_transforms(self) -> transform_lib.Compose:
+    def _default_transforms(self):
         mnist_transforms = transform_lib.Compose([
             transform_lib.ToTensor(),
             imagenet_normalization()
diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py
index 8d46cfd7bf..5cd680a535 100644
--- a/pl_bolts/datamodules/stl10_datamodule.py
+++ b/pl_bolts/datamodules/stl10_datamodule.py
@@ -299,7 +299,7 @@ def val_dataloader_labeled(self) -> DataLoader:
         )
         return loader
 
-    def _default_transforms(self) -> transform_lib.Compose:
+    def _default_transforms(self):
         data_transforms = transform_lib.Compose([
             transform_lib.ToTensor(),
             stl10_normalization()
diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py
index 3f3767b4f8..e0768b42d5 100644
--- a/pl_bolts/datamodules/vocdetection_datamodule.py
+++ b/pl_bolts/datamodules/vocdetection_datamodule.py
@@ -203,7 +203,7 @@ def val_dataloader(self, batch_size: int = 1, transforms=None) -> DataLoader:
         )
         return loader
 
-    def _default_transforms(self) -> T.Compose:
+    def _default_transforms(self):
         if self.normalize:
             return (
                 lambda image, target: (

From 685162c7226a7cfe7bba6e948436dd8d40bc8c27 Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Sun, 20 Dec 2020 03:09:03 +0900
Subject: [PATCH 24/75] Remove more torchvision.transforms typing

---
 pl_bolts/datamodules/vocdetection_datamodule.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py
index e0768b42d5..3753e727ad 100644
--- a/pl_bolts/datamodules/vocdetection_datamodule.py
+++ b/pl_bolts/datamodules/vocdetection_datamodule.py
@@ -19,7 +19,7 @@ class Compose(object):
     Like `torchvision.transforms.compose` but works for (image, target)
     """
 
-    def __init__(self, transforms: T.Compose):
+    def __init__(self, transforms):
         self.transforms = transforms
 
     def __call__(self, image, target):

From 01408372c75cb8c22fbff7c4062509e66ed64335 Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Mon, 21 Dec 2020 10:13:11 +0900
Subject: [PATCH 25/75] Removing return typing

---
 pl_bolts/datamodules/cifar10_datamodule.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py
index b208172ed0..85ba4de6e7 100644
--- a/pl_bolts/datamodules/cifar10_datamodule.py
+++ b/pl_bolts/datamodules/cifar10_datamodule.py
@@ -71,7 +71,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ) -> None:
+    ):
         """
         Args:
             data_dir: Where to save/load the data

From a6b8d4af95c2710ad47e81d25b5f866dd79eca37 Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Mon, 21 Dec 2020 16:00:55 +0900
Subject: [PATCH 26/75] Add `None` for optional arguments

---
 pl_bolts/datamodules/cifar10_datamodule.py | 2 +-
 pl_bolts/datamodules/kitti_datamodule.py   | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py
index 85ba4de6e7..534774684f 100644
--- a/pl_bolts/datamodules/cifar10_datamodule.py
+++ b/pl_bolts/datamodules/cifar10_datamodule.py
@@ -146,7 +146,7 @@ class TinyCIFAR10DataModule(CIFAR10DataModule):
 
     def __init__(
         self,
-        data_dir: Optional[str],
+        data_dir: Optional[str] = None,
         val_split: int = 50,
         num_workers: int = 16,
         num_samples: int = 100,
diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py
index 06778fdbfc..0067a1e53d 100644
--- a/pl_bolts/datamodules/kitti_datamodule.py
+++ b/pl_bolts/datamodules/kitti_datamodule.py
@@ -22,7 +22,7 @@ class KittiDataModule(LightningDataModule):
 
     def __init__(
             self,
-            data_dir: Optional[str],
+            data_dir: Optional[str] = None,
             val_split: float = 0.2,
             test_split: float = 0.1,
             num_workers: int = 16,

From de35a5514fc39ccbe74564f72b245d676f4154ac Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Mon, 21 Dec 2020 16:04:36 +0900
Subject: [PATCH 27/75] Remove unnecessary import

---
 pl_bolts/datamodules/vision_datamodule.py | 3 ---
 1 file changed, 3 deletions(-)

diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py
index cdcefcb2eb..42252c4edf 100644
--- a/pl_bolts/datamodules/vision_datamodule.py
+++ b/pl_bolts/datamodules/vision_datamodule.py
@@ -6,9 +6,6 @@
 from pytorch_lightning import LightningDataModule
 from torch.utils.data import DataLoader, Dataset, random_split
 
-from pl_bolts.utils import _TORCHVISION_AVAILABLE
-from pl_bolts.utils.warnings import warn_missing_pkg
-
 
 class VisionDataModule(LightningDataModule):
 

From f521b793cd6d488531179e457b022c25e1ab9e8f Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Mon, 21 Dec 2020 16:51:36 +0900
Subject: [PATCH 28/75] Add `None` return type

---
 pl_bolts/datamodules/async_dataloader.py         |  4 ++--
 pl_bolts/datamodules/binary_mnist_datamodule.py  |  2 +-
 pl_bolts/datamodules/cifar10_datamodule.py       |  4 ++--
 pl_bolts/datamodules/cityscapes_datamodule.py    |  2 +-
 pl_bolts/datamodules/experience_source.py        |  4 ++--
 pl_bolts/datamodules/fashion_mnist_datamodule.py |  2 +-
 pl_bolts/datamodules/imagenet_datamodule.py      |  6 +++---
 pl_bolts/datamodules/kitti_datamodule.py         |  2 +-
 pl_bolts/datamodules/mnist_datamodule.py         |  2 +-
 pl_bolts/datamodules/sklearn_datamodule.py       | 12 ++++++++----
 pl_bolts/datamodules/ssl_imagenet_datamodule.py  |  6 +++---
 pl_bolts/datamodules/stl10_datamodule.py         |  4 ++--
 pl_bolts/datamodules/vision_datamodule.py        |  6 +++---
 pl_bolts/datamodules/vocdetection_datamodule.py  |  6 +++---
 14 files changed, 33 insertions(+), 29 deletions(-)

diff --git a/pl_bolts/datamodules/async_dataloader.py b/pl_bolts/datamodules/async_dataloader.py
index 38a0b9bb58..224f34d5ee 100644
--- a/pl_bolts/datamodules/async_dataloader.py
+++ b/pl_bolts/datamodules/async_dataloader.py
@@ -34,7 +34,7 @@ def __init__(
             q_size: int = 10,
             num_batches: Optional[int] = None,
             **kwargs: Any
-    ):
+    ) -> None:
         if isinstance(data, torch.utils.data.DataLoader):
             self.dataloader = data
         else:
@@ -57,7 +57,7 @@ def __init__(
 
         self.np_str_obj_array_pattern = re.compile(r'[SaUO]')
 
-    def load_loop(self):  # The loop that will load into the queue in the background
+    def load_loop(self) -> None:  # The loop that will load into the queue in the background
         for i, sample in enumerate(self.dataloader):
             self.queue.put(self.load_instance(sample))
             if i == len(self):
diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py
index 8dc02ec95e..142b3d54ef 100644
--- a/pl_bolts/datamodules/binary_mnist_datamodule.py
+++ b/pl_bolts/datamodules/binary_mnist_datamodule.py
@@ -56,7 +56,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: Where to save/load the data
diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py
index 534774684f..2cb894d749 100644
--- a/pl_bolts/datamodules/cifar10_datamodule.py
+++ b/pl_bolts/datamodules/cifar10_datamodule.py
@@ -71,7 +71,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: Where to save/load the data
@@ -153,7 +153,7 @@ def __init__(
         labels: Optional[Sequence] = (1, 5, 8),
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: where to save/load the data
diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py
index dc8b866fba..61c1ae2bef 100644
--- a/pl_bolts/datamodules/cityscapes_datamodule.py
+++ b/pl_bolts/datamodules/cityscapes_datamodule.py
@@ -73,7 +73,7 @@ def __init__(
             drop_last: bool = False,
             *args: Any,
             **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: where to load the data from path, i.e. where directory leftImg8bit and gtFine or gtCoarse
diff --git a/pl_bolts/datamodules/experience_source.py b/pl_bolts/datamodules/experience_source.py
index 5fe1332dfd..fac4e82f25 100644
--- a/pl_bolts/datamodules/experience_source.py
+++ b/pl_bolts/datamodules/experience_source.py
@@ -30,7 +30,7 @@ class ExperienceSourceDataset(IterableDataset):
     The logic for the experience source and how the batch is generated is defined the Lightning model itself
     """
 
-    def __init__(self, generate_batch: Callable):
+    def __init__(self, generate_batch: Callable) -> None:
         self.generate_batch = generate_batch
 
     def __iter__(self) -> Iterable:
@@ -243,7 +243,7 @@ def pop_rewards_steps(self):
 class DiscountedExperienceSource(ExperienceSource):
     """Outputs experiences with a discounted reward over N steps"""
 
-    def __init__(self, env: Env, agent, n_steps: int = 1, gamma: float = 0.99):
+    def __init__(self, env: Env, agent, n_steps: int = 1, gamma: float = 0.99) -> None:
         super().__init__(env, agent, (n_steps + 1))
         self.gamma = gamma
         self.steps = n_steps
diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py
index a128ddfaab..833c4599a6 100644
--- a/pl_bolts/datamodules/fashion_mnist_datamodule.py
+++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py
@@ -57,7 +57,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: Where to save/load the data
diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py
index f8d8262108..6f06913f9f 100644
--- a/pl_bolts/datamodules/imagenet_datamodule.py
+++ b/pl_bolts/datamodules/imagenet_datamodule.py
@@ -60,7 +60,7 @@ def __init__(
             drop_last: bool = False,
             *args: Any,
             **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: path to the imagenet dataset file
@@ -103,14 +103,14 @@ def num_classes(self) -> int:
         """
         return 1000
 
-    def _verify_splits(self, data_dir: str, split: str):
+    def _verify_splits(self, data_dir: str, split: str) -> None:
         dirs = os.listdir(data_dir)
 
         if split not in dirs:
             raise FileNotFoundError(f'a {split} Imagenet split was not found in {data_dir},'
                                     f' make sure the folder contains a subfolder named {split}')
 
-    def prepare_data(self):
+    def prepare_data(self) -> None:
         """
         This method already assumes you have imagenet2012 downloaded.
         It validates the data using the meta.bin.
diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py
index 0067a1e53d..3cf26dc762 100644
--- a/pl_bolts/datamodules/kitti_datamodule.py
+++ b/pl_bolts/datamodules/kitti_datamodule.py
@@ -33,7 +33,7 @@ def __init__(
             drop_last: bool = False,
             *args: Any,
             **kwargs: Any,
-    ):
+    ) -> None:
         """
         Kitti train, validation and test dataloaders.
 
diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py
index c57fe8ca82..1dd5e927b6 100644
--- a/pl_bolts/datamodules/mnist_datamodule.py
+++ b/pl_bolts/datamodules/mnist_datamodule.py
@@ -56,7 +56,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: Where to save/load the data
diff --git a/pl_bolts/datamodules/sklearn_datamodule.py b/pl_bolts/datamodules/sklearn_datamodule.py
index e80e1dfc9a..d9477acc0b 100644
--- a/pl_bolts/datamodules/sklearn_datamodule.py
+++ b/pl_bolts/datamodules/sklearn_datamodule.py
@@ -28,7 +28,9 @@ class SklearnDataset(Dataset):
         >>> len(dataset)
         506
     """
-    def __init__(self, X: np.ndarray, y: np.ndarray, X_transform: Any = None, y_transform: Any = None):
+    def __init__(
+        self, X: np.ndarray, y: np.ndarray, X_transform: Any = None, y_transform: Any = None
+    ) -> None:
         """
         Args:
             X: Numpy ndarray
@@ -75,7 +77,9 @@ class TensorDataset(Dataset):
         >>> len(dataset)
         10
     """
-    def __init__(self, X: torch.Tensor, y: torch.Tensor, X_transform: Any = None, y_transform: Any = None):
+    def __init__(
+        self, X: torch.Tensor, y: torch.Tensor, X_transform: Any = None, y_transform: Any = None
+    ) -> None:
         """
         Args:
             X: PyTorch tensor
@@ -153,7 +157,7 @@ def __init__(
             drop_last: bool = False,
             *args: Any,
             **kwargs: Any,
-    ):
+    ) -> None:
 
         super().__init__(*args, **kwargs)
         self.num_workers = num_workers
@@ -201,7 +205,7 @@ def _init_datasets(
         y_val: np.ndarray,
         x_test: np.ndarray,
         y_test: np.ndarray
-    ):
+    ) -> None:
         self.train_dataset = SklearnDataset(X, y)
         self.val_dataset = SklearnDataset(x_val, y_val)
         self.test_dataset = SklearnDataset(x_test, y_test)
diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py
index 96041fd4d9..3dbda03527 100644
--- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py
+++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py
@@ -30,7 +30,7 @@ def __init__(
             drop_last: bool = False,
             *args: Any,
             **kwargs: Any,
-    ):
+    ) -> None:
         super().__init__(*args, **kwargs)
 
         if not _TORCHVISION_AVAILABLE:
@@ -50,14 +50,14 @@ def __init__(
     def num_classes(self) -> int:
         return 1000
 
-    def _verify_splits(self, data_dir: str, split: str):
+    def _verify_splits(self, data_dir: str, split: str) -> None:
         dirs = os.listdir(data_dir)
 
         if split not in dirs:
             raise FileNotFoundError(f'a {split} Imagenet split was not found in {data_dir}, make sure the'
                                     f' folder contains a subfolder named {split}')
 
-    def prepare_data(self):
+    def prepare_data(self) -> None:
         # imagenet cannot be downloaded... must provide path to folder with the train/val splits
         self._verify_splits(self.data_dir, 'train')
         self._verify_splits(self.data_dir, 'val')
diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py
index 5cd680a535..79420be149 100644
--- a/pl_bolts/datamodules/stl10_datamodule.py
+++ b/pl_bolts/datamodules/stl10_datamodule.py
@@ -65,7 +65,7 @@ def __init__(
             drop_last: bool = False,
             *args: Any,
             **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: where to save/load the data
@@ -102,7 +102,7 @@ def __init__(
     def num_classes(self) -> int:
         return 10
 
-    def prepare_data(self):
+    def prepare_data(self) -> None:
         """
         Downloads the unlabeled, train and test split
         """
diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py
index 42252c4edf..2144f0f509 100644
--- a/pl_bolts/datamodules/vision_datamodule.py
+++ b/pl_bolts/datamodules/vision_datamodule.py
@@ -29,7 +29,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: Where to save/load the data
@@ -56,14 +56,14 @@ def __init__(
         self.pin_memory = pin_memory
         self.drop_last = drop_last
 
-    def prepare_data(self):
+    def prepare_data(self) -> None:
         """
         Saves files to data_dir
         """
         self.dataset_cls(self.data_dir, train=True, download=True)
         self.dataset_cls(self.data_dir, train=False, download=True)
 
-    def setup(self, stage: Optional[str] = None):
+    def setup(self, stage: Optional[str] = None) -> None:
         """
         Creates train, val, and test dataset
         """
diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py
index 3753e727ad..52d5065d97 100644
--- a/pl_bolts/datamodules/vocdetection_datamodule.py
+++ b/pl_bolts/datamodules/vocdetection_datamodule.py
@@ -19,7 +19,7 @@ class Compose(object):
     Like `torchvision.transforms.compose` but works for (image, target)
     """
 
-    def __init__(self, transforms):
+    def __init__(self, transforms) -> None:
         self.transforms = transforms
 
     def __call__(self, image, target):
@@ -117,7 +117,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         if not _TORCHVISION_AVAILABLE:
             raise ModuleNotFoundError(  # pragma: no-cover
                 'You want to use VOC dataset loaded from `torchvision` which is not installed yet.'
@@ -141,7 +141,7 @@ def num_classes(self) -> int:
         """
         return 21
 
-    def prepare_data(self):
+    def prepare_data(self) -> None:
         """
         Saves VOCDetection files to data_dir
         """

From fa0d271011399d5006541dc6b8fc9cc1efaf180c Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Tue, 5 Jan 2021 20:30:43 +0900
Subject: [PATCH 29/75] Add type for torchvision transforms

---
 pl_bolts/datamodules/binary_mnist_datamodule.py  | 4 +++-
 pl_bolts/datamodules/cifar10_datamodule.py       | 4 +++-
 pl_bolts/datamodules/cityscapes_datamodule.py    | 6 ++++--
 pl_bolts/datamodules/fashion_mnist_datamodule.py | 4 +++-
 pl_bolts/datamodules/imagenet_datamodule.py      | 6 ++++--
 pl_bolts/datamodules/kitti_datamodule.py         | 4 +++-
 pl_bolts/datamodules/mnist_datamodule.py         | 4 +++-
 pl_bolts/datamodules/ssl_imagenet_datamodule.py  | 4 +++-
 pl_bolts/datamodules/stl10_datamodule.py         | 4 +++-
 pl_bolts/datamodules/vision_datamodule.py        | 9 ++++++++-
 10 files changed, 37 insertions(+), 12 deletions(-)

diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py
index 142b3d54ef..ad17360d08 100644
--- a/pl_bolts/datamodules/binary_mnist_datamodule.py
+++ b/pl_bolts/datamodules/binary_mnist_datamodule.py
@@ -7,8 +7,10 @@
 
 if _TORCHVISION_AVAILABLE:
     from torchvision import transforms as transform_lib
+    from torchvision.transforms import Compose
 else:  # pragma: no-cover
     warn_missing_pkg('torchvision')
+    Compose = object
 
 
 class BinaryMNISTDataModule(VisionDataModule):
@@ -98,7 +100,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self):
+    def default_transforms(self) -> Compose:
         if self.normalize:
             mnist_transforms = transform_lib.Compose(
                 [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]
diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py
index 2cb894d749..9dbf10b670 100644
--- a/pl_bolts/datamodules/cifar10_datamodule.py
+++ b/pl_bolts/datamodules/cifar10_datamodule.py
@@ -9,9 +9,11 @@
 if _TORCHVISION_AVAILABLE:
     from torchvision import transforms as transform_lib
     from torchvision.datasets import CIFAR10
+    from torchvision.transforms import Compose
 else:
     warn_missing_pkg('torchvision')  # pragma: no-cover
     CIFAR10 = None
+    Compose = object
 
 
 class CIFAR10DataModule(VisionDataModule):
@@ -112,7 +114,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self):
+    def default_transforms(self) -> Compose:
         if self.normalize:
             cf10_transforms = transform_lib.Compose([transform_lib.ToTensor(), cifar10_normalization()])
         else:
diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py
index 61c1ae2bef..130e333976 100644
--- a/pl_bolts/datamodules/cityscapes_datamodule.py
+++ b/pl_bolts/datamodules/cityscapes_datamodule.py
@@ -9,8 +9,10 @@
 if _TORCHVISION_AVAILABLE:
     from torchvision import transforms as transform_lib
     from torchvision.datasets import Cityscapes
+    from torchvision.transforms import Compose
 else:
     warn_missing_pkg('torchvision')  # pragma: no-cover
+    Compose = object
 
 
 class CityscapesDataModule(LightningDataModule):
@@ -192,7 +194,7 @@ def test_dataloader(self) -> DataLoader:
         )
         return loader
 
-    def _default_transforms(self):
+    def _default_transforms(self) -> Compose:
         cityscapes_transforms = transform_lib.Compose([
             transform_lib.ToTensor(),
             transform_lib.Normalize(
@@ -202,7 +204,7 @@ def _default_transforms(self):
         ])
         return cityscapes_transforms
 
-    def _default_target_transforms(self):
+    def _default_target_transforms(self) -> Compose:
         cityscapes_target_trasnforms = transform_lib.Compose([
             transform_lib.ToTensor(),
             transform_lib.Lambda(lambda t: t.squeeze())
diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py
index 833c4599a6..b37221bc74 100644
--- a/pl_bolts/datamodules/fashion_mnist_datamodule.py
+++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py
@@ -7,9 +7,11 @@
 if _TORCHVISION_AVAILABLE:
     from torchvision import transforms as transform_lib
     from torchvision.datasets import FashionMNIST
+    from torchvision.transforms import Compose
 else:
     warn_missing_pkg('torchvision')  # pragma: no-cover
     FashionMNIST = None
+    Compose = object
 
 
 class FashionMNISTDataModule(VisionDataModule):
@@ -93,7 +95,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self):
+    def default_transforms(self) -> Compose:
         if self.normalize:
             mnist_transforms = transform_lib.Compose(
                 [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]
diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py
index 6f06913f9f..db2fc68c0b 100644
--- a/pl_bolts/datamodules/imagenet_datamodule.py
+++ b/pl_bolts/datamodules/imagenet_datamodule.py
@@ -11,8 +11,10 @@
 
 if _TORCHVISION_AVAILABLE:
     from torchvision import transforms as transform_lib
+    from torchvision.transforms import Compose
 else:
     warn_missing_pkg('torchvision')  # pragma: no-cover
+    Compose = object
 
 
 class ImagenetDataModule(LightningDataModule):
@@ -206,7 +208,7 @@ def test_dataloader(self) -> DataLoader:
         )
         return loader
 
-    def train_transform(self):
+    def train_transform(self) -> Compose:
         """
         The standard imagenet transforms
 
@@ -232,7 +234,7 @@ def train_transform(self):
 
         return preprocessing
 
-    def val_transform(self):
+    def val_transform(self) -> Compose:
         """
         The standard imagenet transforms for validation
 
diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py
index 3cf26dc762..b63040f3bf 100644
--- a/pl_bolts/datamodules/kitti_datamodule.py
+++ b/pl_bolts/datamodules/kitti_datamodule.py
@@ -12,8 +12,10 @@
 
 if _TORCHVISION_AVAILABLE:
     import torchvision.transforms as transforms
+    from torchvision.transforms import Compose
 else:  # pragma: no cover
     warn_missing_pkg('torchvision')
+    Compose = object
 
 
 class KittiDataModule(LightningDataModule):
@@ -134,7 +136,7 @@ def test_dataloader(self) -> DataLoader:
         )
         return loader
 
-    def _default_transforms(self):
+    def _default_transforms(self) -> Compose:
         kitti_transforms = transforms.Compose([
             transforms.ToTensor(),
             transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753],
diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py
index 1dd5e927b6..711460023c 100644
--- a/pl_bolts/datamodules/mnist_datamodule.py
+++ b/pl_bolts/datamodules/mnist_datamodule.py
@@ -7,9 +7,11 @@
 if _TORCHVISION_AVAILABLE:
     from torchvision import transforms as transform_lib
     from torchvision.datasets import MNIST
+    from torchvision.transforms import Compose
 else:
     warn_missing_pkg('torchvision')  # pragma: no-cover
     MNIST = None
+    Compose = object
 
 
 class MNISTDataModule(VisionDataModule):
@@ -92,7 +94,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self):
+    def default_transforms(self) -> Compose:
         if self.normalize:
             mnist_transforms = transform_lib.Compose(
                 [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]
diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py
index 3dbda03527..d575eb2d01 100644
--- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py
+++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py
@@ -11,8 +11,10 @@
 
 if _TORCHVISION_AVAILABLE:
     from torchvision import transforms as transform_lib
+    from torchvision.transforms import Compose
 else:
     warn_missing_pkg('torchvision')  # pragma: no-cover
+    Compose = object
 
 
 class SSLImagenetDataModule(LightningDataModule):  # pragma: no cover
@@ -134,7 +136,7 @@ def test_dataloader(self, num_images_per_class: int, add_normalize: bool = False
         )
         return loader
 
-    def _default_transforms(self):
+    def _default_transforms(self) -> Compose:
         mnist_transforms = transform_lib.Compose([
             transform_lib.ToTensor(),
             imagenet_normalization()
diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py
index 79420be149..f9ac77e140 100644
--- a/pl_bolts/datamodules/stl10_datamodule.py
+++ b/pl_bolts/datamodules/stl10_datamodule.py
@@ -13,8 +13,10 @@
 if _TORCHVISION_AVAILABLE:
     from torchvision import transforms as transform_lib
     from torchvision.datasets import STL10
+    from torchvision.transforms import Compose
 else:
     warn_missing_pkg('torchvision')  # pragma: no-cover
+    Compose = object
 
 
 class STL10DataModule(LightningDataModule):  # pragma: no cover
@@ -299,7 +301,7 @@ def val_dataloader_labeled(self) -> DataLoader:
         )
         return loader
 
-    def _default_transforms(self):
+    def _default_transforms(self) -> Compose:
         data_transforms = transform_lib.Compose([
             transform_lib.ToTensor(),
             stl10_normalization()
diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py
index 2144f0f509..92e6723968 100644
--- a/pl_bolts/datamodules/vision_datamodule.py
+++ b/pl_bolts/datamodules/vision_datamodule.py
@@ -6,6 +6,13 @@
 from pytorch_lightning import LightningDataModule
 from torch.utils.data import DataLoader, Dataset, random_split
 
+from pl_bolts.utils import _TORCHVISION_AVAILABLE
+
+if _TORCHVISION_AVAILABLE:
+    from torchvision.transforms import Compose
+else:
+    Compose = object
+
 
 class VisionDataModule(LightningDataModule):
 
@@ -115,7 +122,7 @@ def _get_splits(self, len_dataset: int) -> List[int]:
         return splits
 
     @abstractmethod
-    def default_transforms(self):
+    def default_transforms(self) -> Compose:
         """ Default transform for the dataset """
 
     def train_dataloader(self) -> DataLoader:

From 0bc9f7b99646e73696bbe75e2311463a607231c5 Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Fri, 18 Dec 2020 19:58:16 +0900
Subject: [PATCH 30/75] Adding types to datamodules

---
 pl_bolts/datamodules/binary_mnist_datamodule.py  | 2 +-
 pl_bolts/datamodules/cifar10_datamodule.py       | 6 +++++-
 pl_bolts/datamodules/fashion_mnist_datamodule.py | 2 +-
 pl_bolts/datamodules/mnist_datamodule.py         | 2 +-
 pl_bolts/datamodules/sklearn_datamodule.py       | 2 +-
 pl_bolts/datamodules/vision_datamodule.py        | 6 +++---
 6 files changed, 12 insertions(+), 8 deletions(-)

diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py
index ad17360d08..28e44ce72c 100644
--- a/pl_bolts/datamodules/binary_mnist_datamodule.py
+++ b/pl_bolts/datamodules/binary_mnist_datamodule.py
@@ -58,7 +58,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ) -> None:
+    ):
         """
         Args:
             data_dir: Where to save/load the data
diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py
index 9dbf10b670..c5e880bf1e 100644
--- a/pl_bolts/datamodules/cifar10_datamodule.py
+++ b/pl_bolts/datamodules/cifar10_datamodule.py
@@ -114,7 +114,11 @@ def num_classes(self) -> int:
         """
         return 10
 
+<<<<<<< HEAD
     def default_transforms(self) -> Compose:
+=======
+    def default_transforms(self) -> transform_lib.Compose:
+>>>>>>> Adding types to datamodules
         if self.normalize:
             cf10_transforms = transform_lib.Compose([transform_lib.ToTensor(), cifar10_normalization()])
         else:
@@ -155,7 +159,7 @@ def __init__(
         labels: Optional[Sequence] = (1, 5, 8),
         *args: Any,
         **kwargs: Any,
-    ) -> None:
+    ):
         """
         Args:
             data_dir: where to save/load the data
diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py
index b37221bc74..97cb8de0cb 100644
--- a/pl_bolts/datamodules/fashion_mnist_datamodule.py
+++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py
@@ -59,7 +59,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ) -> None:
+    ):
         """
         Args:
             data_dir: Where to save/load the data
diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py
index 711460023c..2ab81f6422 100644
--- a/pl_bolts/datamodules/mnist_datamodule.py
+++ b/pl_bolts/datamodules/mnist_datamodule.py
@@ -58,7 +58,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ) -> None:
+    ):
         """
         Args:
             data_dir: Where to save/load the data
diff --git a/pl_bolts/datamodules/sklearn_datamodule.py b/pl_bolts/datamodules/sklearn_datamodule.py
index d9477acc0b..5c05ea3fce 100644
--- a/pl_bolts/datamodules/sklearn_datamodule.py
+++ b/pl_bolts/datamodules/sklearn_datamodule.py
@@ -204,7 +204,7 @@ def _init_datasets(
         x_val: np.ndarray,
         y_val: np.ndarray,
         x_test: np.ndarray,
-        y_test: np.ndarray
+<<<<<<< HEAD
     ) -> None:
         self.train_dataset = SklearnDataset(X, y)
         self.val_dataset = SklearnDataset(x_val, y_val)
diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py
index 92e6723968..5346713f5b 100644
--- a/pl_bolts/datamodules/vision_datamodule.py
+++ b/pl_bolts/datamodules/vision_datamodule.py
@@ -36,7 +36,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ) -> None:
+    ):
         """
         Args:
             data_dir: Where to save/load the data
@@ -63,14 +63,14 @@ def __init__(
         self.pin_memory = pin_memory
         self.drop_last = drop_last
 
-    def prepare_data(self) -> None:
+    def prepare_data(self):
         """
         Saves files to data_dir
         """
         self.dataset_cls(self.data_dir, train=True, download=True)
         self.dataset_cls(self.data_dir, train=False, download=True)
 
-    def setup(self, stage: Optional[str] = None) -> None:
+    def setup(self, stage: Optional[str] = None):
         """
         Creates train, val, and test dataset
         """

From 05fcef2aebae6d90d64cc964e5c3a460d4f51a25 Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Fri, 18 Dec 2020 20:20:20 +0900
Subject: [PATCH 31/75] Fixing typing imports

---
 pl_bolts/datamodules/vocdetection_datamodule.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py
index 52d5065d97..2ebcad6cc5 100644
--- a/pl_bolts/datamodules/vocdetection_datamodule.py
+++ b/pl_bolts/datamodules/vocdetection_datamodule.py
@@ -203,7 +203,7 @@ def val_dataloader(self, batch_size: int = 1, transforms=None) -> DataLoader:
         )
         return loader
 
-    def _default_transforms(self):
+    def _default_transforms(self) -> T.Compose:
         if self.normalize:
             return (
                 lambda image, target: (

From d92604b10545309db49a81b2a4cbb1aff66654e6 Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Sun, 20 Dec 2020 03:03:00 +0900
Subject: [PATCH 32/75] Removing torchvision.transforms from return typing

---
 pl_bolts/datamodules/cifar10_datamodule.py      | 4 ----
 pl_bolts/datamodules/vocdetection_datamodule.py | 2 +-
 2 files changed, 1 insertion(+), 5 deletions(-)

diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py
index c5e880bf1e..cbe0333050 100644
--- a/pl_bolts/datamodules/cifar10_datamodule.py
+++ b/pl_bolts/datamodules/cifar10_datamodule.py
@@ -114,11 +114,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-<<<<<<< HEAD
     def default_transforms(self) -> Compose:
-=======
-    def default_transforms(self) -> transform_lib.Compose:
->>>>>>> Adding types to datamodules
         if self.normalize:
             cf10_transforms = transform_lib.Compose([transform_lib.ToTensor(), cifar10_normalization()])
         else:
diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py
index 2ebcad6cc5..52d5065d97 100644
--- a/pl_bolts/datamodules/vocdetection_datamodule.py
+++ b/pl_bolts/datamodules/vocdetection_datamodule.py
@@ -203,7 +203,7 @@ def val_dataloader(self, batch_size: int = 1, transforms=None) -> DataLoader:
         )
         return loader
 
-    def _default_transforms(self) -> T.Compose:
+    def _default_transforms(self):
         if self.normalize:
             return (
                 lambda image, target: (

From a9c641b817f6276720aa6abf11698d086dc8f339 Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Mon, 21 Dec 2020 10:13:11 +0900
Subject: [PATCH 33/75] Removing return typing

---
 pl_bolts/datamodules/cifar10_datamodule.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py
index cbe0333050..cdca1b61a6 100644
--- a/pl_bolts/datamodules/cifar10_datamodule.py
+++ b/pl_bolts/datamodules/cifar10_datamodule.py
@@ -73,7 +73,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ) -> None:
+    ):
         """
         Args:
             data_dir: Where to save/load the data

From 314329c571d58ef2918b9da3776f5347bc4d7a92 Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Mon, 21 Dec 2020 16:51:36 +0900
Subject: [PATCH 34/75] Add `None` return type

---
 pl_bolts/datamodules/binary_mnist_datamodule.py  | 2 +-
 pl_bolts/datamodules/cifar10_datamodule.py       | 4 ++--
 pl_bolts/datamodules/fashion_mnist_datamodule.py | 2 +-
 pl_bolts/datamodules/mnist_datamodule.py         | 2 +-
 pl_bolts/datamodules/sklearn_datamodule.py       | 2 +-
 pl_bolts/datamodules/vision_datamodule.py        | 6 +++---
 6 files changed, 9 insertions(+), 9 deletions(-)

diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py
index 28e44ce72c..ad17360d08 100644
--- a/pl_bolts/datamodules/binary_mnist_datamodule.py
+++ b/pl_bolts/datamodules/binary_mnist_datamodule.py
@@ -58,7 +58,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: Where to save/load the data
diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py
index cdca1b61a6..9dbf10b670 100644
--- a/pl_bolts/datamodules/cifar10_datamodule.py
+++ b/pl_bolts/datamodules/cifar10_datamodule.py
@@ -73,7 +73,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: Where to save/load the data
@@ -155,7 +155,7 @@ def __init__(
         labels: Optional[Sequence] = (1, 5, 8),
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: where to save/load the data
diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py
index 97cb8de0cb..b37221bc74 100644
--- a/pl_bolts/datamodules/fashion_mnist_datamodule.py
+++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py
@@ -59,7 +59,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: Where to save/load the data
diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py
index 2ab81f6422..711460023c 100644
--- a/pl_bolts/datamodules/mnist_datamodule.py
+++ b/pl_bolts/datamodules/mnist_datamodule.py
@@ -58,7 +58,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: Where to save/load the data
diff --git a/pl_bolts/datamodules/sklearn_datamodule.py b/pl_bolts/datamodules/sklearn_datamodule.py
index 5c05ea3fce..d9477acc0b 100644
--- a/pl_bolts/datamodules/sklearn_datamodule.py
+++ b/pl_bolts/datamodules/sklearn_datamodule.py
@@ -204,7 +204,7 @@ def _init_datasets(
         x_val: np.ndarray,
         y_val: np.ndarray,
         x_test: np.ndarray,
-<<<<<<< HEAD
+        y_test: np.ndarray
     ) -> None:
         self.train_dataset = SklearnDataset(X, y)
         self.val_dataset = SklearnDataset(x_val, y_val)
diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py
index 5346713f5b..92e6723968 100644
--- a/pl_bolts/datamodules/vision_datamodule.py
+++ b/pl_bolts/datamodules/vision_datamodule.py
@@ -36,7 +36,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: Where to save/load the data
@@ -63,14 +63,14 @@ def __init__(
         self.pin_memory = pin_memory
         self.drop_last = drop_last
 
-    def prepare_data(self):
+    def prepare_data(self) -> None:
         """
         Saves files to data_dir
         """
         self.dataset_cls(self.data_dir, train=True, download=True)
         self.dataset_cls(self.data_dir, train=False, download=True)
 
-    def setup(self, stage: Optional[str] = None):
+    def setup(self, stage: Optional[str] = None) -> None:
         """
         Creates train, val, and test dataset
         """

From 14ea6b7da5ea2a15f44760cbfe6a131235ce2497 Mon Sep 17 00:00:00 2001
From: Jirka Borovec <jirka.borovec@seznam.cz>
Date: Tue, 5 Jan 2021 14:04:40 +0100
Subject: [PATCH 35/75] enable check

---
 setup.cfg | 3 ---
 1 file changed, 3 deletions(-)

diff --git a/setup.cfg b/setup.cfg
index dcd35979f9..bda41d20f4 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -61,9 +61,6 @@ show_error_codes = True
 disallow_untyped_defs = True
 ignore_missing_imports = True
 
-[mypy-pl_bolts.datamodules.*]
-ignore_errors = True
-
 [mypy-pl_bolts.datasets.*]
 ignore_errors = True
 

From 8a7c6f128a33dab0230734c30b0da7bfdb39e390 Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Fri, 18 Dec 2020 19:58:16 +0900
Subject: [PATCH 36/75] Adding types to datamodules

---
 .../datamodules/binary_mnist_datamodule.py    |  4 ++--
 pl_bolts/datamodules/cifar10_datamodule.py    |  6 ++---
 pl_bolts/datamodules/cityscapes_datamodule.py | 12 +++++-----
 .../datamodules/fashion_mnist_datamodule.py   |  4 ++--
 pl_bolts/datamodules/imagenet_datamodule.py   | 14 +++++------
 pl_bolts/datamodules/kitti_datamodule.py      |  8 +++----
 pl_bolts/datamodules/mnist_datamodule.py      |  4 ++--
 pl_bolts/datamodules/sklearn_datamodule.py    | 24 ++++++++++++-------
 .../datamodules/ssl_imagenet_datamodule.py    | 10 ++++----
 pl_bolts/datamodules/stl10_datamodule.py      | 16 ++++++-------
 pl_bolts/datamodules/vision_datamodule.py     | 16 +++++++++----
 11 files changed, 67 insertions(+), 51 deletions(-)

diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py
index 4dea946bf9..6e47040f90 100644
--- a/pl_bolts/datamodules/binary_mnist_datamodule.py
+++ b/pl_bolts/datamodules/binary_mnist_datamodule.py
@@ -56,7 +56,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ) -> None:
+    ):
         """
         Args:
             data_dir: Where to save/load the data
@@ -98,7 +98,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self):
+    def default_transforms(self) -> transform_lib.Compose:
         if self.normalize:
             mnist_transforms = transform_lib.Compose([
                 transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5, ), std=(0.5, ))
diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py
index afb2df8c9a..12aea1ec87 100644
--- a/pl_bolts/datamodules/cifar10_datamodule.py
+++ b/pl_bolts/datamodules/cifar10_datamodule.py
@@ -112,7 +112,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self):
+    def default_transforms(self) -> transform_lib.Compose:
         if self.normalize:
             cf10_transforms = transform_lib.Compose([transform_lib.ToTensor(), cifar10_normalization()])
         else:
@@ -146,14 +146,14 @@ class TinyCIFAR10DataModule(CIFAR10DataModule):
 
     def __init__(
         self,
-        data_dir: str,
+        data_dir: Optional[str],
         val_split: int = 50,
         num_workers: int = 16,
         num_samples: int = 100,
         labels: Optional[Sequence] = (1, 5, 8),
         *args: Any,
         **kwargs: Any,
-    ) -> None:
+    ):
         """
         Args:
             data_dir: where to save/load the data
diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py
index ba6acb947d..4e1bbe0699 100644
--- a/pl_bolts/datamodules/cityscapes_datamodule.py
+++ b/pl_bolts/datamodules/cityscapes_datamodule.py
@@ -109,14 +109,14 @@ def __init__(
         self.target_transforms = None
 
     @property
-    def num_classes(self):
+    def num_classes(self) -> int:
         """
         Return:
             30
         """
         return 30
 
-    def train_dataloader(self):
+    def train_dataloader(self) -> DataLoader:
         """
         Cityscapes train set
         """
@@ -143,7 +143,7 @@ def train_dataloader(self):
         )
         return loader
 
-    def val_dataloader(self):
+    def val_dataloader(self) -> DataLoader:
         """
         Cityscapes val set
         """
@@ -170,7 +170,7 @@ def val_dataloader(self):
         )
         return loader
 
-    def test_dataloader(self):
+    def test_dataloader(self) -> DataLoader:
         """
         Cityscapes test set
         """
@@ -196,7 +196,7 @@ def test_dataloader(self):
         )
         return loader
 
-    def _default_transforms(self):
+    def _default_transforms(self) -> transform_lib.Compose:
         cityscapes_transforms = transform_lib.Compose([
             transform_lib.ToTensor(),
             transform_lib.Normalize(
@@ -205,7 +205,7 @@ def _default_transforms(self):
         ])
         return cityscapes_transforms
 
-    def _default_target_transforms(self):
+    def _default_target_transforms(self) -> transform_lib.Compose:
         cityscapes_target_trasnforms = transform_lib.Compose([
             transform_lib.ToTensor(), transform_lib.Lambda(lambda t: t.squeeze())
         ])
diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py
index 8455d1f315..9e4022e218 100644
--- a/pl_bolts/datamodules/fashion_mnist_datamodule.py
+++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py
@@ -57,7 +57,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ) -> None:
+    ):
         """
         Args:
             data_dir: Where to save/load the data
@@ -93,7 +93,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self):
+    def default_transforms(self) -> transform_lib.Compose:
         if self.normalize:
             mnist_transforms = transform_lib.Compose([
                 transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5, ), std=(0.5, ))
diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py
index 60b6c32578..4c3dc2c41f 100644
--- a/pl_bolts/datamodules/imagenet_datamodule.py
+++ b/pl_bolts/datamodules/imagenet_datamodule.py
@@ -94,7 +94,7 @@ def __init__(
         self.num_samples = 1281167 - self.num_imgs_per_val_class * self.num_classes
 
     @property
-    def num_classes(self):
+    def num_classes(self) -> int:
         """
         Return:
 
@@ -103,7 +103,7 @@ def num_classes(self):
         """
         return 1000
 
-    def _verify_splits(self, data_dir, split):
+    def _verify_splits(self, data_dir: str, split: str):
         dirs = os.listdir(data_dir)
 
         if split not in dirs:
@@ -142,7 +142,7 @@ def prepare_data(self):
                 """
                 )
 
-    def train_dataloader(self):
+    def train_dataloader(self) -> DataLoader:
         """
         Uses the train split of imagenet2012 and puts away a portion of it for the validation split
         """
@@ -166,7 +166,7 @@ def train_dataloader(self):
         )
         return loader
 
-    def val_dataloader(self):
+    def val_dataloader(self) -> DataLoader:
         """
         Uses the part of the train split of imagenet2012  that was not used for training via `num_imgs_per_val_class`
 
@@ -193,7 +193,7 @@ def val_dataloader(self):
         )
         return loader
 
-    def test_dataloader(self):
+    def test_dataloader(self) -> DataLoader:
         """
         Uses the validation split of imagenet2012 for testing
         """
@@ -212,7 +212,7 @@ def test_dataloader(self):
         )
         return loader
 
-    def train_transform(self):
+    def train_transform(self) -> transform_lib.Compose:
         """
         The standard imagenet transforms
 
@@ -238,7 +238,7 @@ def train_transform(self):
 
         return preprocessing
 
-    def val_transform(self):
+    def val_transform(self) -> transform_lib.Compose:
         """
         The standard imagenet transforms for validation
 
diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py
index e2cb6fa828..dec13b4514 100644
--- a/pl_bolts/datamodules/kitti_datamodule.py
+++ b/pl_bolts/datamodules/kitti_datamodule.py
@@ -97,7 +97,7 @@ def __init__(
             kitti_dataset, lengths=[train_len, val_len, test_len], generator=torch.Generator().manual_seed(self.seed)
         )
 
-    def train_dataloader(self):
+    def train_dataloader(self) -> DataLoader:
         loader = DataLoader(
             self.trainset,
             batch_size=self.batch_size,
@@ -108,7 +108,7 @@ def train_dataloader(self):
         )
         return loader
 
-    def val_dataloader(self):
+    def val_dataloader(self) -> DataLoader:
         loader = DataLoader(
             self.valset,
             batch_size=self.batch_size,
@@ -119,7 +119,7 @@ def val_dataloader(self):
         )
         return loader
 
-    def test_dataloader(self):
+    def test_dataloader(self) -> DataLoader:
         loader = DataLoader(
             self.testset,
             batch_size=self.batch_size,
@@ -130,7 +130,7 @@ def test_dataloader(self):
         )
         return loader
 
-    def _default_transforms(self):
+    def _default_transforms(self) -> transform_lib.Compose:
         kitti_transforms = transforms.Compose([
             transforms.ToTensor(),
             transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753], std=[0.32064945, 0.32098866, 0.32325324])
diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py
index b700b23123..87d1d72418 100644
--- a/pl_bolts/datamodules/mnist_datamodule.py
+++ b/pl_bolts/datamodules/mnist_datamodule.py
@@ -56,7 +56,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ) -> None:
+    ):
         """
         Args:
             data_dir: Where to save/load the data
@@ -92,7 +92,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self):
+    def default_transforms(self) -> transform_lib.Compose:
         if self.normalize:
             mnist_transforms = transform_lib.Compose([
                 transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5, ), std=(0.5, ))
diff --git a/pl_bolts/datamodules/sklearn_datamodule.py b/pl_bolts/datamodules/sklearn_datamodule.py
index 00e333fd30..ef983a609f 100644
--- a/pl_bolts/datamodules/sklearn_datamodule.py
+++ b/pl_bolts/datamodules/sklearn_datamodule.py
@@ -43,10 +43,10 @@ def __init__(self, X: np.ndarray, y: np.ndarray, X_transform: Any = None, y_tran
         self.X_transform = X_transform
         self.y_transform = y_transform
 
-    def __len__(self):
+    def __len__(self) -> int:
         return len(self.X)
 
-    def __getitem__(self, idx):
+    def __getitem__(self, idx) -> Tuple[np.ndarray, np.ndarray]:
         x = self.X[idx].astype(np.float32)
         y = self.Y[idx]
 
@@ -91,10 +91,10 @@ def __init__(self, X: torch.Tensor, y: torch.Tensor, X_transform: Any = None, y_
         self.X_transform = X_transform
         self.y_transform = y_transform
 
-    def __len__(self):
+    def __len__(self) -> int:
         return len(self.X)
 
-    def __getitem__(self, idx):
+    def __getitem__(self, idx) -> Tuple[torch.Tensor, torch.Tensor]:
         x = self.X[idx].float()
         y = self.Y[idx]
 
@@ -200,12 +200,20 @@ def __init__(
 
         self._init_datasets(X, y, x_val, y_val, x_test, y_test)
 
-    def _init_datasets(self, X, y, x_val, y_val, x_test, y_test):
+    def _init_datasets(
+        self,
+        X: np.ndarray,
+        y: np.ndarray,
+        x_val: np.ndarray,
+        y_val: np.ndarray,
+        x_test: np.ndarray,
+        y_test: np.ndarray
+    ):
         self.train_dataset = SklearnDataset(X, y)
         self.val_dataset = SklearnDataset(x_val, y_val)
         self.test_dataset = SklearnDataset(x_test, y_test)
 
-    def train_dataloader(self):
+    def train_dataloader(self) -> DataLoader:
         loader = DataLoader(
             self.train_dataset,
             batch_size=self.batch_size,
@@ -216,7 +224,7 @@ def train_dataloader(self):
         )
         return loader
 
-    def val_dataloader(self):
+    def val_dataloader(self) -> DataLoader:
         loader = DataLoader(
             self.val_dataset,
             batch_size=self.batch_size,
@@ -227,7 +235,7 @@ def val_dataloader(self):
         )
         return loader
 
-    def test_dataloader(self):
+    def test_dataloader(self) -> DataLoader:
         loader = DataLoader(
             self.test_dataset,
             batch_size=self.batch_size,
diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py
index 7949d218e5..4ede93041a 100644
--- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py
+++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py
@@ -46,10 +46,10 @@ def __init__(
         self.drop_last = drop_last
 
     @property
-    def num_classes(self):
+    def num_classes(self) -> int:
         return 1000
 
-    def _verify_splits(self, data_dir, split):
+    def _verify_splits(self, data_dir: str, split: str):
         dirs = os.listdir(data_dir)
 
         if split not in dirs:
@@ -83,7 +83,7 @@ def prepare_data(self):
                 """
                 )
 
-    def train_dataloader(self, num_images_per_class=-1, add_normalize=False):
+    def train_dataloader(self, num_images_per_class: int = -1, add_normalize: bool = False) -> DataLoader:
         transforms = self._default_transforms() if self.train_transforms is None else self.train_transforms
 
         dataset = UnlabeledImagenet(
@@ -103,7 +103,7 @@ def train_dataloader(self, num_images_per_class=-1, add_normalize=False):
         )
         return loader
 
-    def val_dataloader(self, num_images_per_class=50, add_normalize=False):
+    def val_dataloader(self, num_images_per_class: int = 50, add_normalize: bool = False) -> DataLoader:
         transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms
 
         dataset = UnlabeledImagenet(
@@ -123,7 +123,7 @@ def val_dataloader(self, num_images_per_class=50, add_normalize=False):
         )
         return loader
 
-    def test_dataloader(self, num_images_per_class, add_normalize=False):
+    def test_dataloader(self, num_images_per_class: int, add_normalize: bool = False) -> DataLoader:
         transforms = self._default_transforms() if self.test_transforms is None else self.test_transforms
 
         dataset = UnlabeledImagenet(
diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py
index 90f3434aa1..30411a5ea2 100644
--- a/pl_bolts/datamodules/stl10_datamodule.py
+++ b/pl_bolts/datamodules/stl10_datamodule.py
@@ -99,7 +99,7 @@ def __init__(
         self.num_unlabeled_samples = 100000 - unlabeled_val_split
 
     @property
-    def num_classes(self):
+    def num_classes(self) -> int:
         return 10
 
     def prepare_data(self):
@@ -110,7 +110,7 @@ def prepare_data(self):
         STL10(self.data_dir, split='train', download=True, transform=transform_lib.ToTensor())
         STL10(self.data_dir, split='test', download=True, transform=transform_lib.ToTensor())
 
-    def train_dataloader(self):
+    def train_dataloader(self) -> DataLoader:
         """
         Loads the 'unlabeled' split minus a portion set aside for validation via `unlabeled_val_split`.
         """
@@ -132,7 +132,7 @@ def train_dataloader(self):
         )
         return loader
 
-    def train_dataloader_mixed(self):
+    def train_dataloader_mixed(self) -> DataLoader:
         """
         Loads a portion of the 'unlabeled' training data and 'train' (labeled) data.
         both portions have a subset removed for validation via `unlabeled_val_split` and `train_val_split`
@@ -169,7 +169,7 @@ def train_dataloader_mixed(self):
         )
         return loader
 
-    def val_dataloader(self):
+    def val_dataloader(self) -> DataLoader:
         """
         Loads a portion of the 'unlabeled' training data set aside for validation
         The val dataset = (unlabeled - train_val_split)
@@ -197,7 +197,7 @@ def val_dataloader(self):
         )
         return loader
 
-    def val_dataloader_mixed(self):
+    def val_dataloader_mixed(self) -> DataLoader:
         """
         Loads a portion of the 'unlabeled' training data set aside for validation along with
         the portion of the 'train' dataset to be used for validation
@@ -239,7 +239,7 @@ def val_dataloader_mixed(self):
         )
         return loader
 
-    def test_dataloader(self):
+    def test_dataloader(self) -> DataLoader:
         """
         Loads the test split of STL10
 
@@ -260,7 +260,7 @@ def test_dataloader(self):
         )
         return loader
 
-    def train_dataloader_labeled(self):
+    def train_dataloader_labeled(self) -> DataLoader:
         transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms
 
         dataset = STL10(self.data_dir, split='train', download=False, transform=transforms)
@@ -279,7 +279,7 @@ def train_dataloader_labeled(self):
         )
         return loader
 
-    def val_dataloader_labeled(self):
+    def val_dataloader_labeled(self) -> DataLoader:
         transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms
         dataset = STL10(self.data_dir, split='train', download=False, transform=transforms)
         labeled_length = len(dataset)
diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py
index 5a6f4af4c2..7ab8f1cccb 100644
--- a/pl_bolts/datamodules/vision_datamodule.py
+++ b/pl_bolts/datamodules/vision_datamodule.py
@@ -6,6 +6,14 @@
 from pytorch_lightning import LightningDataModule
 from torch.utils.data import DataLoader, Dataset, random_split
 
+from pl_bolts.utils import _TORCHVISION_AVAILABLE
+from pl_bolts.utils.warnings import warn_missing_pkg
+
+if _TORCHVISION_AVAILABLE:
+    from torchvision import transforms as transform_lib
+else:
+    warn_missing_pkg('torchvision')  # pragma: no-cover
+
 
 class VisionDataModule(LightningDataModule):
 
@@ -29,7 +37,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ) -> None:
+    ):
         """
         Args:
             data_dir: Where to save/load the data
@@ -56,14 +64,14 @@ def __init__(
         self.pin_memory = pin_memory
         self.drop_last = drop_last
 
-    def prepare_data(self) -> None:
+    def prepare_data(self):
         """
         Saves files to data_dir
         """
         self.dataset_cls(self.data_dir, train=True, download=True)
         self.dataset_cls(self.data_dir, train=False, download=True)
 
-    def setup(self, stage: Optional[str] = None) -> None:
+    def setup(self, stage: Optional[str] = None):
         """
         Creates train, val, and test dataset
         """
@@ -113,7 +121,7 @@ def _get_splits(self, len_dataset: int) -> List[int]:
         return splits
 
     @abstractmethod
-    def default_transforms(self):
+    def default_transforms(self) -> transform_lib.Compose:
         """ Default transform for the dataset """
 
     def train_dataloader(self) -> DataLoader:

From 3b0ee3cf129f8a72e6b0cf41f41a6d1c9c19cba8 Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Fri, 18 Dec 2020 20:20:20 +0900
Subject: [PATCH 37/75] Fixing typing imports

---
 pl_bolts/datamodules/async_dataloader.py       | 14 +++++++++++---
 pl_bolts/datamodules/cityscapes_datamodule.py  |  2 ++
 pl_bolts/datamodules/imagenet_datamodule.py    |  2 +-
 pl_bolts/datamodules/kitti_datamodule.py       |  3 ++-
 pl_bolts/datamodules/sklearn_datamodule.py     |  2 +-
 .../datamodules/ssl_imagenet_datamodule.py     |  1 +
 pl_bolts/datamodules/stl10_datamodule.py       |  2 +-
 pl_bolts/datamodules/vision_datamodule.py      |  7 +------
 .../datamodules/vocdetection_datamodule.py     | 18 ++++++++++--------
 9 files changed, 30 insertions(+), 21 deletions(-)

diff --git a/pl_bolts/datamodules/async_dataloader.py b/pl_bolts/datamodules/async_dataloader.py
index 7ded9d9ef1..38a0b9bb58 100644
--- a/pl_bolts/datamodules/async_dataloader.py
+++ b/pl_bolts/datamodules/async_dataloader.py
@@ -1,10 +1,11 @@
 import re
 from queue import Queue
 from threading import Thread
+from typing import Any, Optional, Union
 
 import torch
 from torch._six import container_abcs, string_classes
-from torch.utils.data import DataLoader
+from torch.utils.data import DataLoader, Dataset
 
 
 class AsynchronousLoader(object):
@@ -26,7 +27,14 @@ class AsynchronousLoader(object):
             constructing one here
     """
 
-    def __init__(self, data, device=torch.device('cuda', 0), q_size=10, num_batches=None, **kwargs):
+    def __init__(
+            self,
+            data: Union[DataLoader, Dataset],
+            device: torch.device = torch.device('cuda', 0),
+            q_size: int = 10,
+            num_batches: Optional[int] = None,
+            **kwargs: Any
+    ):
         if isinstance(data, torch.utils.data.DataLoader):
             self.dataloader = data
         else:
@@ -105,5 +113,5 @@ def __next__(self):
         self.idx += 1
         return out
 
-    def __len__(self):
+    def __len__(self) -> int:
         return self.num_batches
diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py
index 4e1bbe0699..ce5d4d52ee 100644
--- a/pl_bolts/datamodules/cityscapes_datamodule.py
+++ b/pl_bolts/datamodules/cityscapes_datamodule.py
@@ -1,3 +1,5 @@
+from typing import Any
+
 from pytorch_lightning import LightningDataModule
 from torch.utils.data import DataLoader
 
diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py
index 4c3dc2c41f..df4094d353 100644
--- a/pl_bolts/datamodules/imagenet_datamodule.py
+++ b/pl_bolts/datamodules/imagenet_datamodule.py
@@ -1,5 +1,5 @@
 import os
-from typing import Optional
+from typing import Any, Optional
 
 from pytorch_lightning import LightningDataModule
 from torch.utils.data import DataLoader
diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py
index dec13b4514..856f54a39e 100644
--- a/pl_bolts/datamodules/kitti_datamodule.py
+++ b/pl_bolts/datamodules/kitti_datamodule.py
@@ -1,4 +1,5 @@
 import os
+from typing import Any, Optional
 
 import torch
 from pytorch_lightning import LightningDataModule
@@ -130,7 +131,7 @@ def test_dataloader(self) -> DataLoader:
         )
         return loader
 
-    def _default_transforms(self) -> transform_lib.Compose:
+    def _default_transforms(self) -> transforms.Compose:
         kitti_transforms = transforms.Compose([
             transforms.ToTensor(),
             transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753], std=[0.32064945, 0.32098866, 0.32325324])
diff --git a/pl_bolts/datamodules/sklearn_datamodule.py b/pl_bolts/datamodules/sklearn_datamodule.py
index ef983a609f..a2ffca5ee8 100644
--- a/pl_bolts/datamodules/sklearn_datamodule.py
+++ b/pl_bolts/datamodules/sklearn_datamodule.py
@@ -1,5 +1,5 @@
 import math
-from typing import Any
+from typing import Any, Tuple
 
 import numpy as np
 import torch
diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py
index 4ede93041a..354cb4f02b 100644
--- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py
+++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py
@@ -1,4 +1,5 @@
 import os
+from typing import Any, Optional
 
 from pytorch_lightning import LightningDataModule
 from torch.utils.data import DataLoader
diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py
index 30411a5ea2..8f5ac120eb 100644
--- a/pl_bolts/datamodules/stl10_datamodule.py
+++ b/pl_bolts/datamodules/stl10_datamodule.py
@@ -1,5 +1,5 @@
 import os
-from typing import Optional
+from typing import Any, Optional
 
 import torch
 from pytorch_lightning import LightningDataModule
diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py
index 7ab8f1cccb..faac1663da 100644
--- a/pl_bolts/datamodules/vision_datamodule.py
+++ b/pl_bolts/datamodules/vision_datamodule.py
@@ -9,11 +9,6 @@
 from pl_bolts.utils import _TORCHVISION_AVAILABLE
 from pl_bolts.utils.warnings import warn_missing_pkg
 
-if _TORCHVISION_AVAILABLE:
-    from torchvision import transforms as transform_lib
-else:
-    warn_missing_pkg('torchvision')  # pragma: no-cover
-
 
 class VisionDataModule(LightningDataModule):
 
@@ -121,7 +116,7 @@ def _get_splits(self, len_dataset: int) -> List[int]:
         return splits
 
     @abstractmethod
-    def default_transforms(self) -> transform_lib.Compose:
+    def default_transforms(self):
         """ Default transform for the dataset """
 
     def train_dataloader(self) -> DataLoader:
diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py
index 6065dcf076..cb54d75d2e 100644
--- a/pl_bolts/datamodules/vocdetection_datamodule.py
+++ b/pl_bolts/datamodules/vocdetection_datamodule.py
@@ -1,3 +1,5 @@
+from typing import Any, Dict
+
 import torch
 from pytorch_lightning import LightningDataModule
 from torch.utils.data import DataLoader
@@ -17,7 +19,7 @@ class Compose(object):
     Like `torchvision.transforms.compose` but works for (image, target)
     """
 
-    def __init__(self, transforms):
+    def __init__(self, transforms: T.Compose):
         self.transforms = transforms
 
     def __call__(self, image, target):
@@ -55,7 +57,7 @@ def _collate_fn(batch):
 )
 
 
-def _prepare_voc_instance(image, target):
+def _prepare_voc_instance(image, target: Dict[str, Any]):
     """
     Prepares VOC dataset into appropriate target for fasterrcnn
 
@@ -113,8 +115,8 @@ def __init__(
         shuffle: bool = False,
         pin_memory: bool = False,
         drop_last: bool = False,
-        *args,
-        **kwargs,
+        *args: Any,
+        **kwargs: Any,
     ):
         if not _TORCHVISION_AVAILABLE:
             raise ModuleNotFoundError(  # pragma: no-cover
@@ -132,7 +134,7 @@ def __init__(
         self.drop_last = drop_last
 
     @property
-    def num_classes(self):
+    def num_classes(self) -> int:
         """
         Return:
             21
@@ -146,7 +148,7 @@ def prepare_data(self):
         VOCDetection(self.data_dir, year=self.year, image_set="train", download=True)
         VOCDetection(self.data_dir, year=self.year, image_set="val", download=True)
 
-    def train_dataloader(self, batch_size=1, transforms=None):
+    def train_dataloader(self, batch_size: int = 1, transforms=None) -> DataLoader:
         """
         VOCDetection train set uses the `train` subset
 
@@ -172,7 +174,7 @@ def train_dataloader(self, batch_size=1, transforms=None):
         )
         return loader
 
-    def val_dataloader(self, batch_size=1, transforms=None):
+    def val_dataloader(self, batch_size: int = 1, transforms=None) -> DataLoader:
         """
         VOCDetection val set uses the `val` subset
 
@@ -197,7 +199,7 @@ def val_dataloader(self, batch_size=1, transforms=None):
         )
         return loader
 
-    def _default_transforms(self):
+    def _default_transforms(self) -> T.Compose:
         if self.normalize:
             return (
                 lambda image, target: (

From 3d1c9a11c23b8eb4bd727b7e0efc4e3a3347bc01 Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Sun, 20 Dec 2020 03:03:00 +0900
Subject: [PATCH 38/75] Removing torchvision.transforms from return typing

---
 pl_bolts/datamodules/binary_mnist_datamodule.py  | 2 +-
 pl_bolts/datamodules/cifar10_datamodule.py       | 2 +-
 pl_bolts/datamodules/cityscapes_datamodule.py    | 4 ++--
 pl_bolts/datamodules/fashion_mnist_datamodule.py | 2 +-
 pl_bolts/datamodules/imagenet_datamodule.py      | 4 ++--
 pl_bolts/datamodules/kitti_datamodule.py         | 2 +-
 pl_bolts/datamodules/mnist_datamodule.py         | 2 +-
 pl_bolts/datamodules/vocdetection_datamodule.py  | 2 +-
 8 files changed, 10 insertions(+), 10 deletions(-)

diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py
index 6e47040f90..85de4f0ef6 100644
--- a/pl_bolts/datamodules/binary_mnist_datamodule.py
+++ b/pl_bolts/datamodules/binary_mnist_datamodule.py
@@ -98,7 +98,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self) -> transform_lib.Compose:
+    def default_transforms(self):
         if self.normalize:
             mnist_transforms = transform_lib.Compose([
                 transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5, ), std=(0.5, ))
diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py
index 12aea1ec87..b208172ed0 100644
--- a/pl_bolts/datamodules/cifar10_datamodule.py
+++ b/pl_bolts/datamodules/cifar10_datamodule.py
@@ -112,7 +112,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self) -> transform_lib.Compose:
+    def default_transforms(self):
         if self.normalize:
             cf10_transforms = transform_lib.Compose([transform_lib.ToTensor(), cifar10_normalization()])
         else:
diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py
index ce5d4d52ee..a2d2aa7950 100644
--- a/pl_bolts/datamodules/cityscapes_datamodule.py
+++ b/pl_bolts/datamodules/cityscapes_datamodule.py
@@ -198,7 +198,7 @@ def test_dataloader(self) -> DataLoader:
         )
         return loader
 
-    def _default_transforms(self) -> transform_lib.Compose:
+    def _default_transforms(self):
         cityscapes_transforms = transform_lib.Compose([
             transform_lib.ToTensor(),
             transform_lib.Normalize(
@@ -207,7 +207,7 @@ def _default_transforms(self) -> transform_lib.Compose:
         ])
         return cityscapes_transforms
 
-    def _default_target_transforms(self) -> transform_lib.Compose:
+    def _default_target_transforms(self):
         cityscapes_target_trasnforms = transform_lib.Compose([
             transform_lib.ToTensor(), transform_lib.Lambda(lambda t: t.squeeze())
         ])
diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py
index 9e4022e218..c8fd7232f8 100644
--- a/pl_bolts/datamodules/fashion_mnist_datamodule.py
+++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py
@@ -93,7 +93,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self) -> transform_lib.Compose:
+    def default_transforms(self):
         if self.normalize:
             mnist_transforms = transform_lib.Compose([
                 transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5, ), std=(0.5, ))
diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py
index df4094d353..80b6e6976a 100644
--- a/pl_bolts/datamodules/imagenet_datamodule.py
+++ b/pl_bolts/datamodules/imagenet_datamodule.py
@@ -212,7 +212,7 @@ def test_dataloader(self) -> DataLoader:
         )
         return loader
 
-    def train_transform(self) -> transform_lib.Compose:
+    def train_transform(self):
         """
         The standard imagenet transforms
 
@@ -238,7 +238,7 @@ def train_transform(self) -> transform_lib.Compose:
 
         return preprocessing
 
-    def val_transform(self) -> transform_lib.Compose:
+    def val_transform(self):
         """
         The standard imagenet transforms for validation
 
diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py
index 856f54a39e..6a052bdf56 100644
--- a/pl_bolts/datamodules/kitti_datamodule.py
+++ b/pl_bolts/datamodules/kitti_datamodule.py
@@ -131,7 +131,7 @@ def test_dataloader(self) -> DataLoader:
         )
         return loader
 
-    def _default_transforms(self) -> transforms.Compose:
+    def _default_transforms(self):
         kitti_transforms = transforms.Compose([
             transforms.ToTensor(),
             transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753], std=[0.32064945, 0.32098866, 0.32325324])
diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py
index 87d1d72418..5c8388facc 100644
--- a/pl_bolts/datamodules/mnist_datamodule.py
+++ b/pl_bolts/datamodules/mnist_datamodule.py
@@ -92,7 +92,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self) -> transform_lib.Compose:
+    def default_transforms(self):
         if self.normalize:
             mnist_transforms = transform_lib.Compose([
                 transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5, ), std=(0.5, ))
diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py
index cb54d75d2e..4011b6226a 100644
--- a/pl_bolts/datamodules/vocdetection_datamodule.py
+++ b/pl_bolts/datamodules/vocdetection_datamodule.py
@@ -199,7 +199,7 @@ def val_dataloader(self, batch_size: int = 1, transforms=None) -> DataLoader:
         )
         return loader
 
-    def _default_transforms(self) -> T.Compose:
+    def _default_transforms(self):
         if self.normalize:
             return (
                 lambda image, target: (

From 6cac90995387107381f90b269b336016441a2730 Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Sun, 20 Dec 2020 03:09:03 +0900
Subject: [PATCH 39/75] Remove more torchvision.transforms typing

---
 pl_bolts/datamodules/vocdetection_datamodule.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py
index 4011b6226a..1bd6706650 100644
--- a/pl_bolts/datamodules/vocdetection_datamodule.py
+++ b/pl_bolts/datamodules/vocdetection_datamodule.py
@@ -19,7 +19,7 @@ class Compose(object):
     Like `torchvision.transforms.compose` but works for (image, target)
     """
 
-    def __init__(self, transforms: T.Compose):
+    def __init__(self, transforms):
         self.transforms = transforms
 
     def __call__(self, image, target):

From c1ea0fb2bc34f0c5b2a083e848533b978a268408 Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Mon, 21 Dec 2020 10:13:11 +0900
Subject: [PATCH 40/75] Removing return typing

---
 pl_bolts/datamodules/cifar10_datamodule.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py
index b208172ed0..85ba4de6e7 100644
--- a/pl_bolts/datamodules/cifar10_datamodule.py
+++ b/pl_bolts/datamodules/cifar10_datamodule.py
@@ -71,7 +71,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ) -> None:
+    ):
         """
         Args:
             data_dir: Where to save/load the data

From c04caabd7b074a471dfb5608d3ddd51558482769 Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Mon, 21 Dec 2020 16:00:55 +0900
Subject: [PATCH 41/75] Add `None` for optional arguments

---
 pl_bolts/datamodules/cifar10_datamodule.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py
index 85ba4de6e7..534774684f 100644
--- a/pl_bolts/datamodules/cifar10_datamodule.py
+++ b/pl_bolts/datamodules/cifar10_datamodule.py
@@ -146,7 +146,7 @@ class TinyCIFAR10DataModule(CIFAR10DataModule):
 
     def __init__(
         self,
-        data_dir: Optional[str],
+        data_dir: Optional[str] = None,
         val_split: int = 50,
         num_workers: int = 16,
         num_samples: int = 100,

From 5b6bf64b319b08a2c0651958226027194985d164 Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Mon, 21 Dec 2020 16:04:36 +0900
Subject: [PATCH 42/75] Remove unnecessary import

---
 pl_bolts/datamodules/vision_datamodule.py | 3 ---
 1 file changed, 3 deletions(-)

diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py
index faac1663da..15648467e8 100644
--- a/pl_bolts/datamodules/vision_datamodule.py
+++ b/pl_bolts/datamodules/vision_datamodule.py
@@ -6,9 +6,6 @@
 from pytorch_lightning import LightningDataModule
 from torch.utils.data import DataLoader, Dataset, random_split
 
-from pl_bolts.utils import _TORCHVISION_AVAILABLE
-from pl_bolts.utils.warnings import warn_missing_pkg
-
 
 class VisionDataModule(LightningDataModule):
 

From 7ce736dfdb1fe0d5167e6f62234878d922f13251 Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Mon, 21 Dec 2020 16:51:36 +0900
Subject: [PATCH 43/75] Add `None` return type

---
 pl_bolts/datamodules/async_dataloader.py         |  4 ++--
 pl_bolts/datamodules/binary_mnist_datamodule.py  |  2 +-
 pl_bolts/datamodules/cifar10_datamodule.py       |  4 ++--
 pl_bolts/datamodules/cityscapes_datamodule.py    |  2 +-
 pl_bolts/datamodules/experience_source.py        |  4 ++--
 pl_bolts/datamodules/fashion_mnist_datamodule.py |  2 +-
 pl_bolts/datamodules/imagenet_datamodule.py      |  6 +++---
 pl_bolts/datamodules/kitti_datamodule.py         |  2 +-
 pl_bolts/datamodules/mnist_datamodule.py         |  2 +-
 pl_bolts/datamodules/sklearn_datamodule.py       | 14 ++++++++------
 pl_bolts/datamodules/ssl_imagenet_datamodule.py  |  6 +++---
 pl_bolts/datamodules/stl10_datamodule.py         |  4 ++--
 pl_bolts/datamodules/vision_datamodule.py        |  6 +++---
 pl_bolts/datamodules/vocdetection_datamodule.py  |  6 +++---
 14 files changed, 33 insertions(+), 31 deletions(-)

diff --git a/pl_bolts/datamodules/async_dataloader.py b/pl_bolts/datamodules/async_dataloader.py
index 38a0b9bb58..224f34d5ee 100644
--- a/pl_bolts/datamodules/async_dataloader.py
+++ b/pl_bolts/datamodules/async_dataloader.py
@@ -34,7 +34,7 @@ def __init__(
             q_size: int = 10,
             num_batches: Optional[int] = None,
             **kwargs: Any
-    ):
+    ) -> None:
         if isinstance(data, torch.utils.data.DataLoader):
             self.dataloader = data
         else:
@@ -57,7 +57,7 @@ def __init__(
 
         self.np_str_obj_array_pattern = re.compile(r'[SaUO]')
 
-    def load_loop(self):  # The loop that will load into the queue in the background
+    def load_loop(self) -> None:  # The loop that will load into the queue in the background
         for i, sample in enumerate(self.dataloader):
             self.queue.put(self.load_instance(sample))
             if i == len(self):
diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py
index 85de4f0ef6..4dea946bf9 100644
--- a/pl_bolts/datamodules/binary_mnist_datamodule.py
+++ b/pl_bolts/datamodules/binary_mnist_datamodule.py
@@ -56,7 +56,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: Where to save/load the data
diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py
index 534774684f..2cb894d749 100644
--- a/pl_bolts/datamodules/cifar10_datamodule.py
+++ b/pl_bolts/datamodules/cifar10_datamodule.py
@@ -71,7 +71,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: Where to save/load the data
@@ -153,7 +153,7 @@ def __init__(
         labels: Optional[Sequence] = (1, 5, 8),
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: where to save/load the data
diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py
index a2d2aa7950..7816236a27 100644
--- a/pl_bolts/datamodules/cityscapes_datamodule.py
+++ b/pl_bolts/datamodules/cityscapes_datamodule.py
@@ -73,7 +73,7 @@ def __init__(
         drop_last: bool = False,
         *args,
         **kwargs,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: where to load the data from path, i.e. where directory leftImg8bit and gtFine or gtCoarse
diff --git a/pl_bolts/datamodules/experience_source.py b/pl_bolts/datamodules/experience_source.py
index 6c85f76fd2..1bc7b0f8a8 100644
--- a/pl_bolts/datamodules/experience_source.py
+++ b/pl_bolts/datamodules/experience_source.py
@@ -27,7 +27,7 @@ class ExperienceSourceDataset(IterableDataset):
     The logic for the experience source and how the batch is generated is defined the Lightning model itself
     """
 
-    def __init__(self, generate_batch: Callable):
+    def __init__(self, generate_batch: Callable) -> None:
         self.generate_batch = generate_batch
 
     def __iter__(self) -> Iterable:
@@ -240,7 +240,7 @@ def pop_rewards_steps(self):
 class DiscountedExperienceSource(ExperienceSource):
     """Outputs experiences with a discounted reward over N steps"""
 
-    def __init__(self, env: Env, agent, n_steps: int = 1, gamma: float = 0.99):
+    def __init__(self, env: Env, agent, n_steps: int = 1, gamma: float = 0.99) -> None:
         super().__init__(env, agent, (n_steps + 1))
         self.gamma = gamma
         self.steps = n_steps
diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py
index c8fd7232f8..8455d1f315 100644
--- a/pl_bolts/datamodules/fashion_mnist_datamodule.py
+++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py
@@ -57,7 +57,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: Where to save/load the data
diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py
index 80b6e6976a..61ff477e82 100644
--- a/pl_bolts/datamodules/imagenet_datamodule.py
+++ b/pl_bolts/datamodules/imagenet_datamodule.py
@@ -60,7 +60,7 @@ def __init__(
         drop_last: bool = False,
         *args,
         **kwargs,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: path to the imagenet dataset file
@@ -103,7 +103,7 @@ def num_classes(self) -> int:
         """
         return 1000
 
-    def _verify_splits(self, data_dir: str, split: str):
+    def _verify_splits(self, data_dir: str, split: str) -> None:
         dirs = os.listdir(data_dir)
 
         if split not in dirs:
@@ -112,7 +112,7 @@ def _verify_splits(self, data_dir: str, split: str):
                 f' make sure the folder contains a subfolder named {split}'
             )
 
-    def prepare_data(self):
+    def prepare_data(self) -> None:
         """
         This method already assumes you have imagenet2012 downloaded.
         It validates the data using the meta.bin.
diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py
index 6a052bdf56..ee66d5c1dc 100644
--- a/pl_bolts/datamodules/kitti_datamodule.py
+++ b/pl_bolts/datamodules/kitti_datamodule.py
@@ -33,7 +33,7 @@ def __init__(
         drop_last: bool = False,
         *args,
         **kwargs,
-    ):
+    ) -> None:
         """
         Kitti train, validation and test dataloaders.
 
diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py
index 5c8388facc..b700b23123 100644
--- a/pl_bolts/datamodules/mnist_datamodule.py
+++ b/pl_bolts/datamodules/mnist_datamodule.py
@@ -56,7 +56,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: Where to save/load the data
diff --git a/pl_bolts/datamodules/sklearn_datamodule.py b/pl_bolts/datamodules/sklearn_datamodule.py
index a2ffca5ee8..be517f47bc 100644
--- a/pl_bolts/datamodules/sklearn_datamodule.py
+++ b/pl_bolts/datamodules/sklearn_datamodule.py
@@ -28,8 +28,9 @@ class SklearnDataset(Dataset):
         >>> len(dataset)
         506
     """
-
-    def __init__(self, X: np.ndarray, y: np.ndarray, X_transform: Any = None, y_transform: Any = None):
+    def __init__(
+        self, X: np.ndarray, y: np.ndarray, X_transform: Any = None, y_transform: Any = None
+    ) -> None:
         """
         Args:
             X: Numpy ndarray
@@ -76,8 +77,9 @@ class TensorDataset(Dataset):
         >>> len(dataset)
         10
     """
-
-    def __init__(self, X: torch.Tensor, y: torch.Tensor, X_transform: Any = None, y_transform: Any = None):
+    def __init__(
+        self, X: torch.Tensor, y: torch.Tensor, X_transform: Any = None, y_transform: Any = None
+    ) -> None:
         """
         Args:
             X: PyTorch tensor
@@ -160,7 +162,7 @@ def __init__(
         drop_last=False,
         *args,
         **kwargs,
-    ):
+    ) -> None:
 
         super().__init__(*args, **kwargs)
         self.num_workers = num_workers
@@ -208,7 +210,7 @@ def _init_datasets(
         y_val: np.ndarray,
         x_test: np.ndarray,
         y_test: np.ndarray
-    ):
+    ) -> None:
         self.train_dataset = SklearnDataset(X, y)
         self.val_dataset = SklearnDataset(x_val, y_val)
         self.test_dataset = SklearnDataset(x_test, y_test)
diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py
index 354cb4f02b..14656280ac 100644
--- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py
+++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py
@@ -30,7 +30,7 @@ def __init__(
         drop_last: bool = False,
         *args,
         **kwargs,
-    ):
+    ) -> None:
         super().__init__(*args, **kwargs)
 
         if not _TORCHVISION_AVAILABLE:
@@ -50,7 +50,7 @@ def __init__(
     def num_classes(self) -> int:
         return 1000
 
-    def _verify_splits(self, data_dir: str, split: str):
+    def _verify_splits(self, data_dir: str, split: str) -> None:
         dirs = os.listdir(data_dir)
 
         if split not in dirs:
@@ -59,7 +59,7 @@ def _verify_splits(self, data_dir: str, split: str):
                 f' folder contains a subfolder named {split}'
             )
 
-    def prepare_data(self):
+    def prepare_data(self) -> None:
         # imagenet cannot be downloaded... must provide path to folder with the train/val splits
         self._verify_splits(self.data_dir, 'train')
         self._verify_splits(self.data_dir, 'val')
diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py
index 8f5ac120eb..5bf9a9b084 100644
--- a/pl_bolts/datamodules/stl10_datamodule.py
+++ b/pl_bolts/datamodules/stl10_datamodule.py
@@ -65,7 +65,7 @@ def __init__(
         drop_last: bool = False,
         *args,
         **kwargs,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: where to save/load the data
@@ -102,7 +102,7 @@ def __init__(
     def num_classes(self) -> int:
         return 10
 
-    def prepare_data(self):
+    def prepare_data(self) -> None:
         """
         Downloads the unlabeled, train and test split
         """
diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py
index 15648467e8..5a6f4af4c2 100644
--- a/pl_bolts/datamodules/vision_datamodule.py
+++ b/pl_bolts/datamodules/vision_datamodule.py
@@ -29,7 +29,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: Where to save/load the data
@@ -56,14 +56,14 @@ def __init__(
         self.pin_memory = pin_memory
         self.drop_last = drop_last
 
-    def prepare_data(self):
+    def prepare_data(self) -> None:
         """
         Saves files to data_dir
         """
         self.dataset_cls(self.data_dir, train=True, download=True)
         self.dataset_cls(self.data_dir, train=False, download=True)
 
-    def setup(self, stage: Optional[str] = None):
+    def setup(self, stage: Optional[str] = None) -> None:
         """
         Creates train, val, and test dataset
         """
diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py
index 1bd6706650..be91eac99b 100644
--- a/pl_bolts/datamodules/vocdetection_datamodule.py
+++ b/pl_bolts/datamodules/vocdetection_datamodule.py
@@ -19,7 +19,7 @@ class Compose(object):
     Like `torchvision.transforms.compose` but works for (image, target)
     """
 
-    def __init__(self, transforms):
+    def __init__(self, transforms) -> None:
         self.transforms = transforms
 
     def __call__(self, image, target):
@@ -117,7 +117,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         if not _TORCHVISION_AVAILABLE:
             raise ModuleNotFoundError(  # pragma: no-cover
                 'You want to use VOC dataset loaded from `torchvision` which is not installed yet.'
@@ -141,7 +141,7 @@ def num_classes(self) -> int:
         """
         return 21
 
-    def prepare_data(self):
+    def prepare_data(self) -> None:
         """
         Saves VOCDetection files to data_dir
         """

From 7309adefb9b92aa1122d187a36e951e94681cf5d Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Tue, 5 Jan 2021 20:30:43 +0900
Subject: [PATCH 44/75] Add type for torchvision transforms

---
 pl_bolts/datamodules/binary_mnist_datamodule.py  | 4 +++-
 pl_bolts/datamodules/cifar10_datamodule.py       | 4 +++-
 pl_bolts/datamodules/cityscapes_datamodule.py    | 6 ++++--
 pl_bolts/datamodules/fashion_mnist_datamodule.py | 4 +++-
 pl_bolts/datamodules/imagenet_datamodule.py      | 6 ++++--
 pl_bolts/datamodules/kitti_datamodule.py         | 4 +++-
 pl_bolts/datamodules/mnist_datamodule.py         | 4 +++-
 pl_bolts/datamodules/ssl_imagenet_datamodule.py  | 9 +++++++--
 pl_bolts/datamodules/stl10_datamodule.py         | 9 +++++++--
 pl_bolts/datamodules/vision_datamodule.py        | 9 ++++++++-
 10 files changed, 45 insertions(+), 14 deletions(-)

diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py
index 4dea946bf9..cdd07ff1e2 100644
--- a/pl_bolts/datamodules/binary_mnist_datamodule.py
+++ b/pl_bolts/datamodules/binary_mnist_datamodule.py
@@ -7,8 +7,10 @@
 
 if _TORCHVISION_AVAILABLE:
     from torchvision import transforms as transform_lib
+    from torchvision.transforms import Compose
 else:  # pragma: no-cover
     warn_missing_pkg('torchvision')
+    Compose = object
 
 
 class BinaryMNISTDataModule(VisionDataModule):
@@ -98,7 +100,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self):
+    def default_transforms(self) -> Compose:
         if self.normalize:
             mnist_transforms = transform_lib.Compose([
                 transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5, ), std=(0.5, ))
diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py
index 2cb894d749..9dbf10b670 100644
--- a/pl_bolts/datamodules/cifar10_datamodule.py
+++ b/pl_bolts/datamodules/cifar10_datamodule.py
@@ -9,9 +9,11 @@
 if _TORCHVISION_AVAILABLE:
     from torchvision import transforms as transform_lib
     from torchvision.datasets import CIFAR10
+    from torchvision.transforms import Compose
 else:
     warn_missing_pkg('torchvision')  # pragma: no-cover
     CIFAR10 = None
+    Compose = object
 
 
 class CIFAR10DataModule(VisionDataModule):
@@ -112,7 +114,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self):
+    def default_transforms(self) -> Compose:
         if self.normalize:
             cf10_transforms = transform_lib.Compose([transform_lib.ToTensor(), cifar10_normalization()])
         else:
diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py
index 7816236a27..3721f897af 100644
--- a/pl_bolts/datamodules/cityscapes_datamodule.py
+++ b/pl_bolts/datamodules/cityscapes_datamodule.py
@@ -9,8 +9,10 @@
 if _TORCHVISION_AVAILABLE:
     from torchvision import transforms as transform_lib
     from torchvision.datasets import Cityscapes
+    from torchvision.transforms import Compose
 else:
     warn_missing_pkg('torchvision')  # pragma: no-cover
+    Compose = object
 
 
 class CityscapesDataModule(LightningDataModule):
@@ -198,7 +200,7 @@ def test_dataloader(self) -> DataLoader:
         )
         return loader
 
-    def _default_transforms(self):
+    def _default_transforms(self) -> Compose:
         cityscapes_transforms = transform_lib.Compose([
             transform_lib.ToTensor(),
             transform_lib.Normalize(
@@ -207,7 +209,7 @@ def _default_transforms(self):
         ])
         return cityscapes_transforms
 
-    def _default_target_transforms(self):
+    def _default_target_transforms(self) -> Compose:
         cityscapes_target_trasnforms = transform_lib.Compose([
             transform_lib.ToTensor(), transform_lib.Lambda(lambda t: t.squeeze())
         ])
diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py
index 8455d1f315..b31a5aa792 100644
--- a/pl_bolts/datamodules/fashion_mnist_datamodule.py
+++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py
@@ -7,9 +7,11 @@
 if _TORCHVISION_AVAILABLE:
     from torchvision import transforms as transform_lib
     from torchvision.datasets import FashionMNIST
+    from torchvision.transforms import Compose
 else:
     warn_missing_pkg('torchvision')  # pragma: no-cover
     FashionMNIST = None
+    Compose = object
 
 
 class FashionMNISTDataModule(VisionDataModule):
@@ -93,7 +95,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self):
+    def default_transforms(self) -> Compose:
         if self.normalize:
             mnist_transforms = transform_lib.Compose([
                 transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5, ), std=(0.5, ))
diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py
index 61ff477e82..b63611b060 100644
--- a/pl_bolts/datamodules/imagenet_datamodule.py
+++ b/pl_bolts/datamodules/imagenet_datamodule.py
@@ -11,8 +11,10 @@
 
 if _TORCHVISION_AVAILABLE:
     from torchvision import transforms as transform_lib
+    from torchvision.transforms import Compose
 else:
     warn_missing_pkg('torchvision')  # pragma: no-cover
+    Compose = object
 
 
 class ImagenetDataModule(LightningDataModule):
@@ -212,7 +214,7 @@ def test_dataloader(self) -> DataLoader:
         )
         return loader
 
-    def train_transform(self):
+    def train_transform(self) -> Compose:
         """
         The standard imagenet transforms
 
@@ -238,7 +240,7 @@ def train_transform(self):
 
         return preprocessing
 
-    def val_transform(self):
+    def val_transform(self) -> Compose:
         """
         The standard imagenet transforms for validation
 
diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py
index ee66d5c1dc..cab7529e91 100644
--- a/pl_bolts/datamodules/kitti_datamodule.py
+++ b/pl_bolts/datamodules/kitti_datamodule.py
@@ -12,8 +12,10 @@
 
 if _TORCHVISION_AVAILABLE:
     import torchvision.transforms as transforms
+    from torchvision.transforms import Compose
 else:  # pragma: no cover
     warn_missing_pkg('torchvision')
+    Compose = object
 
 
 class KittiDataModule(LightningDataModule):
@@ -131,7 +133,7 @@ def test_dataloader(self) -> DataLoader:
         )
         return loader
 
-    def _default_transforms(self):
+    def _default_transforms(self) -> Compose:
         kitti_transforms = transforms.Compose([
             transforms.ToTensor(),
             transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753], std=[0.32064945, 0.32098866, 0.32325324])
diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py
index b700b23123..d52315f41c 100644
--- a/pl_bolts/datamodules/mnist_datamodule.py
+++ b/pl_bolts/datamodules/mnist_datamodule.py
@@ -7,9 +7,11 @@
 if _TORCHVISION_AVAILABLE:
     from torchvision import transforms as transform_lib
     from torchvision.datasets import MNIST
+    from torchvision.transforms import Compose
 else:
     warn_missing_pkg('torchvision')  # pragma: no-cover
     MNIST = None
+    Compose = object
 
 
 class MNISTDataModule(VisionDataModule):
@@ -92,7 +94,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self):
+    def default_transforms(self) -> Compose:
         if self.normalize:
             mnist_transforms = transform_lib.Compose([
                 transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5, ), std=(0.5, ))
diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py
index 14656280ac..6ee430428b 100644
--- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py
+++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py
@@ -11,8 +11,10 @@
 
 if _TORCHVISION_AVAILABLE:
     from torchvision import transforms as transform_lib
+    from torchvision.transforms import Compose
 else:
     warn_missing_pkg('torchvision')  # pragma: no-cover
+    Compose = object
 
 
 class SSLImagenetDataModule(LightningDataModule):  # pragma: no cover
@@ -144,6 +146,9 @@ def test_dataloader(self, num_images_per_class: int, add_normalize: bool = False
         )
         return loader
 
-    def _default_transforms(self):
-        mnist_transforms = transform_lib.Compose([transform_lib.ToTensor(), imagenet_normalization()])
+    def _default_transforms(self) -> Compose:
+        mnist_transforms = transform_lib.Compose([
+            transform_lib.ToTensor(),
+            imagenet_normalization()
+        ])
         return mnist_transforms
diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py
index 5bf9a9b084..45c0c040ea 100644
--- a/pl_bolts/datamodules/stl10_datamodule.py
+++ b/pl_bolts/datamodules/stl10_datamodule.py
@@ -13,8 +13,10 @@
 if _TORCHVISION_AVAILABLE:
     from torchvision import transforms as transform_lib
     from torchvision.datasets import STL10
+    from torchvision.transforms import Compose
 else:
     warn_missing_pkg('torchvision')  # pragma: no-cover
+    Compose = object
 
 
 class STL10DataModule(LightningDataModule):  # pragma: no cover
@@ -298,6 +300,9 @@ def val_dataloader_labeled(self) -> DataLoader:
         )
         return loader
 
-    def _default_transforms(self):
-        data_transforms = transform_lib.Compose([transform_lib.ToTensor(), stl10_normalization()])
+    def _default_transforms(self) -> Compose:
+        data_transforms = transform_lib.Compose([
+            transform_lib.ToTensor(),
+            stl10_normalization()
+        ])
         return data_transforms
diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py
index 5a6f4af4c2..bab4c722dc 100644
--- a/pl_bolts/datamodules/vision_datamodule.py
+++ b/pl_bolts/datamodules/vision_datamodule.py
@@ -6,6 +6,13 @@
 from pytorch_lightning import LightningDataModule
 from torch.utils.data import DataLoader, Dataset, random_split
 
+from pl_bolts.utils import _TORCHVISION_AVAILABLE
+
+if _TORCHVISION_AVAILABLE:
+    from torchvision.transforms import Compose
+else:
+    Compose = object
+
 
 class VisionDataModule(LightningDataModule):
 
@@ -113,7 +120,7 @@ def _get_splits(self, len_dataset: int) -> List[int]:
         return splits
 
     @abstractmethod
-    def default_transforms(self):
+    def default_transforms(self) -> Compose:
         """ Default transform for the dataset """
 
     def train_dataloader(self) -> DataLoader:

From cc154a7242215fd2feb0907727f856d8d59a2067 Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Fri, 18 Dec 2020 19:58:16 +0900
Subject: [PATCH 45/75] Adding types to datamodules

---
 pl_bolts/datamodules/binary_mnist_datamodule.py  | 2 +-
 pl_bolts/datamodules/cifar10_datamodule.py       | 6 +++++-
 pl_bolts/datamodules/fashion_mnist_datamodule.py | 2 +-
 pl_bolts/datamodules/mnist_datamodule.py         | 2 +-
 pl_bolts/datamodules/sklearn_datamodule.py       | 2 +-
 pl_bolts/datamodules/vision_datamodule.py        | 6 +++---
 6 files changed, 12 insertions(+), 8 deletions(-)

diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py
index cdd07ff1e2..88ff0b359d 100644
--- a/pl_bolts/datamodules/binary_mnist_datamodule.py
+++ b/pl_bolts/datamodules/binary_mnist_datamodule.py
@@ -58,7 +58,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ) -> None:
+    ):
         """
         Args:
             data_dir: Where to save/load the data
diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py
index 9dbf10b670..c5e880bf1e 100644
--- a/pl_bolts/datamodules/cifar10_datamodule.py
+++ b/pl_bolts/datamodules/cifar10_datamodule.py
@@ -114,7 +114,11 @@ def num_classes(self) -> int:
         """
         return 10
 
+<<<<<<< HEAD
     def default_transforms(self) -> Compose:
+=======
+    def default_transforms(self) -> transform_lib.Compose:
+>>>>>>> Adding types to datamodules
         if self.normalize:
             cf10_transforms = transform_lib.Compose([transform_lib.ToTensor(), cifar10_normalization()])
         else:
@@ -155,7 +159,7 @@ def __init__(
         labels: Optional[Sequence] = (1, 5, 8),
         *args: Any,
         **kwargs: Any,
-    ) -> None:
+    ):
         """
         Args:
             data_dir: where to save/load the data
diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py
index b31a5aa792..8de6c99bd7 100644
--- a/pl_bolts/datamodules/fashion_mnist_datamodule.py
+++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py
@@ -59,7 +59,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ) -> None:
+    ):
         """
         Args:
             data_dir: Where to save/load the data
diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py
index d52315f41c..5013b8b6b5 100644
--- a/pl_bolts/datamodules/mnist_datamodule.py
+++ b/pl_bolts/datamodules/mnist_datamodule.py
@@ -58,7 +58,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ) -> None:
+    ):
         """
         Args:
             data_dir: Where to save/load the data
diff --git a/pl_bolts/datamodules/sklearn_datamodule.py b/pl_bolts/datamodules/sklearn_datamodule.py
index be517f47bc..ae8ad70d52 100644
--- a/pl_bolts/datamodules/sklearn_datamodule.py
+++ b/pl_bolts/datamodules/sklearn_datamodule.py
@@ -209,7 +209,7 @@ def _init_datasets(
         x_val: np.ndarray,
         y_val: np.ndarray,
         x_test: np.ndarray,
-        y_test: np.ndarray
+<<<<<<< HEAD
     ) -> None:
         self.train_dataset = SklearnDataset(X, y)
         self.val_dataset = SklearnDataset(x_val, y_val)
diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py
index bab4c722dc..ba95c26a20 100644
--- a/pl_bolts/datamodules/vision_datamodule.py
+++ b/pl_bolts/datamodules/vision_datamodule.py
@@ -36,7 +36,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ) -> None:
+    ):
         """
         Args:
             data_dir: Where to save/load the data
@@ -63,14 +63,14 @@ def __init__(
         self.pin_memory = pin_memory
         self.drop_last = drop_last
 
-    def prepare_data(self) -> None:
+    def prepare_data(self):
         """
         Saves files to data_dir
         """
         self.dataset_cls(self.data_dir, train=True, download=True)
         self.dataset_cls(self.data_dir, train=False, download=True)
 
-    def setup(self, stage: Optional[str] = None) -> None:
+    def setup(self, stage: Optional[str] = None):
         """
         Creates train, val, and test dataset
         """

From cff73307ca58a4df401a982475ae5534daf7a682 Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Fri, 18 Dec 2020 20:20:20 +0900
Subject: [PATCH 46/75] Fixing typing imports

---
 pl_bolts/datamodules/vocdetection_datamodule.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py
index be91eac99b..9c6e4b501b 100644
--- a/pl_bolts/datamodules/vocdetection_datamodule.py
+++ b/pl_bolts/datamodules/vocdetection_datamodule.py
@@ -199,7 +199,7 @@ def val_dataloader(self, batch_size: int = 1, transforms=None) -> DataLoader:
         )
         return loader
 
-    def _default_transforms(self):
+    def _default_transforms(self) -> T.Compose:
         if self.normalize:
             return (
                 lambda image, target: (

From 3bbc1894237bcd616c00882d0f19a9c607fe93d0 Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Sun, 20 Dec 2020 03:03:00 +0900
Subject: [PATCH 47/75] Removing torchvision.transforms from return typing

---
 pl_bolts/datamodules/cifar10_datamodule.py      | 4 ----
 pl_bolts/datamodules/vocdetection_datamodule.py | 2 +-
 2 files changed, 1 insertion(+), 5 deletions(-)

diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py
index c5e880bf1e..cbe0333050 100644
--- a/pl_bolts/datamodules/cifar10_datamodule.py
+++ b/pl_bolts/datamodules/cifar10_datamodule.py
@@ -114,11 +114,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-<<<<<<< HEAD
     def default_transforms(self) -> Compose:
-=======
-    def default_transforms(self) -> transform_lib.Compose:
->>>>>>> Adding types to datamodules
         if self.normalize:
             cf10_transforms = transform_lib.Compose([transform_lib.ToTensor(), cifar10_normalization()])
         else:
diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py
index 9c6e4b501b..be91eac99b 100644
--- a/pl_bolts/datamodules/vocdetection_datamodule.py
+++ b/pl_bolts/datamodules/vocdetection_datamodule.py
@@ -199,7 +199,7 @@ def val_dataloader(self, batch_size: int = 1, transforms=None) -> DataLoader:
         )
         return loader
 
-    def _default_transforms(self) -> T.Compose:
+    def _default_transforms(self):
         if self.normalize:
             return (
                 lambda image, target: (

From 5b2401b6f8422d34593d9b518ad9da69706e4395 Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Mon, 21 Dec 2020 10:13:11 +0900
Subject: [PATCH 48/75] Removing return typing

---
 pl_bolts/datamodules/cifar10_datamodule.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py
index cbe0333050..cdca1b61a6 100644
--- a/pl_bolts/datamodules/cifar10_datamodule.py
+++ b/pl_bolts/datamodules/cifar10_datamodule.py
@@ -73,7 +73,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ) -> None:
+    ):
         """
         Args:
             data_dir: Where to save/load the data

From 7eb32b8e84c547e85feb33d808e98256de9fe450 Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Mon, 21 Dec 2020 16:51:36 +0900
Subject: [PATCH 49/75] Add `None` return type

---
 pl_bolts/datamodules/binary_mnist_datamodule.py  | 2 +-
 pl_bolts/datamodules/cifar10_datamodule.py       | 4 ++--
 pl_bolts/datamodules/fashion_mnist_datamodule.py | 2 +-
 pl_bolts/datamodules/mnist_datamodule.py         | 2 +-
 pl_bolts/datamodules/sklearn_datamodule.py       | 2 +-
 pl_bolts/datamodules/vision_datamodule.py        | 6 +++---
 6 files changed, 9 insertions(+), 9 deletions(-)

diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py
index 88ff0b359d..cdd07ff1e2 100644
--- a/pl_bolts/datamodules/binary_mnist_datamodule.py
+++ b/pl_bolts/datamodules/binary_mnist_datamodule.py
@@ -58,7 +58,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: Where to save/load the data
diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py
index cdca1b61a6..9dbf10b670 100644
--- a/pl_bolts/datamodules/cifar10_datamodule.py
+++ b/pl_bolts/datamodules/cifar10_datamodule.py
@@ -73,7 +73,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: Where to save/load the data
@@ -155,7 +155,7 @@ def __init__(
         labels: Optional[Sequence] = (1, 5, 8),
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: where to save/load the data
diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py
index 8de6c99bd7..b31a5aa792 100644
--- a/pl_bolts/datamodules/fashion_mnist_datamodule.py
+++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py
@@ -59,7 +59,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: Where to save/load the data
diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py
index 5013b8b6b5..d52315f41c 100644
--- a/pl_bolts/datamodules/mnist_datamodule.py
+++ b/pl_bolts/datamodules/mnist_datamodule.py
@@ -58,7 +58,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: Where to save/load the data
diff --git a/pl_bolts/datamodules/sklearn_datamodule.py b/pl_bolts/datamodules/sklearn_datamodule.py
index ae8ad70d52..be517f47bc 100644
--- a/pl_bolts/datamodules/sklearn_datamodule.py
+++ b/pl_bolts/datamodules/sklearn_datamodule.py
@@ -209,7 +209,7 @@ def _init_datasets(
         x_val: np.ndarray,
         y_val: np.ndarray,
         x_test: np.ndarray,
-<<<<<<< HEAD
+        y_test: np.ndarray
     ) -> None:
         self.train_dataset = SklearnDataset(X, y)
         self.val_dataset = SklearnDataset(x_val, y_val)
diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py
index ba95c26a20..bab4c722dc 100644
--- a/pl_bolts/datamodules/vision_datamodule.py
+++ b/pl_bolts/datamodules/vision_datamodule.py
@@ -36,7 +36,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: Where to save/load the data
@@ -63,14 +63,14 @@ def __init__(
         self.pin_memory = pin_memory
         self.drop_last = drop_last
 
-    def prepare_data(self):
+    def prepare_data(self) -> None:
         """
         Saves files to data_dir
         """
         self.dataset_cls(self.data_dir, train=True, download=True)
         self.dataset_cls(self.data_dir, train=False, download=True)
 
-    def setup(self, stage: Optional[str] = None):
+    def setup(self, stage: Optional[str] = None) -> None:
         """
         Creates train, val, and test dataset
         """

From 47cca3231091322c155c49369f21a02fba5ecbbb Mon Sep 17 00:00:00 2001
From: Jirka Borovec <jirka.borovec@seznam.cz>
Date: Tue, 5 Jan 2021 14:04:40 +0100
Subject: [PATCH 50/75] enable check

---
 setup.cfg | 3 ---
 1 file changed, 3 deletions(-)

diff --git a/setup.cfg b/setup.cfg
index 080004f375..12525b46ae 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -72,9 +72,6 @@ show_error_codes = True
 disallow_untyped_defs = True
 ignore_missing_imports = True
 
-[mypy-pl_bolts.datamodules.*]
-ignore_errors = True
-
 [mypy-pl_bolts.datasets.*]
 ignore_errors = True
 

From 52e48113aa4fb16d44e2a524869f3c187b51b63f Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Fri, 18 Dec 2020 19:58:16 +0900
Subject: [PATCH 51/75] Adding types to datamodules

---
 pl_bolts/datamodules/binary_mnist_datamodule.py  | 2 +-
 pl_bolts/datamodules/fashion_mnist_datamodule.py | 2 +-
 pl_bolts/datamodules/mnist_datamodule.py         | 2 +-
 pl_bolts/datamodules/vision_datamodule.py        | 6 +++---
 4 files changed, 6 insertions(+), 6 deletions(-)

diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py
index cdd07ff1e2..88ff0b359d 100644
--- a/pl_bolts/datamodules/binary_mnist_datamodule.py
+++ b/pl_bolts/datamodules/binary_mnist_datamodule.py
@@ -58,7 +58,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ) -> None:
+    ):
         """
         Args:
             data_dir: Where to save/load the data
diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py
index b31a5aa792..8de6c99bd7 100644
--- a/pl_bolts/datamodules/fashion_mnist_datamodule.py
+++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py
@@ -59,7 +59,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ) -> None:
+    ):
         """
         Args:
             data_dir: Where to save/load the data
diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py
index d52315f41c..5013b8b6b5 100644
--- a/pl_bolts/datamodules/mnist_datamodule.py
+++ b/pl_bolts/datamodules/mnist_datamodule.py
@@ -58,7 +58,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ) -> None:
+    ):
         """
         Args:
             data_dir: Where to save/load the data
diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py
index bab4c722dc..ba95c26a20 100644
--- a/pl_bolts/datamodules/vision_datamodule.py
+++ b/pl_bolts/datamodules/vision_datamodule.py
@@ -36,7 +36,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ) -> None:
+    ):
         """
         Args:
             data_dir: Where to save/load the data
@@ -63,14 +63,14 @@ def __init__(
         self.pin_memory = pin_memory
         self.drop_last = drop_last
 
-    def prepare_data(self) -> None:
+    def prepare_data(self):
         """
         Saves files to data_dir
         """
         self.dataset_cls(self.data_dir, train=True, download=True)
         self.dataset_cls(self.data_dir, train=False, download=True)
 
-    def setup(self, stage: Optional[str] = None) -> None:
+    def setup(self, stage: Optional[str] = None):
         """
         Creates train, val, and test dataset
         """

From 64f871edf3e27f7c2b048061ed9895c0b038135d Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Fri, 18 Dec 2020 20:20:20 +0900
Subject: [PATCH 52/75] Fixing typing imports

---
 pl_bolts/datamodules/vocdetection_datamodule.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py
index be91eac99b..9c6e4b501b 100644
--- a/pl_bolts/datamodules/vocdetection_datamodule.py
+++ b/pl_bolts/datamodules/vocdetection_datamodule.py
@@ -199,7 +199,7 @@ def val_dataloader(self, batch_size: int = 1, transforms=None) -> DataLoader:
         )
         return loader
 
-    def _default_transforms(self):
+    def _default_transforms(self) -> T.Compose:
         if self.normalize:
             return (
                 lambda image, target: (

From bf6ee1cbadd30fa48b117c9d737483445636818b Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Sun, 20 Dec 2020 03:03:00 +0900
Subject: [PATCH 53/75] Removing torchvision.transforms from return typing

---
 pl_bolts/datamodules/vocdetection_datamodule.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py
index 9c6e4b501b..be91eac99b 100644
--- a/pl_bolts/datamodules/vocdetection_datamodule.py
+++ b/pl_bolts/datamodules/vocdetection_datamodule.py
@@ -199,7 +199,7 @@ def val_dataloader(self, batch_size: int = 1, transforms=None) -> DataLoader:
         )
         return loader
 
-    def _default_transforms(self) -> T.Compose:
+    def _default_transforms(self):
         if self.normalize:
             return (
                 lambda image, target: (

From b5c6dccf886d7b869f08c8a8c24e42b895ab9298 Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Mon, 21 Dec 2020 10:13:11 +0900
Subject: [PATCH 54/75] Removing return typing

---
 pl_bolts/datamodules/cifar10_datamodule.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py
index 9dbf10b670..6c15801aae 100644
--- a/pl_bolts/datamodules/cifar10_datamodule.py
+++ b/pl_bolts/datamodules/cifar10_datamodule.py
@@ -73,7 +73,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ) -> None:
+    ):
         """
         Args:
             data_dir: Where to save/load the data

From 984a9629db081607e65e628c7d9f35d63befb46a Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Mon, 21 Dec 2020 16:51:36 +0900
Subject: [PATCH 55/75] Add `None` return type

---
 pl_bolts/datamodules/binary_mnist_datamodule.py  | 2 +-
 pl_bolts/datamodules/cifar10_datamodule.py       | 2 +-
 pl_bolts/datamodules/fashion_mnist_datamodule.py | 2 +-
 pl_bolts/datamodules/mnist_datamodule.py         | 2 +-
 pl_bolts/datamodules/vision_datamodule.py        | 6 +++---
 5 files changed, 7 insertions(+), 7 deletions(-)

diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py
index 88ff0b359d..cdd07ff1e2 100644
--- a/pl_bolts/datamodules/binary_mnist_datamodule.py
+++ b/pl_bolts/datamodules/binary_mnist_datamodule.py
@@ -58,7 +58,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: Where to save/load the data
diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py
index 6c15801aae..9dbf10b670 100644
--- a/pl_bolts/datamodules/cifar10_datamodule.py
+++ b/pl_bolts/datamodules/cifar10_datamodule.py
@@ -73,7 +73,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: Where to save/load the data
diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py
index 8de6c99bd7..b31a5aa792 100644
--- a/pl_bolts/datamodules/fashion_mnist_datamodule.py
+++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py
@@ -59,7 +59,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: Where to save/load the data
diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py
index 5013b8b6b5..d52315f41c 100644
--- a/pl_bolts/datamodules/mnist_datamodule.py
+++ b/pl_bolts/datamodules/mnist_datamodule.py
@@ -58,7 +58,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: Where to save/load the data
diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py
index ba95c26a20..bab4c722dc 100644
--- a/pl_bolts/datamodules/vision_datamodule.py
+++ b/pl_bolts/datamodules/vision_datamodule.py
@@ -36,7 +36,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: Where to save/load the data
@@ -63,14 +63,14 @@ def __init__(
         self.pin_memory = pin_memory
         self.drop_last = drop_last
 
-    def prepare_data(self):
+    def prepare_data(self) -> None:
         """
         Saves files to data_dir
         """
         self.dataset_cls(self.data_dir, train=True, download=True)
         self.dataset_cls(self.data_dir, train=False, download=True)
 
-    def setup(self, stage: Optional[str] = None):
+    def setup(self, stage: Optional[str] = None) -> None:
         """
         Creates train, val, and test dataset
         """

From 7894471f3038f08df28578c6e92f94d4e6e308f6 Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Fri, 18 Dec 2020 19:58:16 +0900
Subject: [PATCH 56/75] Adding types to datamodules

---
 pl_bolts/datamodules/binary_mnist_datamodule.py  | 2 +-
 pl_bolts/datamodules/cifar10_datamodule.py       | 2 +-
 pl_bolts/datamodules/fashion_mnist_datamodule.py | 2 +-
 pl_bolts/datamodules/mnist_datamodule.py         | 2 +-
 pl_bolts/datamodules/vision_datamodule.py        | 6 +++---
 5 files changed, 7 insertions(+), 7 deletions(-)

diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py
index cdd07ff1e2..88ff0b359d 100644
--- a/pl_bolts/datamodules/binary_mnist_datamodule.py
+++ b/pl_bolts/datamodules/binary_mnist_datamodule.py
@@ -58,7 +58,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ) -> None:
+    ):
         """
         Args:
             data_dir: Where to save/load the data
diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py
index 9dbf10b670..cbe0333050 100644
--- a/pl_bolts/datamodules/cifar10_datamodule.py
+++ b/pl_bolts/datamodules/cifar10_datamodule.py
@@ -155,7 +155,7 @@ def __init__(
         labels: Optional[Sequence] = (1, 5, 8),
         *args: Any,
         **kwargs: Any,
-    ) -> None:
+    ):
         """
         Args:
             data_dir: where to save/load the data
diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py
index b31a5aa792..8de6c99bd7 100644
--- a/pl_bolts/datamodules/fashion_mnist_datamodule.py
+++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py
@@ -59,7 +59,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ) -> None:
+    ):
         """
         Args:
             data_dir: Where to save/load the data
diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py
index d52315f41c..5013b8b6b5 100644
--- a/pl_bolts/datamodules/mnist_datamodule.py
+++ b/pl_bolts/datamodules/mnist_datamodule.py
@@ -58,7 +58,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ) -> None:
+    ):
         """
         Args:
             data_dir: Where to save/load the data
diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py
index bab4c722dc..ba95c26a20 100644
--- a/pl_bolts/datamodules/vision_datamodule.py
+++ b/pl_bolts/datamodules/vision_datamodule.py
@@ -36,7 +36,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ) -> None:
+    ):
         """
         Args:
             data_dir: Where to save/load the data
@@ -63,14 +63,14 @@ def __init__(
         self.pin_memory = pin_memory
         self.drop_last = drop_last
 
-    def prepare_data(self) -> None:
+    def prepare_data(self):
         """
         Saves files to data_dir
         """
         self.dataset_cls(self.data_dir, train=True, download=True)
         self.dataset_cls(self.data_dir, train=False, download=True)
 
-    def setup(self, stage: Optional[str] = None) -> None:
+    def setup(self, stage: Optional[str] = None):
         """
         Creates train, val, and test dataset
         """

From 3443883e97cd37047ae258b851dd3eec7764416c Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Fri, 18 Dec 2020 20:20:20 +0900
Subject: [PATCH 57/75] Fixing typing imports

---
 pl_bolts/datamodules/vocdetection_datamodule.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py
index be91eac99b..9c6e4b501b 100644
--- a/pl_bolts/datamodules/vocdetection_datamodule.py
+++ b/pl_bolts/datamodules/vocdetection_datamodule.py
@@ -199,7 +199,7 @@ def val_dataloader(self, batch_size: int = 1, transforms=None) -> DataLoader:
         )
         return loader
 
-    def _default_transforms(self):
+    def _default_transforms(self) -> T.Compose:
         if self.normalize:
             return (
                 lambda image, target: (

From 51f8f167ee745ec81af993408683642fa865e22b Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Sun, 20 Dec 2020 03:03:00 +0900
Subject: [PATCH 58/75] Removing torchvision.transforms from return typing

---
 pl_bolts/datamodules/vocdetection_datamodule.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py
index 9c6e4b501b..be91eac99b 100644
--- a/pl_bolts/datamodules/vocdetection_datamodule.py
+++ b/pl_bolts/datamodules/vocdetection_datamodule.py
@@ -199,7 +199,7 @@ def val_dataloader(self, batch_size: int = 1, transforms=None) -> DataLoader:
         )
         return loader
 
-    def _default_transforms(self) -> T.Compose:
+    def _default_transforms(self):
         if self.normalize:
             return (
                 lambda image, target: (

From 3062dba63676eba615c569d834177bca6df1e3ff Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Mon, 21 Dec 2020 10:13:11 +0900
Subject: [PATCH 59/75] Removing return typing

---
 pl_bolts/datamodules/cifar10_datamodule.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py
index cbe0333050..cdca1b61a6 100644
--- a/pl_bolts/datamodules/cifar10_datamodule.py
+++ b/pl_bolts/datamodules/cifar10_datamodule.py
@@ -73,7 +73,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ) -> None:
+    ):
         """
         Args:
             data_dir: Where to save/load the data

From 53ebe33ff4f30936335b96421afe8606a7a630a9 Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Mon, 21 Dec 2020 16:51:36 +0900
Subject: [PATCH 60/75] Add `None` return type

---
 pl_bolts/datamodules/binary_mnist_datamodule.py  | 2 +-
 pl_bolts/datamodules/cifar10_datamodule.py       | 4 ++--
 pl_bolts/datamodules/fashion_mnist_datamodule.py | 2 +-
 pl_bolts/datamodules/mnist_datamodule.py         | 2 +-
 pl_bolts/datamodules/vision_datamodule.py        | 6 +++---
 5 files changed, 8 insertions(+), 8 deletions(-)

diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py
index 88ff0b359d..cdd07ff1e2 100644
--- a/pl_bolts/datamodules/binary_mnist_datamodule.py
+++ b/pl_bolts/datamodules/binary_mnist_datamodule.py
@@ -58,7 +58,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: Where to save/load the data
diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py
index cdca1b61a6..9dbf10b670 100644
--- a/pl_bolts/datamodules/cifar10_datamodule.py
+++ b/pl_bolts/datamodules/cifar10_datamodule.py
@@ -73,7 +73,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: Where to save/load the data
@@ -155,7 +155,7 @@ def __init__(
         labels: Optional[Sequence] = (1, 5, 8),
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: where to save/load the data
diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py
index 8de6c99bd7..b31a5aa792 100644
--- a/pl_bolts/datamodules/fashion_mnist_datamodule.py
+++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py
@@ -59,7 +59,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: Where to save/load the data
diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py
index 5013b8b6b5..d52315f41c 100644
--- a/pl_bolts/datamodules/mnist_datamodule.py
+++ b/pl_bolts/datamodules/mnist_datamodule.py
@@ -58,7 +58,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: Where to save/load the data
diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py
index ba95c26a20..bab4c722dc 100644
--- a/pl_bolts/datamodules/vision_datamodule.py
+++ b/pl_bolts/datamodules/vision_datamodule.py
@@ -36,7 +36,7 @@ def __init__(
         drop_last: bool = False,
         *args: Any,
         **kwargs: Any,
-    ):
+    ) -> None:
         """
         Args:
             data_dir: Where to save/load the data
@@ -63,14 +63,14 @@ def __init__(
         self.pin_memory = pin_memory
         self.drop_last = drop_last
 
-    def prepare_data(self):
+    def prepare_data(self) -> None:
         """
         Saves files to data_dir
         """
         self.dataset_cls(self.data_dir, train=True, download=True)
         self.dataset_cls(self.data_dir, train=False, download=True)
 
-    def setup(self, stage: Optional[str] = None):
+    def setup(self, stage: Optional[str] = None) -> None:
         """
         Creates train, val, and test dataset
         """

From c15efdb5959b510e965ec742f6469ae18c016a70 Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Tue, 12 Jan 2021 19:52:18 +0900
Subject: [PATCH 61/75] Fix rebasing mistakes

---
 pl_bolts/datamodules/cityscapes_datamodule.py   | 4 ++--
 pl_bolts/datamodules/imagenet_datamodule.py     | 4 ++--
 pl_bolts/datamodules/kitti_datamodule.py        | 4 ++--
 pl_bolts/datamodules/ssl_imagenet_datamodule.py | 8 ++++----
 pl_bolts/datamodules/stl10_datamodule.py        | 4 ++--
 5 files changed, 12 insertions(+), 12 deletions(-)

diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py
index 3721f897af..41465840dc 100644
--- a/pl_bolts/datamodules/cityscapes_datamodule.py
+++ b/pl_bolts/datamodules/cityscapes_datamodule.py
@@ -73,8 +73,8 @@ def __init__(
         shuffle: bool = False,
         pin_memory: bool = False,
         drop_last: bool = False,
-        *args,
-        **kwargs,
+        *args: Any,
+        **kwargs: Any,
     ) -> None:
         """
         Args:
diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py
index b63611b060..301894c074 100644
--- a/pl_bolts/datamodules/imagenet_datamodule.py
+++ b/pl_bolts/datamodules/imagenet_datamodule.py
@@ -60,8 +60,8 @@ def __init__(
         shuffle: bool = False,
         pin_memory: bool = False,
         drop_last: bool = False,
-        *args,
-        **kwargs,
+        *args: Any,
+        **kwargs: Any,
     ) -> None:
         """
         Args:
diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py
index cab7529e91..461b2244a8 100644
--- a/pl_bolts/datamodules/kitti_datamodule.py
+++ b/pl_bolts/datamodules/kitti_datamodule.py
@@ -33,8 +33,8 @@ def __init__(
         shuffle: bool = False,
         pin_memory: bool = False,
         drop_last: bool = False,
-        *args,
-        **kwargs,
+        *args: Any,
+        **kwargs: Any,
     ) -> None:
         """
         Kitti train, validation and test dataloaders.
diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py
index 116a1f1614..4ab0bcbd9b 100644
--- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py
+++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py
@@ -23,15 +23,15 @@ class SSLImagenetDataModule(LightningDataModule):  # pragma: no cover
 
     def __init__(
         self,
-        data_dir,
-        meta_dir=None,
+        data_dir: str,
+        meta_dir: Optional[str] = None,
         num_workers=16,
         batch_size: int = 32,
         shuffle: bool = False,
         pin_memory: bool = False,
         drop_last: bool = False,
-        *args,
-        **kwargs,
+        *args: Any,
+        **kwargs: Any,
     ) -> None:
         super().__init__(*args, **kwargs)
 
diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py
index 133f3ed6f8..a32c89e5be 100644
--- a/pl_bolts/datamodules/stl10_datamodule.py
+++ b/pl_bolts/datamodules/stl10_datamodule.py
@@ -65,8 +65,8 @@ def __init__(
         shuffle: bool = False,
         pin_memory: bool = False,
         drop_last: bool = False,
-        *args,
-        **kwargs,
+        *args: Any,
+        **kwargs: Any,
     ) -> None:
         """
         Args:

From 7bc0c370b302b3f18e772c6a2f0eb39679ef881d Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Tue, 12 Jan 2021 19:54:42 +0900
Subject: [PATCH 62/75] Fix flake8

---
 pl_bolts/datamodules/kitti_datamodule.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py
index 461b2244a8..df07852085 100644
--- a/pl_bolts/datamodules/kitti_datamodule.py
+++ b/pl_bolts/datamodules/kitti_datamodule.py
@@ -24,7 +24,7 @@ class KittiDataModule(LightningDataModule):
 
     def __init__(
         self,
-        data_dir: str,
+        data_dir: Optional[str] = None,
         val_split: float = 0.2,
         test_split: float = 0.1,
         num_workers: int = 16,

From a5f3e4f98d4d56d2c41f3e93da42e66b1e12e7ce Mon Sep 17 00:00:00 2001
From: Brian Ko <briankosw@gmail.com>
Date: Tue, 12 Jan 2021 19:56:55 +0900
Subject: [PATCH 63/75] Fix yapf format

---
 pl_bolts/datamodules/async_dataloader.py | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/pl_bolts/datamodules/async_dataloader.py b/pl_bolts/datamodules/async_dataloader.py
index 224f34d5ee..72d823c9ae 100644
--- a/pl_bolts/datamodules/async_dataloader.py
+++ b/pl_bolts/datamodules/async_dataloader.py
@@ -28,12 +28,12 @@ class AsynchronousLoader(object):
     """
 
     def __init__(
-            self,
-            data: Union[DataLoader, Dataset],
-            device: torch.device = torch.device('cuda', 0),
-            q_size: int = 10,
-            num_batches: Optional[int] = None,
-            **kwargs: Any
+        self,
+        data: Union[DataLoader, Dataset],
+        device: torch.device = torch.device('cuda', 0),
+        q_size: int = 10,
+        num_batches: Optional[int] = None,
+        **kwargs: Any
     ) -> None:
         if isinstance(data, torch.utils.data.DataLoader):
             self.dataloader = data

From b9c910dde176d02c48e5f0e0f5a7a4e61b1d8cf3 Mon Sep 17 00:00:00 2001
From: Akihiro Nitta <nitta@akihironitta.com>
Date: Wed, 20 Jan 2021 20:11:53 +0900
Subject: [PATCH 64/75] Add types and skip mypy checks on some files

---
 pl_bolts/callbacks/byol_updates.py               |  4 ++--
 pl_bolts/callbacks/variational.py                |  5 +++--
 pl_bolts/datamodules/async_dataloader.py         | 13 +++++++------
 pl_bolts/datamodules/binary_mnist_datamodule.py  |  2 +-
 pl_bolts/datamodules/cifar10_datamodule.py       |  4 ++--
 pl_bolts/datamodules/cityscapes_datamodule.py    |  3 ++-
 pl_bolts/datamodules/experience_source.py        |  2 +-
 pl_bolts/datamodules/fashion_mnist_datamodule.py |  2 +-
 pl_bolts/datamodules/imagenet_datamodule.py      |  7 ++++---
 pl_bolts/datamodules/kitti_datamodule.py         |  1 +
 pl_bolts/datamodules/mnist_datamodule.py         |  2 +-
 pl_bolts/datamodules/ssl_imagenet_datamodule.py  |  9 +++++----
 pl_bolts/datamodules/stl10_datamodule.py         |  3 ++-
 pl_bolts/datamodules/vision_datamodule.py        | 16 ++++++++--------
 pl_bolts/datamodules/vocdetection_datamodule.py  | 16 ++++++++--------
 setup.cfg                                        | 13 +++++++++++++
 16 files changed, 61 insertions(+), 41 deletions(-)

diff --git a/pl_bolts/callbacks/byol_updates.py b/pl_bolts/callbacks/byol_updates.py
index 8f47815521..2b4a1953f8 100644
--- a/pl_bolts/callbacks/byol_updates.py
+++ b/pl_bolts/callbacks/byol_updates.py
@@ -66,7 +66,7 @@ def update_tau(self, pl_module: LightningModule, trainer: Trainer) -> float:
     def update_weights(self, online_net: Union[Module, Tensor], target_net: Union[Module, Tensor]) -> None:
         # apply MA weight update
         for (name, online_p), (_, target_p) in zip(
-            online_net.named_parameters(), target_net.named_parameters()
-        ):  # type: ignore[union-attr]
+            online_net.named_parameters(), target_net.named_parameters()  # type: ignore[union-attr]
+        ):
             if 'weight' in name:
                 target_p.data = self.current_tau * target_p.data + (1 - self.current_tau) * online_p.data
diff --git a/pl_bolts/callbacks/variational.py b/pl_bolts/callbacks/variational.py
index 5947f40be8..4d5f4c6e23 100644
--- a/pl_bolts/callbacks/variational.py
+++ b/pl_bolts/callbacks/variational.py
@@ -62,8 +62,9 @@ def __init__(
     def on_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
         if (trainer.current_epoch + 1) % self.interpolate_epoch_interval == 0:
             images = self.interpolate_latent_space(
-                pl_module, latent_dim=pl_module.hparams.latent_dim
-            )  # type: ignore[union-attr]
+                pl_module,
+                latent_dim=pl_module.hparams.latent_dim  # type: ignore[union-attr]
+            )
             images = torch.cat(images, dim=0)  # type: ignore[assignment]
 
             num_images = (self.range_end - self.range_start)**2
diff --git a/pl_bolts/datamodules/async_dataloader.py b/pl_bolts/datamodules/async_dataloader.py
index 72d823c9ae..137429dc51 100644
--- a/pl_bolts/datamodules/async_dataloader.py
+++ b/pl_bolts/datamodules/async_dataloader.py
@@ -3,6 +3,7 @@
 from threading import Thread
 from typing import Any, Optional, Union
 
+import numpy as np
 import torch
 from torch._six import container_abcs, string_classes
 from torch.utils.data import DataLoader, Dataset
@@ -33,7 +34,7 @@ def __init__(
         device: torch.device = torch.device('cuda', 0),
         q_size: int = 10,
         num_batches: Optional[int] = None,
-        **kwargs: Any
+        **kwargs: Any,
     ) -> None:
         if isinstance(data, torch.utils.data.DataLoader):
             self.dataloader = data
@@ -51,7 +52,7 @@ def __init__(
         self.q_size = q_size
 
         self.load_stream = torch.cuda.Stream(device=device)
-        self.queue = Queue(maxsize=self.q_size)
+        self.queue: Queue = Queue(maxsize=self.q_size)
 
         self.idx = 0
 
@@ -64,7 +65,7 @@ def load_loop(self) -> None:  # The loop that will load into the queue in the ba
                 break
 
     # Recursive loading for each instance based on torch.utils.data.default_collate
-    def load_instance(self, sample):
+    def load_instance(self, sample: Any) -> Any:
         elem_type = type(sample)
 
         if torch.is_tensor(sample):
@@ -88,16 +89,16 @@ def load_instance(self, sample):
         else:
             return sample
 
-    def __iter__(self):
+    def __iter__(self) -> "AsynchronousLoader":
         # We don't want to run the thread more than once
         # Start a new thread if we are at the beginning of a new epoch, and our current worker is dead
-        if (not hasattr(self, 'worker') or not self.worker.is_alive()) and self.queue.empty() and self.idx == 0:
+        if (not hasattr(self, 'worker') or not self.worker.is_alive()) and self.queue.empty() and self.idx == 0:  # type: ignore[has-type]
             self.worker = Thread(target=self.load_loop)
             self.worker.daemon = True
             self.worker.start()
         return self
 
-    def __next__(self):
+    def __next__(self) -> torch.Tensor:
         # If we've reached the number of batches to return
         # or the queue is empty and the worker is dead then exit
         done = not self.worker.is_alive() and self.queue.empty()
diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py
index a72c02b53a..7b1b963f49 100644
--- a/pl_bolts/datamodules/binary_mnist_datamodule.py
+++ b/pl_bolts/datamodules/binary_mnist_datamodule.py
@@ -78,7 +78,7 @@ def __init__(
                 "You want to use transforms loaded from `torchvision` which is not installed yet."
             )
 
-        super().__init__(
+        super().__init__(  # type: ignore[misc]
             data_dir=data_dir,
             val_split=val_split,
             num_workers=num_workers,
diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py
index d7e3ed2d7a..1c17658a30 100644
--- a/pl_bolts/datamodules/cifar10_datamodule.py
+++ b/pl_bolts/datamodules/cifar10_datamodule.py
@@ -87,7 +87,7 @@ def __init__(
                         returning them
             drop_last: If true drops the last incomplete batch
         """
-        super().__init__(
+        super().__init__(  # type: ignore[misc]
             data_dir=data_dir,
             val_split=val_split,
             num_workers=num_workers,
@@ -166,7 +166,7 @@ def __init__(
         """
         super().__init__(data_dir, val_split, num_workers, *args, **kwargs)
 
-        self.num_samples = num_samples
+        self.num_samples = num_samples  # type: ignore[misc]
         self.labels = sorted(labels) if labels is not None else set(range(10))
         self.extra_args = dict(num_samples=self.num_samples, labels=self.labels)
 
diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py
index eaf7cd5fa3..462d4ca982 100644
--- a/pl_bolts/datamodules/cityscapes_datamodule.py
+++ b/pl_bolts/datamodules/cityscapes_datamodule.py
@@ -1,3 +1,4 @@
+# type: ignore[override]
 from typing import Any
 
 from pytorch_lightning import LightningDataModule
@@ -60,7 +61,7 @@ class CityscapesDataModule(LightningDataModule):
     """
 
     name = 'Cityscapes'
-    extra_args = {}
+    extra_args: dict = {}
 
     def __init__(
         self,
diff --git a/pl_bolts/datamodules/experience_source.py b/pl_bolts/datamodules/experience_source.py
index 1bc7b0f8a8..50ed2a6a7b 100644
--- a/pl_bolts/datamodules/experience_source.py
+++ b/pl_bolts/datamodules/experience_source.py
@@ -299,5 +299,5 @@ def discount_rewards(self, experiences: Tuple[Experience]) -> float:
         """
         total_reward = 0.0
         for exp in reversed(experiences):
-            total_reward = (self.gamma * total_reward) + exp.reward
+            total_reward = (self.gamma * total_reward) + exp.reward  # type: ignore[attr-defined]
         return total_reward
diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py
index eabecdbf60..b209e09f37 100644
--- a/pl_bolts/datamodules/fashion_mnist_datamodule.py
+++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py
@@ -78,7 +78,7 @@ def __init__(
                 'You want to use FashionMNIST dataset loaded from `torchvision` which is not installed yet.'
             )
 
-        super().__init__(
+        super().__init__(  # type: ignore[misc]
             data_dir=data_dir,
             val_split=val_split,
             num_workers=num_workers,
diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py
index 41ac3d42db..ab8e8b9921 100644
--- a/pl_bolts/datamodules/imagenet_datamodule.py
+++ b/pl_bolts/datamodules/imagenet_datamodule.py
@@ -1,3 +1,4 @@
+# type: ignore[override]
 import os
 from typing import Any, Optional
 
@@ -158,7 +159,7 @@ def train_dataloader(self) -> DataLoader:
             split='train',
             transform=transforms
         )
-        loader = DataLoader(
+        loader: DataLoader = DataLoader(
             dataset,
             batch_size=self.batch_size,
             shuffle=self.shuffle,
@@ -185,7 +186,7 @@ def val_dataloader(self) -> DataLoader:
             split='val',
             transform=transforms
         )
-        loader = DataLoader(
+        loader: DataLoader = DataLoader(
             dataset,
             batch_size=self.batch_size,
             shuffle=False,
@@ -204,7 +205,7 @@ def test_dataloader(self) -> DataLoader:
         dataset = UnlabeledImagenet(
             self.data_dir, num_imgs_per_class=-1, meta_dir=self.meta_dir, split='test', transform=transforms
         )
-        loader = DataLoader(
+        loader: DataLoader = DataLoader(
             dataset,
             batch_size=self.batch_size,
             shuffle=False,
diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py
index 34d64c3c00..893f224047 100644
--- a/pl_bolts/datamodules/kitti_datamodule.py
+++ b/pl_bolts/datamodules/kitti_datamodule.py
@@ -1,3 +1,4 @@
+# type: ignore[override]
 import os
 from typing import Any, Optional
 
diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py
index 4f8e9e19a0..c813cee685 100644
--- a/pl_bolts/datamodules/mnist_datamodule.py
+++ b/pl_bolts/datamodules/mnist_datamodule.py
@@ -77,7 +77,7 @@ def __init__(
                 'You want to use MNIST dataset loaded from `torchvision` which is not installed yet.'
             )
 
-        super().__init__(
+        super().__init__(  # type: ignore[misc]
             data_dir=data_dir,
             val_split=val_split,
             num_workers=num_workers,
diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py
index c8b05b904a..c21d1af46b 100644
--- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py
+++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py
@@ -1,3 +1,4 @@
+# type: ignore[override]
 import os
 from typing import Any, Optional
 
@@ -25,7 +26,7 @@ def __init__(
         self,
         data_dir: str,
         meta_dir: Optional[str] = None,
-        num_workers=16,
+        num_workers: int = 16,
         batch_size: int = 32,
         shuffle: bool = False,
         pin_memory: bool = False,
@@ -96,7 +97,7 @@ def train_dataloader(self, num_images_per_class: int = -1, add_normalize: bool =
             split='train',
             transform=transforms
         )
-        loader = DataLoader(
+        loader: DataLoader = DataLoader(
             dataset,
             batch_size=self.batch_size,
             shuffle=self.shuffle,
@@ -116,7 +117,7 @@ def val_dataloader(self, num_images_per_class: int = 50, add_normalize: bool = F
             split='val',
             transform=transforms
         )
-        loader = DataLoader(
+        loader: DataLoader = DataLoader(
             dataset,
             batch_size=self.batch_size,
             shuffle=False,
@@ -136,7 +137,7 @@ def test_dataloader(self, num_images_per_class: int, add_normalize: bool = False
             split='test',
             transform=transforms
         )
-        loader = DataLoader(
+        loader: DataLoader = DataLoader(
             dataset,
             batch_size=self.batch_size,
             shuffle=False,
diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py
index 5aafa1380b..0433278af6 100644
--- a/pl_bolts/datamodules/stl10_datamodule.py
+++ b/pl_bolts/datamodules/stl10_datamodule.py
@@ -1,3 +1,4 @@
+# type: ignore[override]
 import os
 from typing import Any, Optional
 
@@ -194,7 +195,7 @@ def val_dataloader(self) -> DataLoader:
             batch_size=self.batch_size,
             shuffle=False,
             num_workers=self.num_workers,
-            drpo_last=self.drop_last,
+            drop_last=self.drop_last,
             pin_memory=self.pin_memory
         )
         return loader
diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py
index bab4c722dc..73ba424c96 100644
--- a/pl_bolts/datamodules/vision_datamodule.py
+++ b/pl_bolts/datamodules/vision_datamodule.py
@@ -16,12 +16,12 @@
 
 class VisionDataModule(LightningDataModule):
 
-    EXTRA_ARGS = {}
+    EXTRA_ARGS: dict = {}
     name: str = ""
     #: Dataset class to use
-    dataset_cls = ...
+    dataset_cls: type
     #: A tuple describing the shape of the data
-    dims: tuple = ...
+    dims: tuple
 
     def __init__(
         self,
@@ -63,7 +63,7 @@ def __init__(
         self.pin_memory = pin_memory
         self.drop_last = drop_last
 
-    def prepare_data(self) -> None:
+    def prepare_data(self, *args: Any, **kwargs: Any) -> None:
         """
         Saves files to data_dir
         """
@@ -95,7 +95,7 @@ def _split_dataset(self, dataset: Dataset, train: bool = True) -> Dataset:
         """
         Splits the dataset into train and validation set
         """
-        len_dataset = len(dataset)
+        len_dataset = len(dataset)  # type: ignore[arg-type]
         splits = self._get_splits(len_dataset)
         dataset_train, dataset_val = random_split(dataset, splits, generator=torch.Generator().manual_seed(self.seed))
 
@@ -123,15 +123,15 @@ def _get_splits(self, len_dataset: int) -> List[int]:
     def default_transforms(self) -> Compose:
         """ Default transform for the dataset """
 
-    def train_dataloader(self) -> DataLoader:
+    def train_dataloader(self, *args: Any, **kwargs: Any) -> DataLoader:
         """ The train dataloader """
         return self._data_loader(self.dataset_train, shuffle=self.shuffle)
 
-    def val_dataloader(self) -> DataLoader:
+    def val_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]:
         """ The val dataloader """
         return self._data_loader(self.dataset_val)
 
-    def test_dataloader(self) -> DataLoader:
+    def test_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]:
         """ The test dataloader """
         return self._data_loader(self.dataset_test)
 
diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py
index c3ba78452c..c704540312 100644
--- a/pl_bolts/datamodules/vocdetection_datamodule.py
+++ b/pl_bolts/datamodules/vocdetection_datamodule.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict
+from typing import Any, Dict, List, Callable, Tuple, Optional
 
 import torch
 from pytorch_lightning import LightningDataModule
@@ -19,11 +19,11 @@ class Compose(object):
     Like `torchvision.transforms.compose` but works for (image, target)
     """
 
-    def __init__(self, transforms, image_transforms=None):
+    def __init__(self, transforms: List[Callable], image_transforms: Optional[Callable] = None) -> None:
         self.transforms = transforms
         self.image_transforms = image_transforms
 
-    def __call__(self, image, target):
+    def __call__(self, image: Any, target: Any) -> Tuple[torch.Tensor, torch.Tensor]:
         for t in self.transforms:
             image, target = t(image, target)
         if self.image_transforms:
@@ -31,7 +31,7 @@ def __call__(self, image, target):
         return image, target
 
 
-def _collate_fn(batch):
+def _collate_fn(batch: List[torch.Tensor]) -> tuple:
     return tuple(zip(*batch))
 
 
@@ -60,7 +60,7 @@ def _collate_fn(batch):
 )
 
 
-def _prepare_voc_instance(image, target: Dict[str, Any]):
+def _prepare_voc_instance(image: Any, target: Dict[str, Any]):
     """
     Prepares VOC dataset into appropriate target for fasterrcnn
 
@@ -151,7 +151,7 @@ def prepare_data(self) -> None:
         VOCDetection(self.data_dir, year=self.year, image_set="train", download=True)
         VOCDetection(self.data_dir, year=self.year, image_set="val", download=True)
 
-    def train_dataloader(self, batch_size: int = 1, image_transforms=None) -> DataLoader:
+    def train_dataloader(self, batch_size: int = 1, image_transforms: Union[List[Callable], Callable]=None) -> DataLoader:
         """
         VOCDetection train set uses the `train` subset
 
@@ -174,7 +174,7 @@ def train_dataloader(self, batch_size: int = 1, image_transforms=None) -> DataLo
         )
         return loader
 
-    def val_dataloader(self, batch_size: int = 1, image_transforms=None) -> DataLoader:
+    def val_dataloader(self, batch_size: int = 1, image_transforms: Optional[List[Callable]] = None) -> DataLoader:
         """
         VOCDetection val set uses the `val` subset
 
@@ -197,7 +197,7 @@ def val_dataloader(self, batch_size: int = 1, image_transforms=None) -> DataLoad
         )
         return loader
 
-    def _default_transforms(self):
+    def _default_transforms(self) -> transform_lib.Compose:
         if self.normalize:
             voc_transforms = transform_lib.Compose([
                 transform_lib.ToTensor(),
diff --git a/setup.cfg b/setup.cfg
index 12525b46ae..957564a046 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -75,6 +75,19 @@ ignore_missing_imports = True
 [mypy-pl_bolts.datasets.*]
 ignore_errors = True
 
+[mypy-pl_bolts.datamodules]
+ # pl_bolts/datamodules/__init__.py
+ ignore_errors = True
+
+[mypy-pl_bolts.datamodules.experience_source]
+ignore_errors = True
+
+[mypy-pl_bolts.datamodules.sklearn_datamodule]
+ignore_errors = True
+
+[mypy-pl_bolts.datamodules.vocdetection_datamodule]
+ignore_errors = True
+
 [mypy-pl_bolts.losses.*]
 ignore_errors = True
 

From 9e222d02f5a2f6721ca5b932d717061b3a23b1d0 Mon Sep 17 00:00:00 2001
From: Akihiro Nitta <nitta@akihironitta.com>
Date: Wed, 20 Jan 2021 20:22:57 +0900
Subject: [PATCH 65/75] Fix setup.cfg

---
 setup.cfg | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/setup.cfg b/setup.cfg
index 957564a046..5883253ce5 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -76,8 +76,8 @@ ignore_missing_imports = True
 ignore_errors = True
 
 [mypy-pl_bolts.datamodules]
- # pl_bolts/datamodules/__init__.py
- ignore_errors = True
+# pl_bolts/datamodules/__init__.py
+ignore_errors = True
 
 [mypy-pl_bolts.datamodules.experience_source]
 ignore_errors = True

From 0c54fdd87906a518c7d836bb2478496e57682995 Mon Sep 17 00:00:00 2001
From: Akihiro Nitta <nitta@akihironitta.com>
Date: Wed, 20 Jan 2021 20:27:24 +0900
Subject: [PATCH 66/75] Add missing import

---
 pl_bolts/datamodules/vocdetection_datamodule.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py
index c704540312..eec63f1891 100644
--- a/pl_bolts/datamodules/vocdetection_datamodule.py
+++ b/pl_bolts/datamodules/vocdetection_datamodule.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict, List, Callable, Tuple, Optional
+from typing import Any, Dict, List, Callable, Tuple, Optional, Union
 
 import torch
 from pytorch_lightning import LightningDataModule

From 8b2e1964af51c771af83002e58539d970646d3e1 Mon Sep 17 00:00:00 2001
From: Akihiro Nitta <nitta@akihironitta.com>
Date: Wed, 20 Jan 2021 20:27:35 +0900
Subject: [PATCH 67/75] isort

---
 pl_bolts/datamodules/vocdetection_datamodule.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py
index eec63f1891..3890ab640d 100644
--- a/pl_bolts/datamodules/vocdetection_datamodule.py
+++ b/pl_bolts/datamodules/vocdetection_datamodule.py
@@ -1,4 +1,4 @@
-from typing import Any, Dict, List, Callable, Tuple, Optional, Union
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
 
 import torch
 from pytorch_lightning import LightningDataModule

From 9c5dd5cafc49d0732143babb0f8f198d07c75a00 Mon Sep 17 00:00:00 2001
From: Akihiro Nitta <nitta@akihironitta.com>
Date: Wed, 20 Jan 2021 20:34:32 +0900
Subject: [PATCH 68/75] yapf

---
 pl_bolts/callbacks/byol_updates.py              | 3 ++-
 pl_bolts/datamodules/async_dataloader.py        | 3 ++-
 pl_bolts/datamodules/vocdetection_datamodule.py | 4 +++-
 3 files changed, 7 insertions(+), 3 deletions(-)

diff --git a/pl_bolts/callbacks/byol_updates.py b/pl_bolts/callbacks/byol_updates.py
index 2b4a1953f8..2918d2aa75 100644
--- a/pl_bolts/callbacks/byol_updates.py
+++ b/pl_bolts/callbacks/byol_updates.py
@@ -66,7 +66,8 @@ def update_tau(self, pl_module: LightningModule, trainer: Trainer) -> float:
     def update_weights(self, online_net: Union[Module, Tensor], target_net: Union[Module, Tensor]) -> None:
         # apply MA weight update
         for (name, online_p), (_, target_p) in zip(
-            online_net.named_parameters(), target_net.named_parameters()  # type: ignore[union-attr]
+            online_net.named_parameters(),
+            target_net.named_parameters()  # type: ignore[union-attr]
         ):
             if 'weight' in name:
                 target_p.data = self.current_tau * target_p.data + (1 - self.current_tau) * online_p.data
diff --git a/pl_bolts/datamodules/async_dataloader.py b/pl_bolts/datamodules/async_dataloader.py
index 137429dc51..b8a35ffa9b 100644
--- a/pl_bolts/datamodules/async_dataloader.py
+++ b/pl_bolts/datamodules/async_dataloader.py
@@ -92,7 +92,8 @@ def load_instance(self, sample: Any) -> Any:
     def __iter__(self) -> "AsynchronousLoader":
         # We don't want to run the thread more than once
         # Start a new thread if we are at the beginning of a new epoch, and our current worker is dead
-        if (not hasattr(self, 'worker') or not self.worker.is_alive()) and self.queue.empty() and self.idx == 0:  # type: ignore[has-type]
+        if (not hasattr(self, 'worker')
+            or not self.worker.is_alive()) and self.queue.empty() and self.idx == 0:  # type: ignore[has-type]
             self.worker = Thread(target=self.load_loop)
             self.worker.daemon = True
             self.worker.start()
diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py
index 3890ab640d..f70806a7a7 100644
--- a/pl_bolts/datamodules/vocdetection_datamodule.py
+++ b/pl_bolts/datamodules/vocdetection_datamodule.py
@@ -151,7 +151,9 @@ def prepare_data(self) -> None:
         VOCDetection(self.data_dir, year=self.year, image_set="train", download=True)
         VOCDetection(self.data_dir, year=self.year, image_set="val", download=True)
 
-    def train_dataloader(self, batch_size: int = 1, image_transforms: Union[List[Callable], Callable]=None) -> DataLoader:
+    def train_dataloader(
+        self, batch_size: int = 1, image_transforms: Union[List[Callable], Callable] = None
+    ) -> DataLoader:
         """
         VOCDetection train set uses the `train` subset
 

From c6c97e1e524208828ff4bd6302072657ab728c83 Mon Sep 17 00:00:00 2001
From: Akihiro Nitta <nitta@akihironitta.com>
Date: Wed, 20 Jan 2021 20:41:33 +0900
Subject: [PATCH 69/75] mypy please...

---
 pl_bolts/callbacks/byol_updates.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/pl_bolts/callbacks/byol_updates.py b/pl_bolts/callbacks/byol_updates.py
index 2918d2aa75..1c4d3ba7d4 100644
--- a/pl_bolts/callbacks/byol_updates.py
+++ b/pl_bolts/callbacks/byol_updates.py
@@ -66,7 +66,7 @@ def update_tau(self, pl_module: LightningModule, trainer: Trainer) -> float:
     def update_weights(self, online_net: Union[Module, Tensor], target_net: Union[Module, Tensor]) -> None:
         # apply MA weight update
         for (name, online_p), (_, target_p) in zip(
-            online_net.named_parameters(),
+            online_net.named_parameters(),  # type: ignore[union-attr]
             target_net.named_parameters()  # type: ignore[union-attr]
         ):
             if 'weight' in name:

From 4ac8a5ba8d00dd8cc60a87c30f7d4910fb205f48 Mon Sep 17 00:00:00 2001
From: Akihiro Nitta <nitta@akihironitta.com>
Date: Wed, 20 Jan 2021 20:52:19 +0900
Subject: [PATCH 70/75] Please be quiet mypy and flake8

---
 pl_bolts/datamodules/async_dataloader.py | 4 +---
 1 file changed, 1 insertion(+), 3 deletions(-)

diff --git a/pl_bolts/datamodules/async_dataloader.py b/pl_bolts/datamodules/async_dataloader.py
index b8a35ffa9b..6034c29d3e 100644
--- a/pl_bolts/datamodules/async_dataloader.py
+++ b/pl_bolts/datamodules/async_dataloader.py
@@ -3,7 +3,6 @@
 from threading import Thread
 from typing import Any, Optional, Union
 
-import numpy as np
 import torch
 from torch._six import container_abcs, string_classes
 from torch.utils.data import DataLoader, Dataset
@@ -92,8 +91,7 @@ def load_instance(self, sample: Any) -> Any:
     def __iter__(self) -> "AsynchronousLoader":
         # We don't want to run the thread more than once
         # Start a new thread if we are at the beginning of a new epoch, and our current worker is dead
-        if (not hasattr(self, 'worker')
-            or not self.worker.is_alive()) and self.queue.empty() and self.idx == 0:  # type: ignore[has-type]
+        if (not hasattr(self, 'worker') or not self.worker.is_alive()) and self.queue.empty() and self.idx == 0:  # type: ignore[has-type] # noqa: E501
             self.worker = Thread(target=self.load_loop)
             self.worker.daemon = True
             self.worker.start()

From e847eadf2c47ea9560019d4782f14e1a0e087c00 Mon Sep 17 00:00:00 2001
From: Akihiro Nitta <nitta@akihironitta.com>
Date: Wed, 20 Jan 2021 20:53:53 +0900
Subject: [PATCH 71/75] yapf...

---
 pl_bolts/datamodules/async_dataloader.py | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/pl_bolts/datamodules/async_dataloader.py b/pl_bolts/datamodules/async_dataloader.py
index 6034c29d3e..410425e3cb 100644
--- a/pl_bolts/datamodules/async_dataloader.py
+++ b/pl_bolts/datamodules/async_dataloader.py
@@ -91,7 +91,8 @@ def load_instance(self, sample: Any) -> Any:
     def __iter__(self) -> "AsynchronousLoader":
         # We don't want to run the thread more than once
         # Start a new thread if we are at the beginning of a new epoch, and our current worker is dead
-        if (not hasattr(self, 'worker') or not self.worker.is_alive()) and self.queue.empty() and self.idx == 0:  # type: ignore[has-type] # noqa: E501
+        if (not hasattr(self, 'worker') or
+            not self.worker.is_alive()) and self.queue.empty() and self.idx == 0:  # type: ignore[has-type] # noqa: E501
             self.worker = Thread(target=self.load_loop)
             self.worker.daemon = True
             self.worker.start()

From 1839438d6f194e8ec3cab5a13d9e27f3b69a8e4b Mon Sep 17 00:00:00 2001
From: Akihiro Nitta <nitta@akihironitta.com>
Date: Wed, 20 Jan 2021 21:05:26 +0900
Subject: [PATCH 72/75] Disable all of yapf, flake8, and mypy

---
 pl_bolts/datamodules/async_dataloader.py | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/pl_bolts/datamodules/async_dataloader.py b/pl_bolts/datamodules/async_dataloader.py
index 410425e3cb..24fa820d67 100644
--- a/pl_bolts/datamodules/async_dataloader.py
+++ b/pl_bolts/datamodules/async_dataloader.py
@@ -91,9 +91,11 @@ def load_instance(self, sample: Any) -> Any:
     def __iter__(self) -> "AsynchronousLoader":
         # We don't want to run the thread more than once
         # Start a new thread if we are at the beginning of a new epoch, and our current worker is dead
-        if (not hasattr(self, 'worker') or
-            not self.worker.is_alive()) and self.queue.empty() and self.idx == 0:  # type: ignore[has-type] # noqa: E501
+
+        # yapf: disable
+        if (not hasattr(self, 'worker') or not self.worker.is_alive()) and self.queue.empty() and self.idx == 0:  # type: ignore[has-type] # noqa: E501
             self.worker = Thread(target=self.load_loop)
+            # yapf: enable
             self.worker.daemon = True
             self.worker.start()
         return self

From 097df6d3e9a8b7e4a286913066dcdc203330ef71 Mon Sep 17 00:00:00 2001
From: Akihiro Nitta <nitta@akihironitta.com>
Date: Wed, 20 Jan 2021 21:12:04 +0900
Subject: [PATCH 73/75] Use Callable

---
 pl_bolts/datamodules/vocdetection_datamodule.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py
index f70806a7a7..97b63cc86e 100644
--- a/pl_bolts/datamodules/vocdetection_datamodule.py
+++ b/pl_bolts/datamodules/vocdetection_datamodule.py
@@ -199,7 +199,7 @@ def val_dataloader(self, batch_size: int = 1, image_transforms: Optional[List[Ca
         )
         return loader
 
-    def _default_transforms(self) -> transform_lib.Compose:
+    def _default_transforms(self) -> Callable:
         if self.normalize:
             voc_transforms = transform_lib.Compose([
                 transform_lib.ToTensor(),

From 9e00c5dfd366d821ed600795c4f57e3fe9eda996 Mon Sep 17 00:00:00 2001
From: Akihiro Nitta <nitta@akihironitta.com>
Date: Wed, 20 Jan 2021 21:41:13 +0900
Subject: [PATCH 74/75] Use Callable

---
 pl_bolts/datamodules/binary_mnist_datamodule.py  |  6 ++----
 pl_bolts/datamodules/cifar10_datamodule.py       |  6 ++----
 pl_bolts/datamodules/cityscapes_datamodule.py    |  8 +++-----
 pl_bolts/datamodules/fashion_mnist_datamodule.py |  6 ++----
 pl_bolts/datamodules/imagenet_datamodule.py      |  6 ++----
 pl_bolts/datamodules/kitti_datamodule.py         |  6 ++----
 pl_bolts/datamodules/mnist_datamodule.py         |  6 ++----
 pl_bolts/datamodules/ssl_imagenet_datamodule.py  |  6 ++----
 pl_bolts/datamodules/stl10_datamodule.py         |  4 +---
 pl_bolts/datamodules/vision_datamodule.py        | 11 ++---------
 10 files changed, 20 insertions(+), 45 deletions(-)

diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py
index 7b1b963f49..c43065984b 100644
--- a/pl_bolts/datamodules/binary_mnist_datamodule.py
+++ b/pl_bolts/datamodules/binary_mnist_datamodule.py
@@ -1,4 +1,4 @@
-from typing import Any, Optional, Union
+from typing import Any, Callable, Optional, Union
 
 from pl_bolts.datamodules.vision_datamodule import VisionDataModule
 from pl_bolts.datasets.mnist_dataset import BinaryMNIST
@@ -7,10 +7,8 @@
 
 if _TORCHVISION_AVAILABLE:
     from torchvision import transforms as transform_lib
-    from torchvision.transforms import Compose
 else:  # pragma: no cover
     warn_missing_pkg('torchvision')
-    Compose = object
 
 
 class BinaryMNISTDataModule(VisionDataModule):
@@ -100,7 +98,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self) -> Compose:
+    def default_transforms(self) -> Callable:
         if self.normalize:
             mnist_transforms = transform_lib.Compose([
                 transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5, ), std=(0.5, ))
diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py
index 1c17658a30..e54eb37deb 100644
--- a/pl_bolts/datamodules/cifar10_datamodule.py
+++ b/pl_bolts/datamodules/cifar10_datamodule.py
@@ -1,4 +1,4 @@
-from typing import Any, Optional, Sequence, Union
+from typing import Any, Callable, Optional, Sequence, Union
 
 from pl_bolts.datamodules.vision_datamodule import VisionDataModule
 from pl_bolts.datasets.cifar10_dataset import TrialCIFAR10
@@ -9,11 +9,9 @@
 if _TORCHVISION_AVAILABLE:
     from torchvision import transforms as transform_lib
     from torchvision.datasets import CIFAR10
-    from torchvision.transforms import Compose
 else:  # pragma: no cover
     warn_missing_pkg('torchvision')
     CIFAR10 = None
-    Compose = object
 
 
 class CIFAR10DataModule(VisionDataModule):
@@ -114,7 +112,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self) -> Compose:
+    def default_transforms(self) -> Callable:
         if self.normalize:
             cf10_transforms = transform_lib.Compose([transform_lib.ToTensor(), cifar10_normalization()])
         else:
diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py
index 462d4ca982..3f1c223baf 100644
--- a/pl_bolts/datamodules/cityscapes_datamodule.py
+++ b/pl_bolts/datamodules/cityscapes_datamodule.py
@@ -1,5 +1,5 @@
 # type: ignore[override]
-from typing import Any
+from typing import Any, Callable
 
 from pytorch_lightning import LightningDataModule
 from torch.utils.data import DataLoader
@@ -10,10 +10,8 @@
 if _TORCHVISION_AVAILABLE:
     from torchvision import transforms as transform_lib
     from torchvision.datasets import Cityscapes
-    from torchvision.transforms import Compose
 else:  # pragma: no cover
     warn_missing_pkg('torchvision')
-    Compose = object
 
 
 class CityscapesDataModule(LightningDataModule):
@@ -201,7 +199,7 @@ def test_dataloader(self) -> DataLoader:
         )
         return loader
 
-    def _default_transforms(self) -> Compose:
+    def _default_transforms(self) -> Callable:
         cityscapes_transforms = transform_lib.Compose([
             transform_lib.ToTensor(),
             transform_lib.Normalize(
@@ -210,7 +208,7 @@ def _default_transforms(self) -> Compose:
         ])
         return cityscapes_transforms
 
-    def _default_target_transforms(self) -> Compose:
+    def _default_target_transforms(self) -> Callable:
         cityscapes_target_trasnforms = transform_lib.Compose([
             transform_lib.ToTensor(), transform_lib.Lambda(lambda t: t.squeeze())
         ])
diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py
index b209e09f37..f945e00912 100644
--- a/pl_bolts/datamodules/fashion_mnist_datamodule.py
+++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py
@@ -1,4 +1,4 @@
-from typing import Any, Optional, Union
+from typing import Any, Callable, Optional, Union
 
 from pl_bolts.datamodules.vision_datamodule import VisionDataModule
 from pl_bolts.utils import _TORCHVISION_AVAILABLE
@@ -7,11 +7,9 @@
 if _TORCHVISION_AVAILABLE:
     from torchvision import transforms as transform_lib
     from torchvision.datasets import FashionMNIST
-    from torchvision.transforms import Compose
 else:  # pragma: no cover
     warn_missing_pkg('torchvision')
     FashionMNIST = None
-    Compose = object
 
 
 class FashionMNISTDataModule(VisionDataModule):
@@ -100,7 +98,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self) -> Compose:
+    def default_transforms(self) -> Callable:
         if self.normalize:
             mnist_transforms = transform_lib.Compose([
                 transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5, ), std=(0.5, ))
diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py
index ab8e8b9921..066aa3cd0a 100644
--- a/pl_bolts/datamodules/imagenet_datamodule.py
+++ b/pl_bolts/datamodules/imagenet_datamodule.py
@@ -12,10 +12,8 @@
 
 if _TORCHVISION_AVAILABLE:
     from torchvision import transforms as transform_lib
-    from torchvision.transforms import Compose
 else:  # pragma: no cover
     warn_missing_pkg('torchvision')
-    Compose = object
 
 
 class ImagenetDataModule(LightningDataModule):
@@ -215,7 +213,7 @@ def test_dataloader(self) -> DataLoader:
         )
         return loader
 
-    def train_transform(self) -> Compose:
+    def train_transform(self) -> Callable:
         """
         The standard imagenet transforms
 
@@ -241,7 +239,7 @@ def train_transform(self) -> Compose:
 
         return preprocessing
 
-    def val_transform(self) -> Compose:
+    def val_transform(self) -> Callable:
         """
         The standard imagenet transforms for validation
 
diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py
index 893f224047..cd6d198185 100644
--- a/pl_bolts/datamodules/kitti_datamodule.py
+++ b/pl_bolts/datamodules/kitti_datamodule.py
@@ -1,6 +1,6 @@
 # type: ignore[override]
 import os
-from typing import Any, Optional
+from typing import Any, Callable, Optional
 
 import torch
 from pytorch_lightning import LightningDataModule
@@ -13,10 +13,8 @@
 
 if _TORCHVISION_AVAILABLE:
     from torchvision import transforms as transforms
-    from torchvision.transforms import Compose
 else:  # pragma: no cover
     warn_missing_pkg('torchvision')
-    Compose = object
 
 
 class KittiDataModule(LightningDataModule):
@@ -129,7 +127,7 @@ def test_dataloader(self) -> DataLoader:
         )
         return loader
 
-    def _default_transforms(self) -> Compose:
+    def _default_transforms(self) -> Callable:
         kitti_transforms = transforms.Compose([
             transforms.ToTensor(),
             transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753], std=[0.32064945, 0.32098866, 0.32325324])
diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py
index c813cee685..0889d71d09 100644
--- a/pl_bolts/datamodules/mnist_datamodule.py
+++ b/pl_bolts/datamodules/mnist_datamodule.py
@@ -1,4 +1,4 @@
-from typing import Any, Optional, Union
+from typing import Any, Callable, Optional, Union
 
 from pl_bolts.datamodules.vision_datamodule import VisionDataModule
 from pl_bolts.utils import _TORCHVISION_AVAILABLE
@@ -7,11 +7,9 @@
 if _TORCHVISION_AVAILABLE:
     from torchvision import transforms as transform_lib
     from torchvision.datasets import MNIST
-    from torchvision.transforms import Compose
 else:  # pragma: no cover
     warn_missing_pkg('torchvision')
     MNIST = None
-    Compose = object
 
 
 class MNISTDataModule(VisionDataModule):
@@ -99,7 +97,7 @@ def num_classes(self) -> int:
         """
         return 10
 
-    def default_transforms(self) -> Compose:
+    def default_transforms(self) -> Callable:
         if self.normalize:
             mnist_transforms = transform_lib.Compose([
                 transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5, ), std=(0.5, ))
diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py
index c21d1af46b..fc14dd2cae 100644
--- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py
+++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py
@@ -1,6 +1,6 @@
 # type: ignore[override]
 import os
-from typing import Any, Optional
+from typing import Any, Callable, Optional
 
 from pytorch_lightning import LightningDataModule
 from torch.utils.data import DataLoader
@@ -12,10 +12,8 @@
 
 if _TORCHVISION_AVAILABLE:
     from torchvision import transforms as transform_lib
-    from torchvision.transforms import Compose
 else:  # pragma: no cover
     warn_missing_pkg('torchvision')
-    Compose = object
 
 
 class SSLImagenetDataModule(LightningDataModule):  # pragma: no cover
@@ -147,6 +145,6 @@ def test_dataloader(self, num_images_per_class: int, add_normalize: bool = False
         )
         return loader
 
-    def _default_transforms(self) -> Compose:
+    def _default_transforms(self) -> Callable:
         mnist_transforms = transform_lib.Compose([transform_lib.ToTensor(), imagenet_normalization()])
         return mnist_transforms
diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py
index 0433278af6..f7b9f9963f 100644
--- a/pl_bolts/datamodules/stl10_datamodule.py
+++ b/pl_bolts/datamodules/stl10_datamodule.py
@@ -14,10 +14,8 @@
 if _TORCHVISION_AVAILABLE:
     from torchvision import transforms as transform_lib
     from torchvision.datasets import STL10
-    from torchvision.transforms import Compose
 else:  # pragma: no cover
     warn_missing_pkg('torchvision')
-    Compose = object
 
 
 class STL10DataModule(LightningDataModule):  # pragma: no cover
@@ -301,6 +299,6 @@ def val_dataloader_labeled(self) -> DataLoader:
         )
         return loader
 
-    def _default_transforms(self) -> Compose:
+    def _default_transforms(self) -> Callable:
         data_transforms = transform_lib.Compose([transform_lib.ToTensor(), stl10_normalization()])
         return data_transforms
diff --git a/pl_bolts/datamodules/vision_datamodule.py b/pl_bolts/datamodules/vision_datamodule.py
index 73ba424c96..d15d9c59d4 100644
--- a/pl_bolts/datamodules/vision_datamodule.py
+++ b/pl_bolts/datamodules/vision_datamodule.py
@@ -1,18 +1,11 @@
 import os
 from abc import abstractmethod
-from typing import Any, List, Optional, Union
+from typing import Any, Callable, List, Optional, Union
 
 import torch
 from pytorch_lightning import LightningDataModule
 from torch.utils.data import DataLoader, Dataset, random_split
 
-from pl_bolts.utils import _TORCHVISION_AVAILABLE
-
-if _TORCHVISION_AVAILABLE:
-    from torchvision.transforms import Compose
-else:
-    Compose = object
-
 
 class VisionDataModule(LightningDataModule):
 
@@ -120,7 +113,7 @@ def _get_splits(self, len_dataset: int) -> List[int]:
         return splits
 
     @abstractmethod
-    def default_transforms(self) -> Compose:
+    def default_transforms(self) -> Callable:
         """ Default transform for the dataset """
 
     def train_dataloader(self, *args: Any, **kwargs: Any) -> DataLoader:

From af563a57d64c76113c2b4daff02759422f62d90c Mon Sep 17 00:00:00 2001
From: Akihiro Nitta <nitta@akihironitta.com>
Date: Wed, 20 Jan 2021 21:44:38 +0900
Subject: [PATCH 75/75] Add missing import

---
 pl_bolts/datamodules/imagenet_datamodule.py | 2 +-
 pl_bolts/datamodules/stl10_datamodule.py    | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py
index 066aa3cd0a..b9cd811335 100644
--- a/pl_bolts/datamodules/imagenet_datamodule.py
+++ b/pl_bolts/datamodules/imagenet_datamodule.py
@@ -1,6 +1,6 @@
 # type: ignore[override]
 import os
-from typing import Any, Optional
+from typing import Any, Callable, Optional
 
 from pytorch_lightning import LightningDataModule
 from torch.utils.data import DataLoader
diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py
index f7b9f9963f..43ff3ebb6a 100644
--- a/pl_bolts/datamodules/stl10_datamodule.py
+++ b/pl_bolts/datamodules/stl10_datamodule.py
@@ -1,6 +1,6 @@
 # type: ignore[override]
 import os
-from typing import Any, Optional
+from typing import Any, Callable, Optional
 
 import torch
 from pytorch_lightning import LightningDataModule