Skip to content

Commit

Permalink
feat(crafter): add ImageReader
Browse files Browse the repository at this point in the history
  • Loading branch information
nan-wang committed Apr 7, 2020
1 parent a747844 commit 09bc3c1
Show file tree
Hide file tree
Showing 8 changed files with 103 additions and 40 deletions.
File renamed without changes.
27 changes: 27 additions & 0 deletions jina/executors/crafters/image/io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import Tuple, Dict, List

import numpy as np

from .. import BaseSegmenter


class ImageReader(BaseSegmenter):
def __init__(self, channel_axis: int = -1, *args, **kwargs):
"""
:class:`ImageReader` load an image file and craft into image matrix.
:param channel_axis: the axis id of the color channel, ``-1`` indicates the color channel info at the last axis
"""
super().__init__(*args, **kwargs)
self.channel_axis = channel_axis

def craft(self, raw_bytes, doc_id, *args, **kwargs) -> List[Dict]:
from PIL import Image
raw_img = Image.open(raw_bytes.decode())
raw_img.tobytes()
if raw_img.mode != 'RGB':
raw_img = raw_img.convert('RGB')
img = np.array(raw_img).astype('float32')
if self.channel_axis != -1:
img = np.moveaxis(img, -1, self.channel_axis)
return [dict(doc_id=doc_id, offset=0, weight=1., blob=img), ]
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

import numpy as np

from .. import BaseSegmenter
from .. import BaseChunkCrafter


class ImageNormalizer(BaseSegmenter):
class ImageNormalizer(BaseChunkCrafter):
""":class:`ImageNormalizer` works on doc-level,
it receives values of file names on the doc-level and returns image matrix on the chunk-level """

Expand All @@ -14,45 +14,53 @@ def __init__(self,
img_mean: Tuple[float] = (0, 0, 0),
img_std: Tuple[float] = (1, 1, 1),
resize_dim: int = 256,
channel_axis: int = -1,
*args,
**kwargs):
"""
:class:`ImageNormalizer` load an image file and craft into image matrix.
:class:`ImageNormalizer` normalize the image.
:param output_dim: the output size. Both height and width are set to `output_dim`
:param img_mean: the mean of the images in `RGB` channels. Set to `[0.485, 0.456, 0.406]` for the models trained
on `imagenet` from `paddlehub`
on `imagenet` with pytorch backbone.
:param img_std: the std of the images in `RGB` channels. Set to `[0.229, 0.224, 0.225]` for the models trained
on `imagenet` from `paddlehub`
on `imagenet` with pytorch backbone.
:param resize_dim: the size of images' height and width to resized to. The images are resized before cropping to
the output size
:param channel_axis: the axis id of the color channel, ``-1`` indicates the color channel info at the last axis
"""
super().__init__()
self.output_dim = output_dim
self.img_mean = np.array(img_mean).reshape((3, 1, 1))
self.img_std = np.array(img_std).reshape((3, 1, 1))
self.resize_dim = resize_dim
self.channel_axis = channel_axis
self.img_mean = np.array(img_mean).reshape((1, 1, 3))
self.img_std = np.array(img_std).reshape((1, 1, 3))

def craft(self, raw_bytes, doc_id, *args, **kwargs) -> List[Dict]:
def craft(self, blob, chunk_id, doc_id, *args, **kwargs) -> Dict:
"""
:param raw_bytes: the file name in bytes
:param blob: the ndarray of the image with the color channel at the last axis
:param chunk_id: the chunk id
:param doc_id: the doc id
:return: a list of chunks-level info represented by a dict
:return: a chunk dict with the normalized image
"""
from PIL import Image
raw_img = Image.open(raw_bytes.decode())
if self.channel_axis != -1:
blob = np.moveaxis(blob, self.channel_axis, -1)
raw_img = Image.fromarray(blob.astype('uint8'))
processed_img = self._normalize(raw_img)
return [dict(doc_id=doc_id, offset=0, weight=1., blob=processed_img), ]
return dict(doc_id=doc_id, offset=0, weight=1., blob=processed_img)

def _normalize(self, img):
img = self._resize_short(img, target_size=self.resize_dim)
img = self._crop_image(img, target_size=self.output_dim, center=True)
if img.mode != 'RGB':
img = img.convert('RGB')
img = np.array(img).astype('float32').transpose((2, 0, 1)) / 255
img = np.array(img).astype('float32') / 255
img -= self.img_mean
img /= self.img_std
if self.channel_axis != -1:
img = np.moveaxis(img, -1, self.channel_axis)
return img

@staticmethod
Expand All @@ -78,3 +86,5 @@ def _crop_image(img, target_size, center):
h_end = h_start + size
img = img.crop((w_start, h_start, w_end, h_end))
return img


Empty file.
27 changes: 0 additions & 27 deletions tests/executors/crafters/cv/test_image.py

This file was deleted.

15 changes: 15 additions & 0 deletions tests/executors/crafters/image/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from tests import JinaTestCase


class JinaImageTestCase(JinaTestCase):
@staticmethod
def create_test_image(output_fn, size=50):
from PIL import Image
image = Image.new('RGB', size=(size, size), color=(155, 0, 0))
with open(output_fn, "wb") as f:
image.save(f, 'jpeg')

@staticmethod
def create_test_img_array(img_height, img_width):
import numpy as np
return np.random.randint(0, 256, (img_height, img_width, 3))
20 changes: 20 additions & 0 deletions tests/executors/crafters/image/test_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import os
import unittest

from jina.executors.crafters.image.io import ImageReader
from . import JinaImageTestCase


class MyTestCase(JinaImageTestCase):
def test_io(self):
crafter = ImageReader()
tmp_fn = os.path.join(crafter.current_workspace, "test.jpeg")
img_size = 50
self.create_test_image(tmp_fn, size=img_size)
test_chunk, *_ = crafter.craft(tmp_fn.encode("utf8"), doc_id=0)
self.assertEqual(test_chunk["blob"].shape, (img_size, img_size, 3))
self.add_tmpfile(tmp_fn)


if __name__ == '__main__':
unittest.main()
18 changes: 18 additions & 0 deletions tests/executors/crafters/image/test_normalize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import os
import unittest

from jina.executors.crafters.image.normalize import ImageNormalizer
from . import JinaImageTestCase


class MyTestCase(JinaImageTestCase):
def test_transform_results(self):
img_size = 224
crafter = ImageNormalizer(output_dim=img_size)
img_array = self.create_test_img_array(img_size, img_size)
crafted_chunk = crafter.craft(img_array, chunk_id=0, doc_id=0)
self.assertEqual(crafted_chunk["blob"].shape, (224, 224, 3))


if __name__ == '__main__':
unittest.main()

0 comments on commit 09bc3c1

Please sign in to comment.