Skip to content

Commit

Permalink
feat(crafter): add cropers
Browse files Browse the repository at this point in the history
  • Loading branch information
nan-wang committed Apr 7, 2020
1 parent 09bc3c1 commit e929ecc
Show file tree
Hide file tree
Showing 7 changed files with 249 additions and 39 deletions.
17 changes: 17 additions & 0 deletions jina/executors/crafters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,20 @@ def craft(self, *args, **kwargs) -> List[Dict]:
:return: a list of chunks-level info represented by a dict
"""
raise NotImplementedError


class BaseChunkSegmenter(BaseCrafter):
""":class:`BaseChunkSegmenter` works on doc-level,
it receives value on the chunk-level and returns a list of new values on the chunk-level """

def craft(self, *args, **kwargs) -> List[Dict]:
"""The apply function of this executor.
Unlike :class:`BaseCrafter`, the :func:`craft` here works on doc-level info and the output is defined on
chunk-level. Therefore the name of the arguments should be always valid keys defined
in the doc-level protobuf whereas the output dict keys should always be valid keys defined in the chunk-level
protobuf.
:return: a list of chunks-level info represented by a dict
"""
raise NotImplementedError
48 changes: 48 additions & 0 deletions jina/executors/crafters/image/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import numpy as np
import numbers

from .. import BaseChunkCrafter


class ImageChunkCrafter(BaseChunkCrafter):

def __init__(self, channel_axis: int = -1, *args, **kwargs):
super().__init__(*args, **kwargs)
self.channel_axis = channel_axis

def _load_image(self, blob: 'np.ndarray'):
from PIL import Image
if self.channel_axis != -1:
blob = np.moveaxis(blob, self.channel_axis, -1)
return Image.fromarray(blob.astype('uint8'))

@staticmethod
def _resize_short(img, target_size):
from PIL.Image import LANCZOS
percent = float(target_size) / min(img.size[0], img.size[1])
resized_width = int(round(img.size[0] * percent))
resized_height = int(round(img.size[1] * percent))
img = img.resize((resized_width, resized_height), LANCZOS)
return img

@staticmethod
def _crop_image(img, target_size, left=None, top=None, how='precise'):
img_width, img_height = img.size
width, height = target_size
w_start = left
h_start = top
if how == 'center':
w_start = (img_width - width) / 2
h_start = (img_height - height) / 2
elif how == 'random':
w_start = np.random.randint(0, img_width - width + 1)
h_start = np.random.randint(0, img_height - height + 1)
if not isinstance(w_start, numbers.Number):
raise ValueError('left must be int number between 0 and {}: {}'.format(img_width, left))
if not isinstance(h_start, numbers.Number):
raise ValueError('top must be int number between 0 and {}: {}'.format(img_height, top))
w_end = w_start + width
h_end = h_start + height
img = img.crop((w_start, h_start, w_end, h_end))
return img

124 changes: 124 additions & 0 deletions jina/executors/crafters/image/crop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from typing import Tuple, Dict, List

import numpy as np

from . import ImageChunkCrafter


class ImageCropper(ImageChunkCrafter):
def __init__(self,
left: int,
top: int,
width: int,
height: int,
channel_axis: int = -1,
*args,
**kwargs):
super().__init__(channel_axis, *args, **kwargs)
self.left = left
self.top = top
self.width = width
self.height = height
self.channel_axis = channel_axis

def craft(self, blob: 'np.ndarray', chunk_id, doc_id, *args, **kwargs) -> Dict:
"""
:param blob: the ndarray of the image with the color channel at the last axis
:param chunk_id: the chunk id
:param doc_id: the doc id
:return: a chunk dict with the normalized image
"""
raw_img = self._load_image(blob)
processe_img = self._crop_image(raw_img, (self.width, self.height), self.left, self.top)
return dict(doc_id=doc_id, offset=0, weight=1., blob=np.asarray(processe_img).astype('float32'))


class CenterImageCropper(ImageChunkCrafter):
def __init__(self,
output_dim: int,
channel_axis: int = -1,
*args,
**kwargs):
super().__init__(channel_axis, *args, **kwargs)
self.output_dim = output_dim

def craft(self, blob: 'np.ndarray', chunk_id, doc_id, *args, **kwargs) -> Dict:
"""
:param blob: the ndarray of the image with the color channel at the last axis
:param chunk_id: the chunk id
:param doc_id: the doc id
:return: a chunk dict with the normalized image
"""
raw_img = self._load_image(blob)
processe_img = self._crop_image(raw_img, (self.output_dim, self.output_dim), how='center')
return dict(doc_id=doc_id, offset=0, weight=1., blob=np.asarray(processe_img).astype('float32'))


class RandomImageCropper(ImageChunkCrafter):
def __init__(self,
output_dim: int,
num_patches: int = 1,
channel_axis: int = -1,
*args,
**kwargs):
super().__init__(channel_axis, *args, **kwargs)
self.output_dim = output_dim
self.num_pathes = num_patches

def craft(self, blob: 'np.ndarray', chunk_id, doc_id, *args, **kwargs) -> Dict:
"""
:param blob: the ndarray of the image with the color channel at the last axis
:param chunk_id: the chunk id
:param doc_id: the doc id
:return: a chunk dict with the normalized image
"""
raw_img = self._load_image(blob)
result = []
for i in range(self.num_pathes):
processe_img = self._crop_image(raw_img, (self.output_dim, self.output_dim), how='random')
result.append(
dict(doc_id=doc_id, offset=0, weight=1., blob=np.asarray(processe_img).astype('float32')))
return result


class FiveImageCropper(ImageChunkCrafter):
def __init__(self,
output_dim: int,
channel_axis: int = -1,
*args,
**kwargs):
super().__init__(channel_axis, *args, **kwargs)
self.output_dim = output_dim

def craft(self, blob: 'np.ndarray', chunk_id, doc_id, *args, **kwargs) -> List[Dict]:
"""
:param blob: the ndarray of the image with the color channel at the last axis
:param chunk_id: the chunk id
:param doc_id: the doc id
:return: a chunk dict with the normalized image
"""
raw_img = self._load_image(blob)
image_width, image_height = raw_img.size
crop_height = self.output_dim
crop_width = self.output_dim
if crop_width > image_width or crop_height > image_height:
msg = "Requested crop size {} is bigger than input size {}"
raise ValueError(msg.format(self.output_dim, (image_height, image_width)))

tl = self._crop_image(raw_img, (crop_width, crop_height), 0, 0)
tr = self._crop_image(raw_img, (image_width, crop_height), image_width - crop_width, 0)
bl = self._crop_image(raw_img, (crop_width, image_height), 0, image_height - crop_height)
br = self._crop_image(raw_img, (image_width, image_height),
image_width - crop_width, image_height - crop_height)
center = self._crop_image(raw_img, (crop_height, crop_width), how='center')
return [
dict(doc_id=doc_id, offset=0, weight=1., blob=np.asarray(tl).astype('float32')),
dict(doc_id=doc_id, offset=0, weight=1., blob=np.asarray(tr).astype('float32')),
dict(doc_id=doc_id, offset=0, weight=1., blob=np.asarray(bl).astype('float32')),
dict(doc_id=doc_id, offset=0, weight=1., blob=np.asarray(br).astype('float32')),
dict(doc_id=doc_id, offset=0, weight=1., blob=np.asarray(center).astype('float32')),
]
41 changes: 5 additions & 36 deletions jina/executors/crafters/image/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

import numpy as np

from .. import BaseChunkCrafter
from . import ImageChunkCrafter


class ImageNormalizer(BaseChunkCrafter):
class ImageNormalizer(ImageChunkCrafter):
""":class:`ImageNormalizer` works on doc-level,
it receives values of file names on the doc-level and returns image matrix on the chunk-level """

Expand All @@ -29,7 +29,7 @@ def __init__(self,
the output size
:param channel_axis: the axis id of the color channel, ``-1`` indicates the color channel info at the last axis
"""
super().__init__()
super().__init__(*args, **kwargs)
self.output_dim = output_dim
self.resize_dim = resize_dim
self.channel_axis = channel_axis
Expand All @@ -44,47 +44,16 @@ def craft(self, blob, chunk_id, doc_id, *args, **kwargs) -> Dict:
:param doc_id: the doc id
:return: a chunk dict with the normalized image
"""
from PIL import Image
if self.channel_axis != -1:
blob = np.moveaxis(blob, self.channel_axis, -1)
raw_img = Image.fromarray(blob.astype('uint8'))
raw_img = self._load_image(blob)
processed_img = self._normalize(raw_img)
return dict(doc_id=doc_id, offset=0, weight=1., blob=processed_img)

def _normalize(self, img):
img = self._resize_short(img, target_size=self.resize_dim)
img = self._crop_image(img, target_size=self.output_dim, center=True)
if img.mode != 'RGB':
img = img.convert('RGB')
img = self._crop_image(img, target_size=(self.output_dim, self.output_dim), how='center')
img = np.array(img).astype('float32') / 255
img -= self.img_mean
img /= self.img_std
if self.channel_axis != -1:
img = np.moveaxis(img, -1, self.channel_axis)
return img

@staticmethod
def _resize_short(img, target_size):
from PIL.Image import LANCZOS
percent = float(target_size) / min(img.size[0], img.size[1])
resized_width = int(round(img.size[0] * percent))
resized_height = int(round(img.size[1] * percent))
img = img.resize((resized_width, resized_height), LANCZOS)
return img

@staticmethod
def _crop_image(img, target_size, center):
width, height = img.size
size = target_size
if center:
w_start = (width - size) / 2
h_start = (height - size) / 2
else:
w_start = np.random.randint(0, width - size + 1)
h_start = np.random.randint(0, height - size + 1)
w_end = w_start + size
h_end = h_start + size
img = img.crop((w_start, h_start, w_end, h_end))
return img


53 changes: 53 additions & 0 deletions tests/executors/crafters/image/test_crop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import unittest
import numpy as np

from jina.executors.crafters.image.crop import ImageCropper, CenterImageCropper, RandomImageCropper, FiveImageCropper
from tests.executors.crafters.image import JinaImageTestCase


class MyTestCase(JinaImageTestCase):
def test_crop(self):
img_size = 217
img_array = self.create_test_img_array(img_size, img_size)
left = 2
top = 17
width = 20
height = 20
crafter = ImageCropper(left, top, width, height)
crafted_chunk = crafter.craft(img_array, 0, 0)
np.testing.assert_array_equal(
crafted_chunk['blob'], np.asarray(img_array[top:top+height, left:left+width, :]),
'img_array: {}\ntest: {}\ncontrol:{}'.format(
img_array.shape,
crafted_chunk['blob'].shape,
np.asarray(img_array[left:left+width, top:top+height, :]).shape))

def test_center_crop(self):
img_size = 217
img_array = self.create_test_img_array(img_size, img_size)
output_dim = 20
crafter = CenterImageCropper(output_dim)
crafted_chunk = crafter.craft(img_array, 0, 0)
self.assertEqual(crafted_chunk["blob"].shape, (20, 20, 3))

def test_random_crop(self):
img_size = 217
img_array = self.create_test_img_array(img_size, img_size)
output_dim = 20
num_pathes = 20
crafter = RandomImageCropper(output_dim, num_pathes)
crafted_chunk_list = crafter.craft(img_array, 0, 0)
self.assertEqual(len(crafted_chunk_list), num_pathes)

def test_random_crop(self):
img_size = 217
img_array = self.create_test_img_array(img_size, img_size)
output_dim = 20
num_pathes = 20
crafter = FiveImageCropper(output_dim)
crafted_chunk_list = crafter.craft(img_array, 0, 0)
self.assertEqual(len(crafted_chunk_list), 5)


if __name__ == '__main__':
unittest.main()
2 changes: 1 addition & 1 deletion tests/executors/crafters/image/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import unittest

from jina.executors.crafters.image.io import ImageReader
from . import JinaImageTestCase
from tests.executors.crafters.image import JinaImageTestCase


class MyTestCase(JinaImageTestCase):
Expand Down
3 changes: 1 addition & 2 deletions tests/executors/crafters/image/test_normalize.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import os
import unittest

from jina.executors.crafters.image.normalize import ImageNormalizer
from . import JinaImageTestCase
from tests.executors.crafters.image import JinaImageTestCase


class MyTestCase(JinaImageTestCase):
Expand Down

0 comments on commit e929ecc

Please sign in to comment.