-
Notifications
You must be signed in to change notification settings - Fork 322
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Revision of SimCLR transforms #857
Changes from 9 commits
68d839c
3e6bba1
6dbef2c
c5a8dff
a107930
821e915
9d881dd
5f7445b
cd582f5
1472a4b
32ca5d7
99baa9e
4a0bcb9
a439d19
cb223d1
1db2a15
8db8465
71540d8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,31 +1,22 @@ | ||
import numpy as np | ||
|
||
from pl_bolts.utils import _OPENCV_AVAILABLE, _TORCHVISION_AVAILABLE | ||
from pl_bolts.utils.stability import under_review | ||
from pl_bolts.utils import _TORCHVISION_AVAILABLE | ||
from pl_bolts.utils.warnings import warn_missing_pkg | ||
|
||
if _TORCHVISION_AVAILABLE: | ||
from torchvision import transforms | ||
else: # pragma: no cover | ||
warn_missing_pkg("torchvision") | ||
|
||
if _OPENCV_AVAILABLE: | ||
import cv2 | ||
else: # pragma: no cover | ||
warn_missing_pkg("cv2", pypi_name="opencv-python") | ||
|
||
|
||
@under_review() | ||
class SimCLRTrainDataTransform: | ||
"""Transforms for SimCLR. | ||
"""Transforms for SimCLR during training step of the pre-training stage. | ||
|
||
Transform:: | ||
|
||
RandomResizedCrop(size=self.input_height) | ||
RandomHorizontalFlip() | ||
RandomApply([color_jitter], p=0.8) | ||
RandomGrayscale(p=0.2) | ||
GaussianBlur(kernel_size=int(0.1 * self.input_height)) | ||
RandomApply([GaussianBlur(kernel_size=int(0.1 * self.input_height))], p=0.5) | ||
transforms.ToTensor() | ||
|
||
Example:: | ||
|
@@ -34,7 +25,7 @@ class SimCLRTrainDataTransform: | |
|
||
transform = SimCLRTrainDataTransform(input_height=32) | ||
x = sample() | ||
(xi, xj) = transform(x) | ||
(xi, xj, xk) = transform(x) # xk is only for the online evaluator if used | ||
""" | ||
|
||
def __init__( | ||
|
@@ -68,7 +59,7 @@ def __init__( | |
if kernel_size % 2 == 0: | ||
kernel_size += 1 | ||
|
||
data_transforms.append(GaussianBlur(kernel_size=kernel_size, p=0.5)) | ||
data_transforms.append(transforms.RandomApply([transforms.GaussianBlur(kernel_size=kernel_size)], p=0.5)) | ||
|
||
data_transforms = transforms.Compose(data_transforms) | ||
|
||
|
@@ -93,9 +84,8 @@ def __call__(self, sample): | |
return xi, xj, self.online_transform(sample) | ||
|
||
|
||
@under_review() | ||
class SimCLREvalDataTransform(SimCLRTrainDataTransform): | ||
"""Transforms for SimCLR. | ||
"""Transforms for SimCLR during the validation step of the pre-training stage. | ||
|
||
Transform:: | ||
|
||
|
@@ -109,7 +99,7 @@ class SimCLREvalDataTransform(SimCLRTrainDataTransform): | |
|
||
transform = SimCLREvalDataTransform(input_height=32) | ||
x = sample() | ||
(xi, xj) = transform(x) | ||
(xi, xj, xk) = transform(x) # xk is only for the online evaluator if used | ||
""" | ||
|
||
def __init__( | ||
|
@@ -129,8 +119,24 @@ def __init__( | |
) | ||
|
||
|
||
@under_review() | ||
class SimCLRFinetuneTransform: | ||
"""Transforms for SimCLR during the fine-tuning stage. | ||
|
||
Transform:: | ||
|
||
Resize(input_height + 10, interpolation=3) | ||
transforms.CenterCrop(input_height), | ||
transforms.ToTensor() | ||
|
||
Example:: | ||
|
||
from pl_bolts.models.self_supervised.simclr.transforms import SimCLREvalDataTransform | ||
|
||
transform = SimCLREvalDataTransform(input_height=32) | ||
x = sample() | ||
(_, _, xk) = transform(x) | ||
""" | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So, I'm not too versed in SimCLR, but either this docstring is wrong or the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for spotting this ❤️ |
||
def __init__( | ||
self, input_height: int = 224, jitter_strength: float = 1.0, normalize=None, eval_transform: bool = False | ||
) -> None: | ||
|
@@ -169,30 +175,3 @@ def __init__( | |
|
||
def __call__(self, sample): | ||
return self.transform(sample) | ||
|
||
|
||
@under_review() | ||
class GaussianBlur: | ||
# Implements Gaussian blur as described in the SimCLR paper | ||
def __init__(self, kernel_size, p=0.5, min=0.1, max=2.0): | ||
if not _TORCHVISION_AVAILABLE: # pragma: no cover | ||
raise ModuleNotFoundError("You want to use `GaussianBlur` from `cv2` which is not installed yet.") | ||
|
||
self.min = min | ||
self.max = max | ||
|
||
# kernel size is set to be 10% of the image height/width | ||
self.kernel_size = kernel_size | ||
self.p = p | ||
|
||
def __call__(self, sample): | ||
sample = np.array(sample) | ||
|
||
# blur the image with a 50% chance | ||
prob = np.random.random_sample() | ||
|
||
if prob < self.p: | ||
sigma = (self.max - self.min) * np.random.random_sample() + self.min | ||
sample = cv2.GaussianBlur(sample, (self.kernel_size, self.kernel_size), sigma) | ||
|
||
return sample |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import numpy as np | ||
import pytest | ||
import torch | ||
from PIL import Image | ||
|
||
from pl_bolts.models.self_supervised.simclr.transforms import ( | ||
SimCLREvalDataTransform, | ||
SimCLRFinetuneTransform, | ||
SimCLRTrainDataTransform, | ||
) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"transform_cls", | ||
[pytest.param(SimCLRTrainDataTransform, id="train-data"), pytest.param(SimCLREvalDataTransform, id="eval-data")], | ||
) | ||
def test_simclr_train_data_transform(catch_warnings, transform_cls): | ||
# dummy image | ||
img = np.random.randint(low=0, high=255, size=(32, 32, 3), dtype=np.uint8) | ||
img = Image.fromarray(img) | ||
|
||
# size of the generated views | ||
input_height = 96 | ||
transform = transform_cls(input_height=input_height) | ||
views = transform(img) | ||
|
||
# the transform must output a list or a tuple of images | ||
assert isinstance(views, (list, tuple)) | ||
|
||
# the transform must output three images | ||
# (1st view, 2nd view, online evaluation view) | ||
assert len(views) == 3 | ||
|
||
# all views are tensors | ||
assert all(torch.is_tensor(v) for v in views) | ||
|
||
# all views have expected sizes | ||
assert all(v.size(1) == v.size(2) == input_height for v in views) | ||
|
||
|
||
def test_simclr_finetune_transform(catch_warnings): | ||
# dummy image | ||
img = np.random.randint(low=0, high=255, size=(32, 32, 3), dtype=np.uint8) | ||
img = Image.fromarray(img) | ||
|
||
# size of the generated views | ||
input_height = 96 | ||
transform = SimCLRFinetuneTransform(input_height=input_height) | ||
view = transform(img) | ||
|
||
# the view generator is a tensor | ||
assert torch.is_tensor(view) | ||
|
||
# view has expected size | ||
assert view.size(1) == view.size(2) == input_height |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this somehow be also subclass of
SimCLRTrainTransform
? Seems like quite a bit of code is duplicatedThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, will do.