Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add center_crop to ImageFeatureExtractionMixin #11066

Merged
merged 1 commit into from
Apr 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions src/transformers/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,55 @@ def resize(self, image, size, resample=PIL.Image.BILINEAR):
image = self.to_pil_image(image)

return image.resize(size, resample=resample)

def center_crop(self, image, size):
"""
Crops :obj:`image` to the given size using a center crop. Note that if the image is too small to be cropped to
the size given, it will be padded (so the returned result has the size asked).

Args:
image (:obj:`PIL.Image.Image` or :obj:`np.ndarray` or :obj:`torch.Tensor`):
The image to resize.
size (:obj:`int` or :obj:`Tuple[int, int]`):
The size to which crop the image.
"""
self._ensure_format_supported(image)
if not isinstance(size, tuple):
size = (size, size)

# PIL Image.size is (width, height) but NumPy array and torch Tensors have (height, width)
image_shape = (image.size[1], image.size[0]) if isinstance(image, PIL.Image.Image) else image.shape[-2:]
top = (image_shape[0] - size[0]) // 2
bottom = top + size[0] # In case size is odd, (image_shape[0] + size[0]) // 2 won't give the proper result.
left = (image_shape[1] - size[1]) // 2
right = left + size[1] # In case size is odd, (image_shape[1] + size[1]) // 2 won't give the proper result.

# For PIL Images we have a method to crop directly.
if isinstance(image, PIL.Image.Image):
return image.crop((left, top, right, bottom))

# Check if all the dimensions are inside the image.
if top >= 0 and bottom <= image_shape[0] and left >= 0 and right <= image_shape[1]:
return image[..., top:bottom, left:right]

# Otherwise, we may need to pad if the image is too small. Oh joy...
new_shape = image.shape[:-2] + (max(size[0], image_shape[0]), max(size[1], image_shape[1]))
if isinstance(image, np.ndarray):
new_image = np.zeros_like(image, shape=new_shape)
elif is_torch_tensor(image):
new_image = image.new_zeros(new_shape)

top_pad = (new_shape[-2] - image_shape[0]) // 2
bottom_pad = top_pad + image_shape[0]
left_pad = (new_shape[-1] - image_shape[1]) // 2
right_pad = left_pad + image_shape[1]
new_image[..., top_pad:bottom_pad, left_pad:right_pad] = image

top += top_pad
bottom += top_pad
left += left_pad
right += left_pad

return new_image[
..., max(0, top) : min(new_image.shape[-2], bottom), max(0, left) : min(new_image.shape[-1], right)
]
52 changes: 52 additions & 0 deletions tests/test_image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,3 +315,55 @@ def test_normalize_tensor(self):

normalized_tensor = feature_extractor.normalize(tensor, torch.tensor(mean), torch.tensor(std))
self.assertTrue(torch.equal(normalized_tensor, expected))

def test_center_crop_image(self):
feature_extractor = ImageFeatureExtractionMixin()
image = get_random_image(16, 32)

# Test various crop sizes: bigger on all dimensions, on one of the dimensions only and on both dimensions.
crop_sizes = [8, (8, 64), 20, (32, 64)]
for size in crop_sizes:
cropped_image = feature_extractor.center_crop(image, size)
self.assertTrue(isinstance(cropped_image, PIL.Image.Image))

# PIL Image.size is transposed compared to NumPy or PyTorch (width first instead of height first).
expected_size = (size, size) if isinstance(size, int) else (size[1], size[0])
self.assertEqual(cropped_image.size, expected_size)

def test_center_crop_array(self):
feature_extractor = ImageFeatureExtractionMixin()
image = get_random_image(16, 32)
array = feature_extractor.to_numpy_array(image)

# Test various crop sizes: bigger on all dimensions, on one of the dimensions only and on both dimensions.
crop_sizes = [8, (8, 64), 20, (32, 64)]
for size in crop_sizes:
cropped_array = feature_extractor.center_crop(array, size)
self.assertTrue(isinstance(cropped_array, np.ndarray))

expected_size = (size, size) if isinstance(size, int) else size
self.assertEqual(cropped_array.shape[-2:], expected_size)

# Check result is consistent with PIL.Image.crop
cropped_image = feature_extractor.center_crop(image, size)
self.assertTrue(np.array_equal(cropped_array, feature_extractor.to_numpy_array(cropped_image)))

@require_torch
def test_center_crop_tensor(self):
feature_extractor = ImageFeatureExtractionMixin()
image = get_random_image(16, 32)
array = feature_extractor.to_numpy_array(image)
tensor = torch.tensor(array)

# Test various crop sizes: bigger on all dimensions, on one of the dimensions only and on both dimensions.
crop_sizes = [8, (8, 64), 20, (32, 64)]
for size in crop_sizes:
cropped_tensor = feature_extractor.center_crop(tensor, size)
self.assertTrue(isinstance(cropped_tensor, torch.Tensor))

expected_size = (size, size) if isinstance(size, int) else size
self.assertEqual(cropped_tensor.shape[-2:], expected_size)

# Check result is consistent with PIL.Image.crop
cropped_image = feature_extractor.center_crop(image, size)
self.assertTrue(torch.equal(cropped_tensor, torch.tensor(feature_extractor.to_numpy_array(cropped_image))))