Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Type Annotations to pl_bolts/utils/* #455

Merged
merged 5 commits into from
Dec 20, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions pl_bolts/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import torch
from pytorch_lightning.utilities import _module_available

_NATIVE_AMP_AVAILABLE = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast")
_NATIVE_AMP_AVAILABLE: bool = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast")

_TORCHVISION_AVAILABLE = _module_available("torchvision")
_GYM_AVAILABLE = _module_available("gym")
_SKLEARN_AVAILABLE = _module_available("sklearn")
_PIL_AVAILABLE = _module_available("PIL")
_OPENCV_AVAILABLE = _module_available("cv2")
_TORCHVISION_AVAILABLE: bool = _module_available("torchvision")
_GYM_AVAILABLE: bool = _module_available("gym")
_SKLEARN_AVAILABLE: bool = _module_available("sklearn")
_PIL_AVAILABLE: bool = _module_available("PIL")
_OPENCV_AVAILABLE: bool = _module_available("cv2")
20 changes: 10 additions & 10 deletions pl_bolts/utils/arguments.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import inspect
from argparse import ArgumentParser, Namespace
from dataclasses import dataclass
from typing import Any, Optional
from typing import Any, Dict, List, Optional

import pytorch_lightning as pl

Expand Down Expand Up @@ -32,19 +32,19 @@ class LightningArgumentParser(ArgumentParser):
# args.data -> data args
# args.model -> model args
"""
def __init__(self, *args, ignore_required_init_args=True, **kwargs):
def __init__(self, *args: Any, ignore_required_init_args: bool = True, **kwargs: Any) -> None:
"""
Args:
ignore_required_init_args (bool, optional): Whether to include positional args when adding
object args. Defaults to True.
ignore_required_init_args: Whether to include positional args when adding
object args. Defaults to ``True``.
"""
super().__init__(*args, **kwargs)
self.ignore_required_init_args = ignore_required_init_args

self._default_obj_args = dict()
self._added_arg_names = []
self._default_obj_args: Dict[str, List[LitArg]] = dict()
akihironitta marked this conversation as resolved.
Show resolved Hide resolved
self._added_arg_names: List[str] = []

def add_object_args(self, name, obj):
def add_object_args(self, name: str, obj: Any) -> None:
default_args = gather_lit_args(obj)
self._default_obj_args[name] = default_args
for arg in default_args:
Expand All @@ -58,7 +58,7 @@ def add_object_args(self, name, obj):
kwargs["default"] = arg.default
self.add_argument(f"--{arg.name}", **kwargs)

def parse_lit_args(self, *args, **kwargs):
def parse_lit_args(self, *args: Any, **kwargs: Any) -> Namespace:
parsed_args_dict = vars(self.parse_args(*args, **kwargs))
lit_args = Namespace()
for name, default_args in self._default_obj_args.items():
Expand All @@ -72,7 +72,7 @@ def parse_lit_args(self, *args, **kwargs):
return lit_args


def gather_lit_args(cls, root_cls=None):
def gather_lit_args(cls: Any, root_cls: Optional[Any] = None) -> List[LitArg]:

if root_cls is None:
if issubclass(cls, pl.LightningModule):
Expand All @@ -83,7 +83,7 @@ def gather_lit_args(cls, root_cls=None):
root_cls = cls

blacklisted_args = ["self", "args", "kwargs"]
arguments = []
arguments: List[LitArg] = []
argument_names = []
for obj in inspect.getmro(cls):

Expand Down
4 changes: 3 additions & 1 deletion pl_bolts/utils/pretrained_weights.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Optional

from pytorch_lightning import LightningModule

vae_imagenet2012 = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/' \
'vae/imagenet_06_22_2019/checkpoints/epoch%3D63.ckpt'
Expand All @@ -11,7 +13,7 @@
}


def load_pretrained(model, class_name=None): # pragma: no-cover
def load_pretrained(model: LightningModule, class_name: Optional[str] = None) -> None: # pragma: no-cover
if class_name is None:
class_name = model.__class__.__name__
ckpt_url = urls[class_name]
Expand Down
8 changes: 7 additions & 1 deletion pl_bolts/utils/self_supervised.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from torch.nn import Module

from pl_bolts.utils.semi_supervised import Identity


def torchvision_ssl_encoder(name, pretrained=False, return_all_feature_maps=False):
def torchvision_ssl_encoder(
name: str,
pretrained: bool = False,
return_all_feature_maps: bool = False,
) -> Module:
from pl_bolts.models.self_supervised import resnets

pretrained_model = getattr(resnets, name)(pretrained=pretrained, return_all_feature_maps=return_all_feature_maps)
Expand Down
28 changes: 15 additions & 13 deletions pl_bolts/utils/semi_supervised.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import math
from typing import Any, List, Sequence, Tuple

import numpy as np
import torch
from torch import Tensor

from pl_bolts.utils import _SKLEARN_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg
Expand All @@ -24,46 +26,46 @@ class Identity(torch.nn.Module):
model.fc = Identity()

"""
def __init__(self):
def __init__(self) -> None:
super(Identity, self).__init__()

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
return x


def balance_classes(X: np.ndarray, Y: list, batch_size: int):
def balance_classes(X: np.ndarray, labels: Sequence[int], batch_size: int) -> Tuple[np.ndarray, np.ndarray]:
hassiahk marked this conversation as resolved.
Show resolved Hide resolved
"""
Makes sure each batch has an equal amount of data from each class.
Perfect balance

Args:
X: input features
Y: mixed labels (ints)
labels: mixed labels (ints)
hassiahk marked this conversation as resolved.
Show resolved Hide resolved
batch_size: the ultimate batch size
"""
if not _SKLEARN_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
'You want to use `shuffle` function from `scikit-learn` which is not installed yet.'
)

nb_classes = len(set(Y))
nb_classes = len(set(labels))

nb_batches = math.ceil(len(Y) / batch_size)
nb_batches = math.ceil(len(labels) / batch_size)

# sort by classes
final_batches_x = [[] for i in range(nb_batches)]
final_batches_y = [[] for i in range(nb_batches)]
final_batches_x: List[Any] = [[] for i in range(nb_batches)]
final_batches_y: List[Any] = [[] for i in range(nb_batches)]
akihironitta marked this conversation as resolved.
Show resolved Hide resolved
hassiahk marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since these variables are the output of the function, I think they should match with the return type of the function balance_classes():

) -> Tuple[np.ndarray, np.ndarray]:

but in order to do that, you might need to change/define other variables, too. We can tighten these types later in another PR, but it'd be nice to do it in this PR. Do you think you can improve typing here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I will see what I can do. :)

Borda marked this conversation as resolved.
Show resolved Hide resolved

# Y needs to be np arr
Y = np.asarray(Y)
Y = np.asarray(labels)

# pick chunk size for each class using the largest split
chunk_size = []
chunk_sizes = []
for class_i in range(nb_classes):
mask = Y == class_i
y = Y[mask]
chunk_size.append(math.ceil(len(y) / nb_batches))
chunk_size = max(chunk_size)
chunk_sizes.append(math.ceil(len(y) / nb_batches))
chunk_size = max(chunk_sizes)
# force chunk size to be even
if chunk_size % 2 != 0:
chunk_size -= 1
Expand Down Expand Up @@ -102,7 +104,7 @@ def generate_half_labeled_batches(
larger_set_X: np.ndarray,
larger_set_Y: np.ndarray,
batch_size: int,
):
) -> Tuple[np.ndarray, np.ndarray]:
"""
Given a labeled dataset and an unlabeled dataset, this function generates
a joint pair where half the batches are labeled and the other half is not
Expand Down
3 changes: 2 additions & 1 deletion pl_bolts/utils/shaping.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import numpy as np
import torch
from torch import Tensor


def tile(a, dim, n_tile):
def tile(a: Tensor, dim: int, n_tile: int) -> Tensor:
init_dim = a.size(dim)
repeat_idx = [1] * a.dim()
repeat_idx[dim] = n_tile
Expand Down
8 changes: 7 additions & 1 deletion pl_bolts/utils/warnings.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import os
import warnings
from typing import Any, Callable, Optional

MISSING_PACKAGE_WARNINGS = {}


def warn_missing_pkg(pkg_name: str, pypi_name: str = None, extra_text: str = None, stdout_func=warnings.warn):
def warn_missing_pkg(
pkg_name: str,
pypi_name: Optional[str] = None,
extra_text: Optional[str] = None,
stdout_func: Callable = warnings.warn,
) -> int:
"""
Template for warning on missing packages, show them just once.

Expand Down
3 changes: 0 additions & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,5 @@ ignore_errors = True
[mypy-pl_bolts.transforms.*]
ignore_errors = True

[mypy-pl_bolts.utils.*]
ignore_errors = True

[mypy-tests.*]
ignore_errors = True