From 30f2efd27550aa747ef0224944ec887f6ed4b65c Mon Sep 17 00:00:00 2001 From: haekula Date: Wed, 16 Dec 2020 14:29:47 +0530 Subject: [PATCH 1/5] Add Type Annotations to pl_bolts/utils/* --- pl_bolts/utils/__init__.py | 2 +- pl_bolts/utils/arguments.py | 16 ++++++++-------- pl_bolts/utils/pretrained_weights.py | 4 +++- pl_bolts/utils/self_supervised.py | 8 +++++++- pl_bolts/utils/semi_supervised.py | 20 +++++++++++--------- pl_bolts/utils/shaping.py | 3 ++- pl_bolts/utils/warnings.py | 8 +++++++- setup.cfg | 3 --- 8 files changed, 39 insertions(+), 25 deletions(-) diff --git a/pl_bolts/utils/__init__.py b/pl_bolts/utils/__init__.py index 7bfaeb94d3..3271fc0cc4 100644 --- a/pl_bolts/utils/__init__.py +++ b/pl_bolts/utils/__init__.py @@ -1,7 +1,7 @@ import torch from pytorch_lightning.utilities import _module_available -_NATIVE_AMP_AVAILABLE = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast") +_NATIVE_AMP_AVAILABLE: bool = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast") _TORCHVISION_AVAILABLE = _module_available("torchvision") _GYM_AVAILABLE = _module_available("gym") diff --git a/pl_bolts/utils/arguments.py b/pl_bolts/utils/arguments.py index fa0c7e9ec1..decfcef10c 100644 --- a/pl_bolts/utils/arguments.py +++ b/pl_bolts/utils/arguments.py @@ -1,7 +1,7 @@ import inspect from argparse import ArgumentParser, Namespace from dataclasses import dataclass -from typing import Any, Optional +from typing import Any, Dict, List, Optional import pytorch_lightning as pl @@ -32,7 +32,7 @@ class LightningArgumentParser(ArgumentParser): # args.data -> data args # args.model -> model args """ - def __init__(self, *args, ignore_required_init_args=True, **kwargs): + def __init__(self, *args: Any, ignore_required_init_args: bool = True, **kwargs: Any): """ Args: ignore_required_init_args (bool, optional): Whether to include positional args when adding @@ -41,10 +41,10 @@ def __init__(self, *args, ignore_required_init_args=True, **kwargs): super().__init__(*args, **kwargs) self.ignore_required_init_args = ignore_required_init_args - self._default_obj_args = dict() - self._added_arg_names = [] + self._default_obj_args: Dict[str, List[LitArg]] = dict() + self._added_arg_names: List[str] = [] - def add_object_args(self, name, obj): + def add_object_args(self, name: str, obj: Any) -> None: default_args = gather_lit_args(obj) self._default_obj_args[name] = default_args for arg in default_args: @@ -58,7 +58,7 @@ def add_object_args(self, name, obj): kwargs["default"] = arg.default self.add_argument(f"--{arg.name}", **kwargs) - def parse_lit_args(self, *args, **kwargs): + def parse_lit_args(self, *args: Any, **kwargs: Any) -> Namespace: parsed_args_dict = vars(self.parse_args(*args, **kwargs)) lit_args = Namespace() for name, default_args in self._default_obj_args.items(): @@ -72,7 +72,7 @@ def parse_lit_args(self, *args, **kwargs): return lit_args -def gather_lit_args(cls, root_cls=None): +def gather_lit_args(cls: Any, root_cls: Optional[Any] = None) -> List[LitArg]: if root_cls is None: if issubclass(cls, pl.LightningModule): @@ -83,7 +83,7 @@ def gather_lit_args(cls, root_cls=None): root_cls = cls blacklisted_args = ["self", "args", "kwargs"] - arguments = [] + arguments: List[LitArg] = [] argument_names = [] for obj in inspect.getmro(cls): diff --git a/pl_bolts/utils/pretrained_weights.py b/pl_bolts/utils/pretrained_weights.py index 8abfdb18f5..ce69c004ec 100644 --- a/pl_bolts/utils/pretrained_weights.py +++ b/pl_bolts/utils/pretrained_weights.py @@ -1,4 +1,6 @@ +from typing import Optional +from pytorch_lightning import LightningModule vae_imagenet2012 = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/' \ 'vae/imagenet_06_22_2019/checkpoints/epoch%3D63.ckpt' @@ -11,7 +13,7 @@ } -def load_pretrained(model, class_name=None): # pragma: no-cover +def load_pretrained(model: LightningModule, class_name: Optional[str] = None) -> None: # pragma: no-cover if class_name is None: class_name = model.__class__.__name__ ckpt_url = urls[class_name] diff --git a/pl_bolts/utils/self_supervised.py b/pl_bolts/utils/self_supervised.py index b1d72e38c5..5f412594bc 100644 --- a/pl_bolts/utils/self_supervised.py +++ b/pl_bolts/utils/self_supervised.py @@ -1,7 +1,13 @@ +from torch.nn import Module + from pl_bolts.utils.semi_supervised import Identity -def torchvision_ssl_encoder(name, pretrained=False, return_all_feature_maps=False): +def torchvision_ssl_encoder( + name: str, + pretrained: bool = False, + return_all_feature_maps: bool = False, +) -> Module: from pl_bolts.models.self_supervised import resnets pretrained_model = getattr(resnets, name)(pretrained=pretrained, return_all_feature_maps=return_all_feature_maps) diff --git a/pl_bolts/utils/semi_supervised.py b/pl_bolts/utils/semi_supervised.py index 8363fa44b5..bdf5d928af 100644 --- a/pl_bolts/utils/semi_supervised.py +++ b/pl_bolts/utils/semi_supervised.py @@ -1,7 +1,9 @@ import math +from typing import Any, List, Tuple import numpy as np import torch +from torch import Tensor from pl_bolts.utils import _SKLEARN_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg @@ -24,14 +26,14 @@ class Identity(torch.nn.Module): model.fc = Identity() """ - def __init__(self): + def __init__(self) -> None: super(Identity, self).__init__() - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: return x -def balance_classes(X: np.ndarray, Y: list, batch_size: int): +def balance_classes(X: np.ndarray, Y: list, batch_size: int) -> Tuple[np.ndarray, np.ndarray]: """ Makes sure each batch has an equal amount of data from each class. Perfect balance @@ -51,19 +53,19 @@ def balance_classes(X: np.ndarray, Y: list, batch_size: int): nb_batches = math.ceil(len(Y) / batch_size) # sort by classes - final_batches_x = [[] for i in range(nb_batches)] - final_batches_y = [[] for i in range(nb_batches)] + final_batches_x: List[Any] = [[] for i in range(nb_batches)] + final_batches_y: List[Any] = [[] for i in range(nb_batches)] # Y needs to be np arr Y = np.asarray(Y) # pick chunk size for each class using the largest split - chunk_size = [] + chunk_size_class = [] for class_i in range(nb_classes): mask = Y == class_i y = Y[mask] - chunk_size.append(math.ceil(len(y) / nb_batches)) - chunk_size = max(chunk_size) + chunk_size_class.append(math.ceil(len(y) / nb_batches)) + chunk_size = max(chunk_size_class) # force chunk size to be even if chunk_size % 2 != 0: chunk_size -= 1 @@ -102,7 +104,7 @@ def generate_half_labeled_batches( larger_set_X: np.ndarray, larger_set_Y: np.ndarray, batch_size: int, -): +) -> Tuple[np.ndarray, np.ndarray]: """ Given a labeled dataset and an unlabeled dataset, this function generates a joint pair where half the batches are labeled and the other half is not diff --git a/pl_bolts/utils/shaping.py b/pl_bolts/utils/shaping.py index fe06991cd0..abf395467d 100644 --- a/pl_bolts/utils/shaping.py +++ b/pl_bolts/utils/shaping.py @@ -1,8 +1,9 @@ import numpy as np import torch +from torch import Tensor -def tile(a, dim, n_tile): +def tile(a: Tensor, dim: int, n_tile: int) -> Tensor: init_dim = a.size(dim) repeat_idx = [1] * a.dim() repeat_idx[dim] = n_tile diff --git a/pl_bolts/utils/warnings.py b/pl_bolts/utils/warnings.py index 9d9dabbc25..c87ab1cc69 100644 --- a/pl_bolts/utils/warnings.py +++ b/pl_bolts/utils/warnings.py @@ -1,10 +1,16 @@ import os import warnings +from typing import Any, Callable MISSING_PACKAGE_WARNINGS = {} -def warn_missing_pkg(pkg_name: str, pypi_name: str = None, extra_text: str = None, stdout_func=warnings.warn): +def warn_missing_pkg( + pkg_name: str, + pypi_name: str = None, + extra_text: str = None, + stdout_func: Callable = warnings.warn, +) -> int: """ Template for warning on missing packages, show them just once. diff --git a/setup.cfg b/setup.cfg index 1661aad951..95850c7bdb 100644 --- a/setup.cfg +++ b/setup.cfg @@ -115,8 +115,5 @@ ignore_errors = True [mypy-pl_bolts.transforms.*] ignore_errors = True -[mypy-pl_bolts.utils.*] -ignore_errors = True - [mypy-tests.*] ignore_errors = True From 301cf5b05459716759eb3ab299e15f36ccf3dc7a Mon Sep 17 00:00:00 2001 From: hassiahk Date: Wed, 16 Dec 2020 18:50:27 +0530 Subject: [PATCH 2/5] Fixed Requested Changes --- pl_bolts/utils/arguments.py | 6 +++--- pl_bolts/utils/semi_supervised.py | 18 +++++++++--------- pl_bolts/utils/warnings.py | 6 +++--- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/pl_bolts/utils/arguments.py b/pl_bolts/utils/arguments.py index decfcef10c..c8bac77bfb 100644 --- a/pl_bolts/utils/arguments.py +++ b/pl_bolts/utils/arguments.py @@ -32,11 +32,11 @@ class LightningArgumentParser(ArgumentParser): # args.data -> data args # args.model -> model args """ - def __init__(self, *args: Any, ignore_required_init_args: bool = True, **kwargs: Any): + def __init__(self, *args: Any, ignore_required_init_args: bool = True, **kwargs: Any) -> None: """ Args: - ignore_required_init_args (bool, optional): Whether to include positional args when adding - object args. Defaults to True. + ignore_required_init_args: Whether to include positional args when adding + object args. Defaults to ``True``. """ super().__init__(*args, **kwargs) self.ignore_required_init_args = ignore_required_init_args diff --git a/pl_bolts/utils/semi_supervised.py b/pl_bolts/utils/semi_supervised.py index bdf5d928af..0f119adc8b 100644 --- a/pl_bolts/utils/semi_supervised.py +++ b/pl_bolts/utils/semi_supervised.py @@ -1,5 +1,5 @@ import math -from typing import Any, List, Tuple +from typing import Any, List, Sequence, Tuple import numpy as np import torch @@ -33,14 +33,14 @@ def forward(self, x: Tensor) -> Tensor: return x -def balance_classes(X: np.ndarray, Y: list, batch_size: int) -> Tuple[np.ndarray, np.ndarray]: +def balance_classes(X: np.ndarray, labels: Sequence[int], batch_size: int) -> Tuple[np.ndarray, np.ndarray]: """ Makes sure each batch has an equal amount of data from each class. Perfect balance Args: X: input features - Y: mixed labels (ints) + labels: mixed labels (ints) batch_size: the ultimate batch size """ if not _SKLEARN_AVAILABLE: @@ -48,24 +48,24 @@ def balance_classes(X: np.ndarray, Y: list, batch_size: int) -> Tuple[np.ndarray 'You want to use `shuffle` function from `scikit-learn` which is not installed yet.' ) - nb_classes = len(set(Y)) + nb_classes = len(set(labels)) - nb_batches = math.ceil(len(Y) / batch_size) + nb_batches = math.ceil(len(labels) / batch_size) # sort by classes final_batches_x: List[Any] = [[] for i in range(nb_batches)] final_batches_y: List[Any] = [[] for i in range(nb_batches)] # Y needs to be np arr - Y = np.asarray(Y) + Y = np.asarray(labels) # pick chunk size for each class using the largest split - chunk_size_class = [] + chunk_sizes = [] for class_i in range(nb_classes): mask = Y == class_i y = Y[mask] - chunk_size_class.append(math.ceil(len(y) / nb_batches)) - chunk_size = max(chunk_size_class) + chunk_sizes.append(math.ceil(len(y) / nb_batches)) + chunk_size = max(chunk_sizes) # force chunk size to be even if chunk_size % 2 != 0: chunk_size -= 1 diff --git a/pl_bolts/utils/warnings.py b/pl_bolts/utils/warnings.py index c87ab1cc69..baa31ddf54 100644 --- a/pl_bolts/utils/warnings.py +++ b/pl_bolts/utils/warnings.py @@ -1,14 +1,14 @@ import os import warnings -from typing import Any, Callable +from typing import Any, Callable, Optional MISSING_PACKAGE_WARNINGS = {} def warn_missing_pkg( pkg_name: str, - pypi_name: str = None, - extra_text: str = None, + pypi_name: Optional[str] = None, + extra_text: Optional[str] = None, stdout_func: Callable = warnings.warn, ) -> int: """ From da46851418ca784ef74230a2f214bcee38e41821 Mon Sep 17 00:00:00 2001 From: hassiahk Date: Wed, 16 Dec 2020 20:03:42 +0530 Subject: [PATCH 3/5] Made changes for resolving conflicts --- pl_bolts/utils/__init__.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pl_bolts/utils/__init__.py b/pl_bolts/utils/__init__.py index 3271fc0cc4..88aaa1b583 100644 --- a/pl_bolts/utils/__init__.py +++ b/pl_bolts/utils/__init__.py @@ -3,8 +3,8 @@ _NATIVE_AMP_AVAILABLE: bool = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast") -_TORCHVISION_AVAILABLE = _module_available("torchvision") -_GYM_AVAILABLE = _module_available("gym") -_SKLEARN_AVAILABLE = _module_available("sklearn") -_PIL_AVAILABLE = _module_available("PIL") -_OPENCV_AVAILABLE = _module_available("cv2") +_TORCHVISION_AVAILABLE: bool = _module_available("torchvision") +_GYM_AVAILABLE: bool = _module_available("gym") +_SKLEARN_AVAILABLE: bool = _module_available("sklearn") +_PIL_AVAILABLE: bool = _module_available("PIL") +_OPENCV_AVAILABLE: bool = _module_available("cv2") From de5ff43a59c39d1a11a60981cb870bdb32b35a7d Mon Sep 17 00:00:00 2001 From: hassiahk Date: Fri, 18 Dec 2020 12:14:51 +0530 Subject: [PATCH 4/5] Resolved Requested Changes --- pl_bolts/utils/arguments.py | 10 +++++----- pl_bolts/utils/semi_supervised.py | 16 ++++++++++------ 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/pl_bolts/utils/arguments.py b/pl_bolts/utils/arguments.py index c8bac77bfb..76d5659e7c 100644 --- a/pl_bolts/utils/arguments.py +++ b/pl_bolts/utils/arguments.py @@ -92,7 +92,7 @@ def gather_lit_args(cls: Any, root_cls: Optional[Any] = None) -> List[LitArg]: if issubclass(obj, root_cls): - default_params = inspect.signature(obj.__init__).parameters + default_params = inspect.signature(obj.__init__).parameters # type: ignore for arg in default_params: arg_type = default_params[arg].annotation @@ -104,7 +104,7 @@ def gather_lit_args(cls: Any, root_cls: Optional[Any] = None) -> List[LitArg]: arg_types = (arg_type,) # If type is empty, that means it hasn't been given type hint. We skip these. - arg_is_missing_type_hint = arg_types == (inspect._empty,) + arg_is_missing_type_hint = arg_types == (inspect.Parameter.empty,) # Some args should be ignored by default (self, kwargs, args) arg_is_in_blacklist = arg in blacklisted_args and arg_is_missing_type_hint # We only keep the first arg we see of a given name, as it overrides the parents @@ -113,9 +113,9 @@ def gather_lit_args(cls: Any, root_cls: Optional[Any] = None) -> List[LitArg]: do_skip_this_arg = arg_is_in_blacklist or arg_is_missing_type_hint or arg_is_duplicate # Positional args have no default, but do have a known type or types. - arg_is_positional = arg_default == inspect._empty and not arg_is_missing_type_hint + arg_is_positional = arg_default == inspect.Parameter.empty and not arg_is_missing_type_hint # Kwargs have both a default + known type or types - arg_is_kwarg = arg_default != inspect._empty and not arg_is_missing_type_hint + arg_is_kwarg = arg_default != inspect.Parameter.empty and not arg_is_missing_type_hint if do_skip_this_arg: continue @@ -124,7 +124,7 @@ def gather_lit_args(cls: Any, root_cls: Optional[Any] = None) -> List[LitArg]: lit_arg = LitArg( name=arg, types=arg_types, - default=arg_default if arg_default != inspect._empty else None, + default=arg_default if arg_default != inspect.Parameter.empty else None, required=arg_is_positional, context=obj.__name__, ) diff --git a/pl_bolts/utils/semi_supervised.py b/pl_bolts/utils/semi_supervised.py index 0f119adc8b..bfd14d1696 100644 --- a/pl_bolts/utils/semi_supervised.py +++ b/pl_bolts/utils/semi_supervised.py @@ -1,5 +1,5 @@ import math -from typing import Any, List, Sequence, Tuple +from typing import Any, List, Sequence, Tuple, Union import numpy as np import torch @@ -33,14 +33,18 @@ def forward(self, x: Tensor) -> Tensor: return x -def balance_classes(X: np.ndarray, labels: Sequence[int], batch_size: int) -> Tuple[np.ndarray, np.ndarray]: +def balance_classes( + X: Union[Tensor, np.ndarray], + Y: Union[Tensor, np.ndarray, Sequence[int]], + batch_size: int +) -> Tuple[np.ndarray, np.ndarray]: """ Makes sure each batch has an equal amount of data from each class. Perfect balance Args: X: input features - labels: mixed labels (ints) + Y: mixed labels (ints) batch_size: the ultimate batch size """ if not _SKLEARN_AVAILABLE: @@ -48,16 +52,16 @@ def balance_classes(X: np.ndarray, labels: Sequence[int], batch_size: int) -> Tu 'You want to use `shuffle` function from `scikit-learn` which is not installed yet.' ) - nb_classes = len(set(labels)) + nb_classes = len(set(Y)) - nb_batches = math.ceil(len(labels) / batch_size) + nb_batches = math.ceil(len(Y) / batch_size) # sort by classes final_batches_x: List[Any] = [[] for i in range(nb_batches)] final_batches_y: List[Any] = [[] for i in range(nb_batches)] # Y needs to be np arr - Y = np.asarray(labels) + Y = np.asarray(Y) # pick chunk size for each class using the largest split chunk_sizes = [] From 4510e1fe1cd3eece2ce82328bc72ef8ff2b4ca7f Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Sun, 20 Dec 2020 22:47:38 +0100 Subject: [PATCH 5/5] typing --- pl_bolts/utils/semi_supervised.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pl_bolts/utils/semi_supervised.py b/pl_bolts/utils/semi_supervised.py index bfd14d1696..163723fe17 100644 --- a/pl_bolts/utils/semi_supervised.py +++ b/pl_bolts/utils/semi_supervised.py @@ -57,8 +57,8 @@ def balance_classes( nb_batches = math.ceil(len(Y) / batch_size) # sort by classes - final_batches_x: List[Any] = [[] for i in range(nb_batches)] - final_batches_y: List[Any] = [[] for i in range(nb_batches)] + final_batches_x: List[list] = [[] for i in range(nb_batches)] + final_batches_y: List[list] = [[] for i in range(nb_batches)] # Y needs to be np arr Y = np.asarray(Y)