Skip to content

Commit

Permalink
fix: fix transformers on gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
nan-wang committed May 6, 2020
1 parent 8898288 commit 1ddf248
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 17 deletions.
33 changes: 26 additions & 7 deletions jina/executors/encoders/nlp/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,21 +95,23 @@ def encode(self, data: 'np.ndarray', *args, **kwargs) -> 'np.ndarray':
mask_ids = [0 if t == self.tokenizer.pad_token_id else 1 for t in token_ids]
token_ids_batch.append(token_ids)
mask_ids_batch.append(mask_ids)
token_ids_batch = self._tensor_func(token_ids_batch)
mask_ids_batch = self._tensor_func(mask_ids_batch)
token_ids_batch = self.array2tensor(token_ids_batch)
mask_ids_batch = self.array2tensor(mask_ids_batch)
with self._sess_func():
seq_output, *extra_output = self.model(token_ids_batch, attention_mask=mask_ids_batch)
_mask_ids_batch = self.tensor2array(mask_ids_batch)
_seq_output = self.tensor2array(seq_output)
if self.pooling_strategy == 'cls':
if self.model_name in ('bert-base-uncased', 'roberta-base'):
output = extra_output[0].numpy()
output = self.tensor2array(extra_output[0])
else:
output = reduce_cls(seq_output.numpy(), mask_ids_batch.numpy(), self.cls_pos)
output = reduce_cls(_seq_output, _mask_ids_batch, self.cls_pos)
elif self.pooling_strategy == 'mean':
output = reduce_mean(seq_output.numpy(), mask_ids_batch.numpy())
output = reduce_mean(_seq_output, _mask_ids_batch)
elif self.pooling_strategy == 'max':
output = reduce_max(seq_output.numpy(), mask_ids_batch.numpy())
output = reduce_max(_seq_output, _mask_ids_batch)
elif self.pooling_strategy == 'min':
output = reduce_min(seq_output.numpy(), mask_ids_batch.numpy())
output = reduce_min(_seq_output, _mask_ids_batch)
else:
self.logger.error("pooling strategy not found: {}".format(self.pooling_strategy))
raise NotImplementedError
Expand Down Expand Up @@ -137,6 +139,12 @@ def build_model(self):
def _build_model(self):
raise NotImplementedError

def array2tensor(self, array):
return self._tensor_func(array)

def tensor2array(self, tensor):
return tensor.numpy()


class TransformerTFEncoder(BaseTFExecutor, BaseTransformerEncoder):
"""
Expand Down Expand Up @@ -193,3 +201,14 @@ def _build_model(self):
self._sess_func = torch.no_grad
if self.model_name in ('xlnet-base-cased', 'openai-gpt', 'gpt2', 'xlm-mlm-enfr-1024'):
self.model.resize_token_embeddings(len(self.tokenizer))

def array2tensor(self, array):
tensor = super().array2tensor(array)
if self.on_gpu:
tensor = tensor.cuda()
return tensor

def tensor2array(self, tensor):
if self.on_gpu:
tensor = tensor.cpu()
return tensor.numpy()
8 changes: 4 additions & 4 deletions tests/executors/encoders/nlp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import numpy as np

from jina.executors import BaseExecutor
from tests import JinaTestCase
from tests.executors import ExecutorTestCase


class NlpTestCase(JinaTestCase):
class NlpTestCase(ExecutorTestCase):
@property
def workspace(self):
return os.path.join(os.environ['TEST_WORKDIR'], 'test_tmp')
Expand All @@ -29,12 +29,12 @@ def input_dim(self, input_dim):
self._input_dim = input_dim

def get_encoder(self):
encoder = self._get_encoder()
encoder = self._get_encoder(self.metas)
encoder.workspace = self.workspace
self.add_tmpfile(encoder.workspace)
return encoder

def _get_encoder(self):
def _get_encoder(self, metas):
raise NotImplementedError

@unittest.skipUnless('JINA_TEST_PRETRAINED' in os.environ, 'skip the pretrained test if not set')
Expand Down
16 changes: 10 additions & 6 deletions tests/executors/encoders/nlp/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,19 @@


class PytorchTestCase(NlpTestCase):
def _get_encoder(self):
encoder = TransformerTorchEncoder(model_name='bert-base-uncased', pooling_strategy='cls')
return encoder
def _get_encoder(self, metas):
return TransformerTorchEncoder(
model_name='bert-base-uncased',
pooling_strategy='cls',
metas=metas)


class TfTestCase(NlpTestCase):
def _get_encoder(self):
encoder = TransformerTFEncoder(model_name='bert-base-uncased', pooling_strategy='cls')
return encoder
def _get_encoder(self, metas):
return TransformerTFEncoder(
model_name='bert-base-uncased',
pooling_strategy='cls',
metas=metas)


if __name__ == '__main__':
Expand Down

0 comments on commit 1ddf248

Please sign in to comment.