diff --git a/pl_bolts/utils/__init__.py b/pl_bolts/utils/__init__.py index 7bfaeb94d3..88aaa1b583 100644 --- a/pl_bolts/utils/__init__.py +++ b/pl_bolts/utils/__init__.py @@ -1,10 +1,10 @@ import torch from pytorch_lightning.utilities import _module_available -_NATIVE_AMP_AVAILABLE = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast") +_NATIVE_AMP_AVAILABLE: bool = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast") -_TORCHVISION_AVAILABLE = _module_available("torchvision") -_GYM_AVAILABLE = _module_available("gym") -_SKLEARN_AVAILABLE = _module_available("sklearn") -_PIL_AVAILABLE = _module_available("PIL") -_OPENCV_AVAILABLE = _module_available("cv2") +_TORCHVISION_AVAILABLE: bool = _module_available("torchvision") +_GYM_AVAILABLE: bool = _module_available("gym") +_SKLEARN_AVAILABLE: bool = _module_available("sklearn") +_PIL_AVAILABLE: bool = _module_available("PIL") +_OPENCV_AVAILABLE: bool = _module_available("cv2") diff --git a/pl_bolts/utils/arguments.py b/pl_bolts/utils/arguments.py index fa0c7e9ec1..76d5659e7c 100644 --- a/pl_bolts/utils/arguments.py +++ b/pl_bolts/utils/arguments.py @@ -1,7 +1,7 @@ import inspect from argparse import ArgumentParser, Namespace from dataclasses import dataclass -from typing import Any, Optional +from typing import Any, Dict, List, Optional import pytorch_lightning as pl @@ -32,19 +32,19 @@ class LightningArgumentParser(ArgumentParser): # args.data -> data args # args.model -> model args """ - def __init__(self, *args, ignore_required_init_args=True, **kwargs): + def __init__(self, *args: Any, ignore_required_init_args: bool = True, **kwargs: Any) -> None: """ Args: - ignore_required_init_args (bool, optional): Whether to include positional args when adding - object args. Defaults to True. + ignore_required_init_args: Whether to include positional args when adding + object args. Defaults to ``True``. """ super().__init__(*args, **kwargs) self.ignore_required_init_args = ignore_required_init_args - self._default_obj_args = dict() - self._added_arg_names = [] + self._default_obj_args: Dict[str, List[LitArg]] = dict() + self._added_arg_names: List[str] = [] - def add_object_args(self, name, obj): + def add_object_args(self, name: str, obj: Any) -> None: default_args = gather_lit_args(obj) self._default_obj_args[name] = default_args for arg in default_args: @@ -58,7 +58,7 @@ def add_object_args(self, name, obj): kwargs["default"] = arg.default self.add_argument(f"--{arg.name}", **kwargs) - def parse_lit_args(self, *args, **kwargs): + def parse_lit_args(self, *args: Any, **kwargs: Any) -> Namespace: parsed_args_dict = vars(self.parse_args(*args, **kwargs)) lit_args = Namespace() for name, default_args in self._default_obj_args.items(): @@ -72,7 +72,7 @@ def parse_lit_args(self, *args, **kwargs): return lit_args -def gather_lit_args(cls, root_cls=None): +def gather_lit_args(cls: Any, root_cls: Optional[Any] = None) -> List[LitArg]: if root_cls is None: if issubclass(cls, pl.LightningModule): @@ -83,7 +83,7 @@ def gather_lit_args(cls, root_cls=None): root_cls = cls blacklisted_args = ["self", "args", "kwargs"] - arguments = [] + arguments: List[LitArg] = [] argument_names = [] for obj in inspect.getmro(cls): @@ -92,7 +92,7 @@ def gather_lit_args(cls, root_cls=None): if issubclass(obj, root_cls): - default_params = inspect.signature(obj.__init__).parameters + default_params = inspect.signature(obj.__init__).parameters # type: ignore for arg in default_params: arg_type = default_params[arg].annotation @@ -104,7 +104,7 @@ def gather_lit_args(cls, root_cls=None): arg_types = (arg_type,) # If type is empty, that means it hasn't been given type hint. We skip these. - arg_is_missing_type_hint = arg_types == (inspect._empty,) + arg_is_missing_type_hint = arg_types == (inspect.Parameter.empty,) # Some args should be ignored by default (self, kwargs, args) arg_is_in_blacklist = arg in blacklisted_args and arg_is_missing_type_hint # We only keep the first arg we see of a given name, as it overrides the parents @@ -113,9 +113,9 @@ def gather_lit_args(cls, root_cls=None): do_skip_this_arg = arg_is_in_blacklist or arg_is_missing_type_hint or arg_is_duplicate # Positional args have no default, but do have a known type or types. - arg_is_positional = arg_default == inspect._empty and not arg_is_missing_type_hint + arg_is_positional = arg_default == inspect.Parameter.empty and not arg_is_missing_type_hint # Kwargs have both a default + known type or types - arg_is_kwarg = arg_default != inspect._empty and not arg_is_missing_type_hint + arg_is_kwarg = arg_default != inspect.Parameter.empty and not arg_is_missing_type_hint if do_skip_this_arg: continue @@ -124,7 +124,7 @@ def gather_lit_args(cls, root_cls=None): lit_arg = LitArg( name=arg, types=arg_types, - default=arg_default if arg_default != inspect._empty else None, + default=arg_default if arg_default != inspect.Parameter.empty else None, required=arg_is_positional, context=obj.__name__, ) diff --git a/pl_bolts/utils/pretrained_weights.py b/pl_bolts/utils/pretrained_weights.py index 8abfdb18f5..ce69c004ec 100644 --- a/pl_bolts/utils/pretrained_weights.py +++ b/pl_bolts/utils/pretrained_weights.py @@ -1,4 +1,6 @@ +from typing import Optional +from pytorch_lightning import LightningModule vae_imagenet2012 = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/' \ 'vae/imagenet_06_22_2019/checkpoints/epoch%3D63.ckpt' @@ -11,7 +13,7 @@ } -def load_pretrained(model, class_name=None): # pragma: no-cover +def load_pretrained(model: LightningModule, class_name: Optional[str] = None) -> None: # pragma: no-cover if class_name is None: class_name = model.__class__.__name__ ckpt_url = urls[class_name] diff --git a/pl_bolts/utils/self_supervised.py b/pl_bolts/utils/self_supervised.py index b1d72e38c5..5f412594bc 100644 --- a/pl_bolts/utils/self_supervised.py +++ b/pl_bolts/utils/self_supervised.py @@ -1,7 +1,13 @@ +from torch.nn import Module + from pl_bolts.utils.semi_supervised import Identity -def torchvision_ssl_encoder(name, pretrained=False, return_all_feature_maps=False): +def torchvision_ssl_encoder( + name: str, + pretrained: bool = False, + return_all_feature_maps: bool = False, +) -> Module: from pl_bolts.models.self_supervised import resnets pretrained_model = getattr(resnets, name)(pretrained=pretrained, return_all_feature_maps=return_all_feature_maps) diff --git a/pl_bolts/utils/semi_supervised.py b/pl_bolts/utils/semi_supervised.py index 8363fa44b5..163723fe17 100644 --- a/pl_bolts/utils/semi_supervised.py +++ b/pl_bolts/utils/semi_supervised.py @@ -1,7 +1,9 @@ import math +from typing import Any, List, Sequence, Tuple, Union import numpy as np import torch +from torch import Tensor from pl_bolts.utils import _SKLEARN_AVAILABLE from pl_bolts.utils.warnings import warn_missing_pkg @@ -24,14 +26,18 @@ class Identity(torch.nn.Module): model.fc = Identity() """ - def __init__(self): + def __init__(self) -> None: super(Identity, self).__init__() - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: return x -def balance_classes(X: np.ndarray, Y: list, batch_size: int): +def balance_classes( + X: Union[Tensor, np.ndarray], + Y: Union[Tensor, np.ndarray, Sequence[int]], + batch_size: int +) -> Tuple[np.ndarray, np.ndarray]: """ Makes sure each batch has an equal amount of data from each class. Perfect balance @@ -51,19 +57,19 @@ def balance_classes(X: np.ndarray, Y: list, batch_size: int): nb_batches = math.ceil(len(Y) / batch_size) # sort by classes - final_batches_x = [[] for i in range(nb_batches)] - final_batches_y = [[] for i in range(nb_batches)] + final_batches_x: List[list] = [[] for i in range(nb_batches)] + final_batches_y: List[list] = [[] for i in range(nb_batches)] # Y needs to be np arr Y = np.asarray(Y) # pick chunk size for each class using the largest split - chunk_size = [] + chunk_sizes = [] for class_i in range(nb_classes): mask = Y == class_i y = Y[mask] - chunk_size.append(math.ceil(len(y) / nb_batches)) - chunk_size = max(chunk_size) + chunk_sizes.append(math.ceil(len(y) / nb_batches)) + chunk_size = max(chunk_sizes) # force chunk size to be even if chunk_size % 2 != 0: chunk_size -= 1 @@ -102,7 +108,7 @@ def generate_half_labeled_batches( larger_set_X: np.ndarray, larger_set_Y: np.ndarray, batch_size: int, -): +) -> Tuple[np.ndarray, np.ndarray]: """ Given a labeled dataset and an unlabeled dataset, this function generates a joint pair where half the batches are labeled and the other half is not diff --git a/pl_bolts/utils/shaping.py b/pl_bolts/utils/shaping.py index fe06991cd0..abf395467d 100644 --- a/pl_bolts/utils/shaping.py +++ b/pl_bolts/utils/shaping.py @@ -1,8 +1,9 @@ import numpy as np import torch +from torch import Tensor -def tile(a, dim, n_tile): +def tile(a: Tensor, dim: int, n_tile: int) -> Tensor: init_dim = a.size(dim) repeat_idx = [1] * a.dim() repeat_idx[dim] = n_tile diff --git a/pl_bolts/utils/warnings.py b/pl_bolts/utils/warnings.py index 9d9dabbc25..baa31ddf54 100644 --- a/pl_bolts/utils/warnings.py +++ b/pl_bolts/utils/warnings.py @@ -1,10 +1,16 @@ import os import warnings +from typing import Any, Callable, Optional MISSING_PACKAGE_WARNINGS = {} -def warn_missing_pkg(pkg_name: str, pypi_name: str = None, extra_text: str = None, stdout_func=warnings.warn): +def warn_missing_pkg( + pkg_name: str, + pypi_name: Optional[str] = None, + extra_text: Optional[str] = None, + stdout_func: Callable = warnings.warn, +) -> int: """ Template for warning on missing packages, show them just once. diff --git a/setup.cfg b/setup.cfg index 1661aad951..95850c7bdb 100644 --- a/setup.cfg +++ b/setup.cfg @@ -115,8 +115,5 @@ ignore_errors = True [mypy-pl_bolts.transforms.*] ignore_errors = True -[mypy-pl_bolts.utils.*] -ignore_errors = True - [mypy-tests.*] ignore_errors = True