Skip to content

Commit

Permalink
fix(executor): add tf/pytorch detect in transformertextencoder
Browse files Browse the repository at this point in the history
  • Loading branch information
Xiong Ma committed Apr 3, 2020
1 parent 7d4e714 commit be1a4ec
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 38 deletions.
92 changes: 55 additions & 37 deletions jina/executors/encoders/nlp/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ def __init__(self,
pooling_strategy: str = 'reduce-mean',
max_length: int = 64,
encoder_abspath: str = '',
backend: str = 'pytorch',
*args, **kwargs):
"""
Expand All @@ -28,46 +27,74 @@ def __init__(self,
:param max_length: the max length to truncate the tokenized sequences to.
:param encoder_abspath: the absolute saving path of the encoder. If a valid path is given, the encoder will be
loaded from the given path.
:param backend: whether use tensorflow to load pretraining model, just support tensorflow or pytorch to load
model
"""

super().__init__(*args, **kwargs)
self.model_name = model_name
self.pooling_strategy = pooling_strategy
self.model = None
self.tokenizer = None
if backend not in ('tensorflow', 'pytorch'):
raise ValueError('unknown backend: {}'.format(backend))
self.backend = backend
self.max_length = max_length
self.encoder_abspath = encoder_abspath

def post_init(self):
from transformers import BertModel, BertTokenizer, OpenAIGPTModel, \
OpenAIGPTTokenizer, GPT2Model, GPT2Tokenizer, \
XLNetModel, XLNetTokenizer, XLMModel, \
XLMTokenizer, DistilBertModel, DistilBertTokenizer, RobertaModel, \
RobertaTokenizer, XLMRobertaModel, XLMRobertaTokenizer, TFBertModel, \
TFOpenAIGPTModel, TFGPT2Model, TFXLNetModel, TFXLMModel, TFDistilBertModel, \
TFRobertaModel, TFXLMRobertaModel

model_dict = {
'bert-base-uncased': (TFBertModel, BertModel, BertTokenizer),
'openai-gpt': (TFOpenAIGPTModel, OpenAIGPTModel, OpenAIGPTTokenizer),
'gpt2': (TFGPT2Model, GPT2Model, GPT2Tokenizer),
'xlnet-base-cased': (TFXLNetModel, XLNetModel, XLNetTokenizer),
'xlm-mlm-enfr-1024': (TFXLMModel, XLMModel, XLMTokenizer),
'distilbert-base-cased': (TFDistilBertModel, DistilBertModel, DistilBertTokenizer),
'roberta-base': (TFRobertaModel, RobertaModel, RobertaTokenizer),
'xlm-roberta-base': (TFXLMRobertaModel, XLMRobertaModel, XLMRobertaTokenizer)
from transformers import BertTokenizer, OpenAIGPTTokenizer, GPT2Tokenizer, \
XLNetTokenizer, XLMTokenizer, DistilBertTokenizer, RobertaTokenizer, XLMRobertaTokenizer

tokenizer_dict = {
'bert-base-uncased': BertTokenizer,
'openai-gpt': OpenAIGPTTokenizer,
'gpt2': GPT2Tokenizer,
'xlnet-base-cased': XLNetTokenizer,
'xlm-mlm-enfr-1024': XLMTokenizer,
'distilbert-base-cased': DistilBertTokenizer,
'roberta-base': RobertaTokenizer,
'xlm-roberta-base': XLMRobertaTokenizer
}

if self.model_name not in model_dict:
self.logger.error('{} not in our supports: {}'.format(self.model_name, ','.join(model_dict.keys())))
if self.model_name not in tokenizer_dict:
self.logger.error('{} not in our supports: {}'.format(self.model_name, ','.join(tokenizer_dict.keys())))
raise ValueError

tf_model_class, model_class, tokenizer_class = model_dict[self.model_name]
try:
import tensorflow as tf
from transformers import TFBertModel, TFOpenAIGPTModel, TFGPT2Model, TFXLNetModel, TFXLMModel, \
TFDistilBertModel, TFRobertaModel, TFXLMRobertaModel
tf_model_dict = {
'bert-base-uncased': TFBertModel,
'openai-gpt': TFOpenAIGPTModel,
'gpt2': TFGPT2Model,
'xlnet-base-cased': TFXLNetModel,
'xlm-mlm-enfr-1024': TFXLMModel,
'distilbert-base-cased': TFDistilBertModel,
'roberta-base': TFRobertaModel,
'xlm-roberta-base': TFXLMRobertaModel,
}
model_class = tf_model_dict[self.model_name]
self._tensor_func = tf.constant
self._sess_func = tf.GradientTape

except:
try:
import torch
from transformers import BertModel, OpenAIGPTModel, GPT2Model, XLNetModel, XLMModel, DistilBertModel, \
RobertaModel, XLMRobertaModel

model_dict = {
'bert-base-uncased': BertModel,
'openai-gpt': OpenAIGPTModel,
'gpt2': GPT2Model,
'xlnet-base-cased': XLNetModel,
'xlm-mlm-enfr-1024': XLMModel,
'distilbert-base-cased': DistilBertModel,
'roberta-base': RobertaModel,
'xlm-roberta-base': XLMRobertaModel,
}
model_class = model_dict[self.model_name]
self._tensor_func = torch.tensor
self._sess_func = torch.no_grad
except:
raise ModuleNotFoundError('Tensorflow or Pytorch is required!')

if self.encoder_abspath:
if not os.path.exists(self.encoder_abspath):
Expand All @@ -78,18 +105,9 @@ def post_init(self):
else:
tmp = self.model_name

self.model = model_class.from_pretrained(tmp)
tokenizer_class = tokenizer_dict[self.model_name]
self.tokenizer = tokenizer_class.from_pretrained(tmp)
if self.backend == 'tensorflow':
import tensorflow as tf
self.model = tf_model_class.from_pretrained(tmp)
self._tensor_func = tf.constant
self._sess_func = tf.GradientTape

else:
import torch
self.model = model_class.from_pretrained(tmp)
self._tensor_func = torch.tensor
self._sess_func = torch.no_grad

self.tokenizer.padding_side = 'right'

Expand Down
2 changes: 1 addition & 1 deletion tests/test_exec_encoder_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_pytorch_encoding_results(self):

@unittest.skipUnless('JINA_TEST_PRETRAINED' in os.environ, 'skip the pretrained test if not set')
def test_tf_encoding_results(self):
encoder = TransformerTextEncoder(model_name='bert-base-uncased', backend='tensorflow')
encoder = TransformerTextEncoder(model_name='bert-base-uncased')
test_data = np.array(['a', 'b', 'xy'])
encoded_data = encoder.encode(test_data)
self.assertEqual(encoded_data.shape[0], 3)
Expand Down

0 comments on commit be1a4ec

Please sign in to comment.