Skip to content

Commit

Permalink
Add missing type hints to ColorJitter constructor (#7087)
Browse files Browse the repository at this point in the history
  • Loading branch information
RoiEXLab authored Jan 17, 2023
1 parent 8985b59 commit 93df9a5
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import random
import warnings
from collections.abc import Sequence
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Union

import torch
from torch import Tensor
Expand Down Expand Up @@ -1172,7 +1172,13 @@ class ColorJitter(torch.nn.Module):
or use an interpolation that generates negative values before using this function.
"""

def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
def __init__(
self,
brightness: Union[float, Tuple[float, float]] = 0,
contrast: Union[float, Tuple[float, float]] = 0,
saturation: Union[float, Tuple[float, float]] = 0,
hue: Union[float, Tuple[float, float]] = 0,
) -> None:
super().__init__()
_log_api_usage_once(self)
self.brightness = self._check_input(brightness, "brightness")
Expand Down

0 comments on commit 93df9a5

Please sign in to comment.