From 93df9a50885d0345e31bba691576c83d5cee7737 Mon Sep 17 00:00:00 2001 From: RoiEX <8350879+RoiEXLab@users.noreply.github.com> Date: Tue, 17 Jan 2023 14:11:00 +0100 Subject: [PATCH] Add missing type hints to ColorJitter constructor (#7087) --- torchvision/transforms/transforms.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 2f513058ad0..cb2bfdb92a8 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -3,7 +3,7 @@ import random import warnings from collections.abc import Sequence -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import torch from torch import Tensor @@ -1172,7 +1172,13 @@ class ColorJitter(torch.nn.Module): or use an interpolation that generates negative values before using this function. """ - def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): + def __init__( + self, + brightness: Union[float, Tuple[float, float]] = 0, + contrast: Union[float, Tuple[float, float]] = 0, + saturation: Union[float, Tuple[float, float]] = 0, + hue: Union[float, Tuple[float, float]] = 0, + ) -> None: super().__init__() _log_api_usage_once(self) self.brightness = self._check_input(brightness, "brightness")