-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
249 additions
and
39 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import numpy as np | ||
import numbers | ||
|
||
from .. import BaseChunkCrafter | ||
|
||
|
||
class ImageChunkCrafter(BaseChunkCrafter): | ||
|
||
def __init__(self, channel_axis: int = -1, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.channel_axis = channel_axis | ||
|
||
def _load_image(self, blob: 'np.ndarray'): | ||
from PIL import Image | ||
if self.channel_axis != -1: | ||
blob = np.moveaxis(blob, self.channel_axis, -1) | ||
return Image.fromarray(blob.astype('uint8')) | ||
|
||
@staticmethod | ||
def _resize_short(img, target_size): | ||
from PIL.Image import LANCZOS | ||
percent = float(target_size) / min(img.size[0], img.size[1]) | ||
resized_width = int(round(img.size[0] * percent)) | ||
resized_height = int(round(img.size[1] * percent)) | ||
img = img.resize((resized_width, resized_height), LANCZOS) | ||
return img | ||
|
||
@staticmethod | ||
def _crop_image(img, target_size, left=None, top=None, how='precise'): | ||
img_width, img_height = img.size | ||
width, height = target_size | ||
w_start = left | ||
h_start = top | ||
if how == 'center': | ||
w_start = (img_width - width) / 2 | ||
h_start = (img_height - height) / 2 | ||
elif how == 'random': | ||
w_start = np.random.randint(0, img_width - width + 1) | ||
h_start = np.random.randint(0, img_height - height + 1) | ||
if not isinstance(w_start, numbers.Number): | ||
raise ValueError('left must be int number between 0 and {}: {}'.format(img_width, left)) | ||
if not isinstance(h_start, numbers.Number): | ||
raise ValueError('top must be int number between 0 and {}: {}'.format(img_height, top)) | ||
w_end = w_start + width | ||
h_end = h_start + height | ||
img = img.crop((w_start, h_start, w_end, h_end)) | ||
return img | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
from typing import Tuple, Dict, List | ||
|
||
import numpy as np | ||
|
||
from . import ImageChunkCrafter | ||
|
||
|
||
class ImageCropper(ImageChunkCrafter): | ||
def __init__(self, | ||
left: int, | ||
top: int, | ||
width: int, | ||
height: int, | ||
channel_axis: int = -1, | ||
*args, | ||
**kwargs): | ||
super().__init__(channel_axis, *args, **kwargs) | ||
self.left = left | ||
self.top = top | ||
self.width = width | ||
self.height = height | ||
self.channel_axis = channel_axis | ||
|
||
def craft(self, blob: 'np.ndarray', chunk_id, doc_id, *args, **kwargs) -> Dict: | ||
""" | ||
:param blob: the ndarray of the image with the color channel at the last axis | ||
:param chunk_id: the chunk id | ||
:param doc_id: the doc id | ||
:return: a chunk dict with the normalized image | ||
""" | ||
raw_img = self._load_image(blob) | ||
processe_img = self._crop_image(raw_img, (self.width, self.height), self.left, self.top) | ||
return dict(doc_id=doc_id, offset=0, weight=1., blob=np.asarray(processe_img).astype('float32')) | ||
|
||
|
||
class CenterImageCropper(ImageChunkCrafter): | ||
def __init__(self, | ||
output_dim: int, | ||
channel_axis: int = -1, | ||
*args, | ||
**kwargs): | ||
super().__init__(channel_axis, *args, **kwargs) | ||
self.output_dim = output_dim | ||
|
||
def craft(self, blob: 'np.ndarray', chunk_id, doc_id, *args, **kwargs) -> Dict: | ||
""" | ||
:param blob: the ndarray of the image with the color channel at the last axis | ||
:param chunk_id: the chunk id | ||
:param doc_id: the doc id | ||
:return: a chunk dict with the normalized image | ||
""" | ||
raw_img = self._load_image(blob) | ||
processe_img = self._crop_image(raw_img, (self.output_dim, self.output_dim), how='center') | ||
return dict(doc_id=doc_id, offset=0, weight=1., blob=np.asarray(processe_img).astype('float32')) | ||
|
||
|
||
class RandomImageCropper(ImageChunkCrafter): | ||
def __init__(self, | ||
output_dim: int, | ||
num_patches: int = 1, | ||
channel_axis: int = -1, | ||
*args, | ||
**kwargs): | ||
super().__init__(channel_axis, *args, **kwargs) | ||
self.output_dim = output_dim | ||
self.num_pathes = num_patches | ||
|
||
def craft(self, blob: 'np.ndarray', chunk_id, doc_id, *args, **kwargs) -> Dict: | ||
""" | ||
:param blob: the ndarray of the image with the color channel at the last axis | ||
:param chunk_id: the chunk id | ||
:param doc_id: the doc id | ||
:return: a chunk dict with the normalized image | ||
""" | ||
raw_img = self._load_image(blob) | ||
result = [] | ||
for i in range(self.num_pathes): | ||
processe_img = self._crop_image(raw_img, (self.output_dim, self.output_dim), how='random') | ||
result.append( | ||
dict(doc_id=doc_id, offset=0, weight=1., blob=np.asarray(processe_img).astype('float32'))) | ||
return result | ||
|
||
|
||
class FiveImageCropper(ImageChunkCrafter): | ||
def __init__(self, | ||
output_dim: int, | ||
channel_axis: int = -1, | ||
*args, | ||
**kwargs): | ||
super().__init__(channel_axis, *args, **kwargs) | ||
self.output_dim = output_dim | ||
|
||
def craft(self, blob: 'np.ndarray', chunk_id, doc_id, *args, **kwargs) -> List[Dict]: | ||
""" | ||
:param blob: the ndarray of the image with the color channel at the last axis | ||
:param chunk_id: the chunk id | ||
:param doc_id: the doc id | ||
:return: a chunk dict with the normalized image | ||
""" | ||
raw_img = self._load_image(blob) | ||
image_width, image_height = raw_img.size | ||
crop_height = self.output_dim | ||
crop_width = self.output_dim | ||
if crop_width > image_width or crop_height > image_height: | ||
msg = "Requested crop size {} is bigger than input size {}" | ||
raise ValueError(msg.format(self.output_dim, (image_height, image_width))) | ||
|
||
tl = self._crop_image(raw_img, (crop_width, crop_height), 0, 0) | ||
tr = self._crop_image(raw_img, (image_width, crop_height), image_width - crop_width, 0) | ||
bl = self._crop_image(raw_img, (crop_width, image_height), 0, image_height - crop_height) | ||
br = self._crop_image(raw_img, (image_width, image_height), | ||
image_width - crop_width, image_height - crop_height) | ||
center = self._crop_image(raw_img, (crop_height, crop_width), how='center') | ||
return [ | ||
dict(doc_id=doc_id, offset=0, weight=1., blob=np.asarray(tl).astype('float32')), | ||
dict(doc_id=doc_id, offset=0, weight=1., blob=np.asarray(tr).astype('float32')), | ||
dict(doc_id=doc_id, offset=0, weight=1., blob=np.asarray(bl).astype('float32')), | ||
dict(doc_id=doc_id, offset=0, weight=1., blob=np.asarray(br).astype('float32')), | ||
dict(doc_id=doc_id, offset=0, weight=1., blob=np.asarray(center).astype('float32')), | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import unittest | ||
import numpy as np | ||
|
||
from jina.executors.crafters.image.crop import ImageCropper, CenterImageCropper, RandomImageCropper, FiveImageCropper | ||
from tests.executors.crafters.image import JinaImageTestCase | ||
|
||
|
||
class MyTestCase(JinaImageTestCase): | ||
def test_crop(self): | ||
img_size = 217 | ||
img_array = self.create_test_img_array(img_size, img_size) | ||
left = 2 | ||
top = 17 | ||
width = 20 | ||
height = 20 | ||
crafter = ImageCropper(left, top, width, height) | ||
crafted_chunk = crafter.craft(img_array, 0, 0) | ||
np.testing.assert_array_equal( | ||
crafted_chunk['blob'], np.asarray(img_array[top:top+height, left:left+width, :]), | ||
'img_array: {}\ntest: {}\ncontrol:{}'.format( | ||
img_array.shape, | ||
crafted_chunk['blob'].shape, | ||
np.asarray(img_array[left:left+width, top:top+height, :]).shape)) | ||
|
||
def test_center_crop(self): | ||
img_size = 217 | ||
img_array = self.create_test_img_array(img_size, img_size) | ||
output_dim = 20 | ||
crafter = CenterImageCropper(output_dim) | ||
crafted_chunk = crafter.craft(img_array, 0, 0) | ||
self.assertEqual(crafted_chunk["blob"].shape, (20, 20, 3)) | ||
|
||
def test_random_crop(self): | ||
img_size = 217 | ||
img_array = self.create_test_img_array(img_size, img_size) | ||
output_dim = 20 | ||
num_pathes = 20 | ||
crafter = RandomImageCropper(output_dim, num_pathes) | ||
crafted_chunk_list = crafter.craft(img_array, 0, 0) | ||
self.assertEqual(len(crafted_chunk_list), num_pathes) | ||
|
||
def test_random_crop(self): | ||
img_size = 217 | ||
img_array = self.create_test_img_array(img_size, img_size) | ||
output_dim = 20 | ||
num_pathes = 20 | ||
crafter = FiveImageCropper(output_dim) | ||
crafted_chunk_list = crafter.craft(img_array, 0, 0) | ||
self.assertEqual(len(crafted_chunk_list), 5) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters