Skip to content

Commit

Permalink
Merge pull request #195 from jina-ai/refactor-encoder-116
Browse files Browse the repository at this point in the history
refactor(encoder): refactoring torch encoders
  • Loading branch information
hanxiao authored Apr 5, 2020
2 parents ee2f658 + 154b4b9 commit 93382b0
Show file tree
Hide file tree
Showing 29 changed files with 101 additions and 72 deletions.
2 changes: 0 additions & 2 deletions jina/executors/encoders/image/paddlehub.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ def __init__(self,
``densenet264_imagenet``, ``densenet201_imagenet``, ``densenet169_imagenet``, ``densenet161_imagenet``,
``densenet121_imagenet``, ``darknet53_imagenet``,
``alexnet_imagenet``,
# ``pnasnet_imagenet``,
# ``nasnet_imagenet``
"""
Expand Down
30 changes: 12 additions & 18 deletions jina/executors/encoders/image/torchvision.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import numpy as np

from .. import BaseImageEncoder
from ..torchvision import TorchEncoder


class TorchImageEncoder(BaseImageEncoder):
class ImageTorchEncoder(TorchEncoder):
"""
:class:`TorchImageEncoder` encodes data from a ndarray, potentially B x (Channel x Height x Width) into a
:class:`ImageTorchEncoder` encodes data from a ndarray, potentially B x (Channel x Height x Width) into a
ndarray of `B x D`.
Internally, :class:`TorchImageEncoder` wraps the models from `torchvision.models`.
Internally, :class:`ImageTorchEncoder` wraps the models from `torchvision.models`.
https://pytorch.org/docs/stable/torchvision/models.html
"""

def __init__(self, model_name: str = 'mobilenet_v2', pool_strategy: str = 'mean', *args, **kwargs):
def __init__(self,
model_name: str = 'mobilenet_v2',
pool_strategy: str = 'mean', *args, **kwargs):
"""
:param model_name: the name of the model. Supported models include
Expand All @@ -33,28 +35,20 @@ def __init__(self, model_name: str = 'mobilenet_v2', pool_strategy: str = 'mean'
thus the output of the model will be a 2D tensor.
- `max` means that global max pooling will be applied.
"""
super().__init__(*args, **kwargs)
self.model_name = model_name
super().__init__(model_name, *args, **kwargs)
self.pool_strategy = pool_strategy
if pool_strategy not in ('mean', 'max', None):
raise NotImplementedError('unknown pool_strategy: {}'.format(self.pool_strategy))

def post_init(self):
def _build_model(self):
import torchvision.models as models
import torch
model = getattr(models, self.model_name)(pretrained=True)
self.model = model.features.eval()
device = 'cuda:0' if self.on_gpu else 'cpu'
self.model.to(torch.device(device))

def encode(self, data: 'np.ndarray', *args, **kwargs) -> 'np.ndarray':
"""
def _get_features(self, data):
return self.model(data)

:param data: a `B x (Channel x Height x Width)` numpy ``ndarray``, `B` is the size of the batch
:return: a `B x D` numpy ``ndarray``, `D` is the output dimension
"""
import torch
feature_map = self.model(torch.from_numpy(data.astype('float32'))).detach().numpy()
def _get_pooling(self, feature_map: 'np.ndarray') -> 'np.ndarray':
if feature_map.ndim == 2 or self.pool_strategy is None:
return feature_map
return getattr(np, self.pool_strategy)(feature_map, axis=(2, 3))
42 changes: 28 additions & 14 deletions jina/executors/encoders/nlp/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import numpy as np

from .. import BaseTextEncoder
from ..helper import reduce_mean, reduce_max, reduce_cls
from ..helper import reduce_mean, reduce_max, reduce_min, reduce_cls
from ...decorators import batching, as_ndarray


class TransformerEncoder(BaseTextEncoder):
Expand All @@ -13,16 +14,17 @@ class TransformerEncoder(BaseTextEncoder):

def __init__(self,
model_name: str = 'bert-base-uncased',
pooling_strategy: str = 'reduce-mean',
pooling_strategy: str = 'mean',
max_length: int = 64,
model_path: str = 'transformer',
*args, **kwargs):
"""
:param model_name: the name of the model. Supported models include 'bert-base-uncased', 'openai-gpt', 'gpt2',
'xlm-mlm-enfr-1024', 'distilbert-base-cased', 'roberta-base', 'xlm-roberta-base' .
'xlm-mlm-enfr-1024', 'distilbert-base-cased', 'roberta-base', 'xlm-roberta-base', 'flaubert-base-cased',
'camembert-base', 'ctrl'.
:param pooling_strategy: the strategy to merge the word embeddings into the chunk embedding. Supported
strategies include 'cls', 'reduce-mean', 'reduce-max'.
strategies include 'cls', 'mean', 'max', 'min'.
:param max_length: the max length to truncate the tokenized sequences to.
:param model_path: the path of the encoder model. If a valid path is given, the encoder will be loaded from the
given path.
Expand All @@ -38,7 +40,8 @@ def __init__(self,

def post_init(self):
from transformers import BertTokenizer, OpenAIGPTTokenizer, GPT2Tokenizer, \
XLNetTokenizer, XLMTokenizer, DistilBertTokenizer, RobertaTokenizer, XLMRobertaTokenizer
XLNetTokenizer, XLMTokenizer, DistilBertTokenizer, RobertaTokenizer, XLMRobertaTokenizer, \
FlaubertTokenizer, CamembertTokenizer, CTRLTokenizer

tokenizer_dict = {
'bert-base-uncased': BertTokenizer,
Expand All @@ -48,7 +51,10 @@ def post_init(self):
'xlm-mlm-enfr-1024': XLMTokenizer,
'distilbert-base-cased': DistilBertTokenizer,
'roberta-base': RobertaTokenizer,
'xlm-roberta-base': XLMRobertaTokenizer
'xlm-roberta-base': XLMRobertaTokenizer,
'flaubert-base-cased': FlaubertTokenizer,
'camembert-base': CamembertTokenizer,
'ctrl': CTRLTokenizer
}

if self.model_name not in tokenizer_dict:
Expand All @@ -62,14 +68,18 @@ def post_init(self):
self.tokenizer = tokenizer_dict[self.model_name].from_pretrained(self._tmp_model_path)
self.tokenizer.padding_side = 'right'

if self.model_name in ('bert-base-uncased', 'distilbert-base-cased', 'roberta-base', 'xlm-roberta-base'):
if self.model_name in (
'bert-base-uncased', 'distilbert-base-cased', 'roberta-base', 'xlm-roberta-base', 'flaubert-base-cased',
'camembert-base'):
self.cls_pos = 'head'
elif self.model_name in ('xlnet-base-cased'):
self.cls_pos = 'tail'

if self.model_name in ('openai-gpt', 'gpt2', 'xlm-mlm-enfr-1024', 'xlnet-base-cased'):
self.tokenizer.pad_token = '<PAD>'

@batching
@as_ndarray
def encode(self, data: 'np.ndarray', *args, **kwargs) -> 'np.ndarray':
"""
Expand All @@ -86,17 +96,16 @@ def encode(self, data: 'np.ndarray', *args, **kwargs) -> 'np.ndarray':
mask_ids_batch.append(mask_ids)
token_ids_batch = self._tensor_func(token_ids_batch)
mask_ids_batch = self._tensor_func(mask_ids_batch)

with self._sess_func():
# seq_output, cls_output = self.model(token_ids_batch, attention_mask=mask_ids_batch)
seq_output, *extra_output = self.model(token_ids_batch, attention_mask=mask_ids_batch)
if self.pooling_strategy == 'cls':
output = reduce_cls(seq_output.numpy(), mask_ids_batch.numpy(), self.cls_pos)

elif self.pooling_strategy == 'reduce-mean':
elif self.pooling_strategy == 'mean':
output = reduce_mean(seq_output.numpy(), mask_ids_batch.numpy())
elif self.pooling_strategy == 'reduce-max':
elif self.pooling_strategy == 'max':
output = reduce_max(seq_output.numpy(), mask_ids_batch.numpy())
elif self.pooling_strategy == 'min':
output = reduce_min(seq_output.numpy(), mask_ids_batch.numpy())
else:
self.logger.error("pooling strategy not found: {}".format(self.pooling_strategy))
raise NotImplementedError
Expand Down Expand Up @@ -128,7 +137,7 @@ def post_init(self):

import tensorflow as tf
from transformers import TFBertModel, TFOpenAIGPTModel, TFGPT2Model, TFXLNetModel, TFXLMModel, \
TFDistilBertModel, TFRobertaModel, TFXLMRobertaModel
TFDistilBertModel, TFRobertaModel, TFXLMRobertaModel, TFCamembertModel, TFCTRLModel
model_dict = {
'bert-base-uncased': TFBertModel,
'openai-gpt': TFOpenAIGPTModel,
Expand All @@ -138,6 +147,8 @@ def post_init(self):
'distilbert-base-cased': TFDistilBertModel,
'roberta-base': TFRobertaModel,
'xlm-roberta-base': TFXLMRobertaModel,
'camembert-base': TFCamembertModel,
'ctrl': TFCTRLModel
}
self.model = model_dict[self.model_name].from_pretrained(self._tmp_model_path)
self._tensor_func = tf.constant
Expand All @@ -157,7 +168,7 @@ def post_init(self):

import torch
from transformers import BertModel, OpenAIGPTModel, GPT2Model, XLNetModel, XLMModel, DistilBertModel, \
RobertaModel, XLMRobertaModel
RobertaModel, XLMRobertaModel, FlaubertModel, CamembertModel, CTRLModel

model_dict = {
'bert-base-uncased': BertModel,
Expand All @@ -168,6 +179,9 @@ def post_init(self):
'distilbert-base-cased': DistilBertModel,
'roberta-base': RobertaModel,
'xlm-roberta-base': XLMRobertaModel,
'flaubert-base-cased': FlaubertModel,
'camembert-base': CamembertModel,
'ctrl': CTRLModel
}
self.model = model_dict[self.model_name].from_pretrained(self._tmp_model_path)
self._tensor_func = torch.tensor
Expand Down
39 changes: 39 additions & 0 deletions jina/executors/encoders/torchvision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import numpy as np

from . import BaseNumericEncoder
from ..decorators import batching, as_ndarray


class TorchEncoder(BaseNumericEncoder):
def __init__(self,
model_name: str,
channel_axis: int = 1,
*args, **kwargs):
super().__init__(*args, **kwargs)
self.model_name = model_name
self.channel_axis = channel_axis
self._default_channel_axis = 1

def post_init(self):
import torch
self._build_model()
device = 'cuda:0' if self.on_gpu else 'cpu'
self.model.to(torch.device(device))

@batching
@as_ndarray
def encode(self, data: 'np.ndarray', *args, **kwargs) -> 'np.ndarray':
if self.channel_axis != self._default_channel_axis:
data = np.moveaxis(data, self.channel_axis, self._default_channel_axis)
import torch
feature_map = self._get_features(torch.from_numpy(data.astype('float32'))).detach().numpy()
return self._get_pooling(feature_map)

def _build_model(self):
raise NotImplementedError

def _get_features(self, data):
raise NotImplementedError

def _get_pooling(self, feature_map):
return feature_map
34 changes: 9 additions & 25 deletions jina/executors/encoders/video/torchvision.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,26 @@
import numpy as np

from .. import BaseVideoEncoder
from ..torchvision import TorchEncoder


class TorchVideoEncoder(BaseVideoEncoder):
class VideoTorchEncoder(TorchEncoder):
"""
:class:`TorchVideoEncoder` encodes data from a ndarray, potentially B x T x (Channel x Height x Width) into an
:class:`VideoTorchEncoder` encodes data from a ndarray, potentially B x T x (Channel x Height x Width) into an
ndarray of `B x D`.
Internally, :class:`TorchVideoEncoder` wraps the models from `torchvision.models`.
Internally, :class:`VideoTorchEncoder` wraps the models from `torchvision.models`.
https://pytorch.org/docs/stable/torchvision/models.html
"""
def __init__(self,
model_name: str = 'r3d_18',
*args, **kwargs):
def __init__(self, model_name: str = 'r3d_18', *args, **kwargs):
"""
:param model_name: the name of the model. Supported models include ``r3d_18``, ``mc3_18``, ``r2plus1d_18``
"""
super().__init__(*args, **kwargs)
self.model_name = model_name
super().__init__(model_name, *args, **kwargs)
self._default_channel_axis = 2

def post_init(self):
def _build_model(self):
import torchvision.models.video as models
import torch
model = getattr(models, self.model_name)(pretrained=True)
self.model = model.eval()
device = 'cuda:0' if self.on_gpu else 'cpu'
self.model.to(torch.device(device))

def encode(self, data: 'np.ndarray', *args, **kwargs) -> 'np.ndarray':
"""
:param data: a `B x T x (Channel x Height x Width)` numpy ``ndarray``, `B` is the size of the batch
:return: a `B x D` numpy ``ndarray``, `D` is the output dimension
"""
import torch
return self._get_features(
torch.from_numpy(np.moveaxis(data.astype('float32'), 1, 2))).detach().numpy()
self.model = getattr(models, self.model_name)(pretrained=True).eval()

def _get_features(self, x):
x = self.model.stem(x)
Expand Down
Empty file added tests/executors/__init__.py
Empty file.
Empty file.
Empty file.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,6 @@ def test_dummy_seg(self):
fl.index(raw_bytes=random_docs(10), in_proto=True, callback=self.get_chunk_id)

def test_dummy_seg_random(self):
f = Flow().add(yaml_path='yaml/dummy-seg-random.yml')
f = Flow().add(yaml_path='../../yaml/dummy-seg-random.yml')
with f.build() as fl:
fl.index(raw_bytes=random_docs(10), in_proto=True, callback=self.collect_chunk_id)
Empty file.
Empty file.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,21 @@
import numpy as np

from jina.executors import BaseExecutor
from jina.executors.encoders.image.torchvision import TorchImageEncoder
from jina.executors.encoders.image.torchvision import ImageTorchEncoder
from tests import JinaTestCase


class MyTestCase(JinaTestCase):
@unittest.skipUnless('JINA_TEST_PRETRAINED' in os.environ, 'skip the pretrained test if not set')
def test_encoding_results(self):
encoder = TorchImageEncoder()
encoder = ImageTorchEncoder()
test_data = np.random.rand(2, 3, 224, 224)
encoded_data = encoder.encode(test_data)
self.assertEqual(encoded_data.shape, (2, 1280))

@unittest.skipUnless('JINA_TEST_PRETRAINED' in os.environ, 'skip the pretrained test if not set')
def test_save_and_load(self):
encoder = TorchImageEncoder()
encoder = ImageTorchEncoder()
test_data = np.random.rand(2, 3, 224, 224)
encoded_data_control = encoder.encode(test_data)
encoder.touch()
Expand All @@ -33,7 +33,7 @@ def test_save_and_load(self):

@unittest.skipUnless('JINA_TEST_PRETRAINED' in os.environ, 'skip the pretrained test if not set')
def test_save_and_load_config(self):
encoder = TorchImageEncoder()
encoder = ImageTorchEncoder()
encoder.save_config()
self.assertTrue(os.path.exists(encoder.config_abspath))
encoder_loaded = BaseExecutor.load_config(encoder.config_abspath)
Expand Down
Empty file.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Empty file.
File renamed without changes.
File renamed without changes.
Empty file.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,21 @@
import numpy as np

from jina.executors import BaseExecutor
from jina.executors.encoders.video.torchvision import TorchVideoEncoder
from jina.executors.encoders.video.torchvision import VideoTorchEncoder
from tests import JinaTestCase


class MyTestCase(JinaTestCase):
@unittest.skipIf(os.getenv('JINA_SKIP_TEST_PRETRAINED', True), 'skip the pretrained test if not set')
@unittest.skipUnless('JINA_TEST_PRETRAINED' in os.environ, 'skip the pretrained test if not set')
def test_encoding_results(self):
encoder = TorchVideoEncoder()
encoder = VideoTorchEncoder()
test_data = np.random.rand(2, 3, 3, 112, 112)
encoded_data = encoder.encode(test_data)
self.assertEqual(encoded_data.shape, (2, 512))

@unittest.skipIf(os.getenv('JINA_SKIP_TEST_PRETRAINED', True), 'skip the pretrained test if not set')
@unittest.skipUnless('JINA_TEST_PRETRAINED' in os.environ, 'skip the pretrained test if not set')
def test_save_and_load(self):
encoder = TorchVideoEncoder()
encoder = VideoTorchEncoder()
test_data = np.random.rand(2, 3, 3, 112, 112)
encoded_data_control = encoder.encode(test_data)
encoder.touch()
Expand All @@ -31,9 +31,9 @@ def test_save_and_load(self):
self.add_tmpfile(
encoder.config_abspath, encoder.save_abspath, encoder_loaded.config_abspath, encoder_loaded.save_abspath)

@unittest.skipIf(os.getenv('JINA_SKIP_TEST_PRETRAINED', True), 'skip the pretrained test if not set')
@unittest.skipUnless('JINA_TEST_PRETRAINED' in os.environ, 'skip the pretrained test if not set')
def test_save_and_load_config(self):
encoder = TorchVideoEncoder()
encoder = VideoTorchEncoder()
encoder.save_config()
self.assertTrue(os.path.exists(encoder.config_abspath))
encoder_loaded = BaseExecutor.load_config(encoder.config_abspath)
Expand Down
Empty file.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def test_compositional_dump(self):
self.assertTrue(os.path.exists(a.config_abspath))

def test_compound_from_yaml(self):
a = BaseExecutor.load_config('yaml/npvec.yml')
a = BaseExecutor.load_config('../yaml/npvec.yml')
for c in a.components:
self.add_tmpfile(c.index_abspath)
self.assertTrue(isinstance(a, CompoundExecutor))
Expand Down

0 comments on commit 93382b0

Please sign in to comment.