diff --git a/jina/executors/encoders/nlp/transformer.py b/jina/executors/encoders/nlp/transformer.py index 95cd5614e80c5..5ec29e7f8a720 100644 --- a/jina/executors/encoders/nlp/transformer.py +++ b/jina/executors/encoders/nlp/transformer.py @@ -57,42 +57,42 @@ def post_init(self): raise ValueError try: - import tensorflow as tf - from transformers import TFBertModel, TFOpenAIGPTModel, TFGPT2Model, TFXLNetModel, TFXLMModel, \ - TFDistilBertModel, TFRobertaModel, TFXLMRobertaModel - tf_model_dict = { - 'bert-base-uncased': TFBertModel, - 'openai-gpt': TFOpenAIGPTModel, - 'gpt2': TFGPT2Model, - 'xlnet-base-cased': TFXLNetModel, - 'xlm-mlm-enfr-1024': TFXLMModel, - 'distilbert-base-cased': TFDistilBertModel, - 'roberta-base': TFRobertaModel, - 'xlm-roberta-base': TFXLMRobertaModel, + import torch + from transformers import BertModel, OpenAIGPTModel, GPT2Model, XLNetModel, XLMModel, DistilBertModel, \ + RobertaModel, XLMRobertaModel + + model_dict = { + 'bert-base-uncased': BertModel, + 'openai-gpt': OpenAIGPTModel, + 'gpt2': GPT2Model, + 'xlnet-base-cased': XLNetModel, + 'xlm-mlm-enfr-1024': XLMModel, + 'distilbert-base-cased': DistilBertModel, + 'roberta-base': RobertaModel, + 'xlm-roberta-base': XLMRobertaModel, } - model_class = tf_model_dict[self.model_name] - self._tensor_func = tf.constant - self._sess_func = tf.GradientTape + model_class = model_dict[self.model_name] + self._tensor_func = torch.tensor + self._sess_func = torch.no_grad except: try: - import torch - from transformers import BertModel, OpenAIGPTModel, GPT2Model, XLNetModel, XLMModel, DistilBertModel, \ - RobertaModel, XLMRobertaModel - - model_dict = { - 'bert-base-uncased': BertModel, - 'openai-gpt': OpenAIGPTModel, - 'gpt2': GPT2Model, - 'xlnet-base-cased': XLNetModel, - 'xlm-mlm-enfr-1024': XLMModel, - 'distilbert-base-cased': DistilBertModel, - 'roberta-base': RobertaModel, - 'xlm-roberta-base': XLMRobertaModel, + import tensorflow as tf + from transformers import TFBertModel, TFOpenAIGPTModel, TFGPT2Model, TFXLNetModel, TFXLMModel, \ + TFDistilBertModel, TFRobertaModel, TFXLMRobertaModel + tf_model_dict = { + 'bert-base-uncased': TFBertModel, + 'openai-gpt': TFOpenAIGPTModel, + 'gpt2': TFGPT2Model, + 'xlnet-base-cased': TFXLNetModel, + 'xlm-mlm-enfr-1024': TFXLMModel, + 'distilbert-base-cased': TFDistilBertModel, + 'roberta-base': TFRobertaModel, + 'xlm-roberta-base': TFXLMRobertaModel, } - model_class = model_dict[self.model_name] - self._tensor_func = torch.tensor - self._sess_func = torch.no_grad + model_class = tf_model_dict[self.model_name] + self._tensor_func = tf.constant + self._sess_func = tf.GradientTape except: raise ModuleNotFoundError('Tensorflow or Pytorch is required!')