Skip to content

Commit

Permalink
refactor(encoder): refactoring the paddlehub encoder for nlp
Browse files Browse the repository at this point in the history
  • Loading branch information
nan-wang committed Apr 5, 2020
1 parent 7a68b83 commit a73690e
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 155 deletions.
146 changes: 0 additions & 146 deletions jina/executors/encoders/nlp/ernie.py

This file was deleted.

64 changes: 64 additions & 0 deletions jina/executors/encoders/nlp/paddlehub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import os

import numpy as np

from .. import BaseTextEncoder
from ...decorators import batching, as_ndarray


class TextPaddlehubEncoder(BaseTextEncoder):
"""
:class:`TextPaddlehubEncoder` encodes data from an array of string in size `B` into a ndarray in size `B x D`.
Internally, :class:`TextPaddlehubEncoder` wraps the Ernie module from paddlehub.
https://github.com/PaddlePaddle/PaddleHub
"""

def __init__(self,
model_name: str = 'ernie_tiny',
max_length: int = 128,
*args,
**kwargs):
"""
:param model_name: the name of the model. Supported models include
``ernie``, ``ernie_tiny``, ``ernie_v2_eng_base``, ``ernie_v2_eng_large``,
``bert_chinese_L-12_H-768_A-12``, ``bert_multi_cased_L-12_H-768_A-12``,
``bert_multi_uncased_L-12_H-768_A-12``, ``bert_uncased_L-12_H-768_A-12``,
``bert_uncased_L-24_H-1024_A-16``,
``chinese-bert-wwm``, ``chinese-bert-wwm-ext``,
``chinese-electra-base``, ``chinese-electra-small``,
``chinese-roberta-wwm-ext``, ``chinese-roberta-wwm-ext-large``,
``rbt3``, ``rbtl3``
:param max_length: the max length to truncate the tokenized sequences to.
For models' details refer to
https://www.paddlepaddle.org.cn/hublist?filter=en_category&value=SemanticModel
"""
super().__init__(*args, **kwargs)
self.model_name = model_name
self.max_seq_length = max_length
self.tokenizer = None

def post_init(self):
import paddlehub as hub
self.model = hub.Module(name=self.model_name)
self.model.MAX_SEQ_LEN = self.max_seq_length

@batching
@as_ndarray
def encode(self, data: 'np.ndarray', *args, **kwargs) -> 'np.ndarray':
"""
:param data: a 1d array of string type in size `B`
:return: an ndarray in size `B x D`
"""
results = []
_raw_results = self.model.get_embedding(
texts=np.atleast_2d(data).reshape(-1, 1).tolist(), use_gpu=self.on_gpu, batch_size=data.shape[0])
for emb in _raw_results:
_pooled_feature, _seq_feature = emb
results.append(_pooled_feature)
return np.array(results)

def close(self):
pass
3 changes: 3 additions & 0 deletions jina/executors/encoders/paddlehub.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np

from . import BaseNumericEncoder
from ..decorators import batching, as_ndarray


class PaddlehubEncoder(BaseNumericEncoder):
Expand Down Expand Up @@ -31,6 +32,8 @@ def post_init(self):
def get_inputs_and_outputs_name(self, input_dict, output_dict):
raise NotImplementedError

@batching
@as_ndarray
def encode(self, data: 'np.ndarray', *args, **kwargs) -> 'np.ndarray':
"""
Expand Down
18 changes: 9 additions & 9 deletions tests/test_exec_encoder_ernie.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,22 @@
import numpy as np

from jina.executors import BaseExecutor
from jina.executors.encoders.nlp.ernie import ErnieTextEncoder
from jina.executors.encoders.nlp.paddlehub import TextPaddlehubEncoder
from tests import JinaTestCase


class MyTestCase(JinaTestCase):
@unittest.skipUnless('JINA_TEST_PRETRAINED' in os.environ, 'skip the pretrained test if not set')
def test_encoding_results(self):
encoder = ErnieTextEncoder(max_length=10, workspace=os.environ['TEST_WORKDIR'])
encoder = TextPaddlehubEncoder(max_length=10, workspace=os.environ['TEST_WORKDIR'])
test_data = np.array(['it is a good day!', 'the dog sits on the floor.'])
encoded_data = encoder.encode(test_data)
self.assertEqual(encoded_data.shape[0], 2)
self.assertIs(type(encoded_data), np.ndarray)
self.add_tmpfile(encoder.vocab_abspath)

@unittest.skipUnless('JINA_TEST_PRETRAINED' in os.environ, 'skip the pretrained test if not set')
def test_save_and_load(self):
encoder = ErnieTextEncoder(
encoder = TextPaddlehubEncoder(
max_length=10, workspace=os.environ['TEST_WORKDIR'])
encoder.save_config()
self.assertTrue(os.path.exists(encoder.config_abspath))
Expand All @@ -33,22 +32,23 @@ def test_save_and_load(self):
encoder_loaded = BaseExecutor.load(encoder.save_abspath)
encoded_data_test = encoder_loaded.encode(test_data)

self.assertEqual(encoder_loaded.vocab_abspath, encoder.vocab_abspath)
self.assertEqual(encoder_loaded.max_seq_length, encoder.max_seq_length)
np.testing.assert_array_equal(encoded_data_control, encoded_data_test)

self.add_tmpfile(
encoder.config_abspath, encoder.save_abspath, encoder_loaded.config_abspath, encoder_loaded.save_abspath, encoder.vocab_abspath)
encoder.config_abspath, encoder.save_abspath, encoder_loaded.config_abspath, encoder_loaded.save_abspath)

@unittest.skipUnless('JINA_TEST_PRETRAINED' in os.environ, 'skip the pretrained test if not set')
def test_save_and_load_config(self):
encoder = ErnieTextEncoder(
encoder = TextPaddlehubEncoder(
max_length=10, workspace=os.environ['TEST_WORKDIR'])
encoder.save_config()
self.assertTrue(os.path.exists(encoder.config_abspath))

encoder_loaded = BaseExecutor.load_config(encoder.config_abspath)
self.assertEqual(encoder_loaded.vocab_abspath, encoder.vocab_abspath)
self.assertEqual(encoder_loaded.max_seq_length, encoder.max_seq_length)

self.add_tmpfile(encoder_loaded.config_abspath, encoder_loaded.save_abspath, encoder.vocab_abspath)
self.add_tmpfile(encoder_loaded.config_abspath, encoder_loaded.save_abspath)


if __name__ == '__main__':
Expand Down

0 comments on commit a73690e

Please sign in to comment.