Skip to content

Commit

Permalink
fix: rename the classes
Browse files Browse the repository at this point in the history
  • Loading branch information
nan-wang committed May 5, 2020
1 parent 0052ccd commit a40b185
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 10 deletions.
25 changes: 21 additions & 4 deletions jina/executors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,20 +560,37 @@ def __call__(self, req_type, *args, **kwargs):
raise NoDriverForRequest(req_type)


class BaseFramewordExecutor(BaseExecutor):
class BaseFrameworkExecutor(BaseExecutor):
"""
:class:`BaseFrameworkExecutor` is the base class for the executors using other frameworks internally, including
`tensorflow`, `pytorch`, `onnx`, and, `paddlepaddle`.
..notes:
The derived classes must implement `build_model()` and `set_device()` methods.
"""
def post_init(self):
super().post_init()
self.build_model()
self.set_device()

def build_model(self):
"""
Build the model with the framework set by `self._backend`.
"""
raise NotImplementedError

def set_device(self):
"""
Set the device on which the model will be executed.
..notes:
In the case of using GPUs, we only use the first gpu from the visible gpus. To specify which gpu to use,
please use the environment variable `CUDA_VISIBLE_DEVICES`.
"""
raise NotImplementedError


class BaseTorchExecutor(BaseFramewordExecutor):
class BaseTorchExecutor(BaseFrameworkExecutor):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._backend = 'pytorch'
Expand All @@ -591,7 +608,7 @@ def set_device(self):
self.model.set_providers(self._device)


class BaseTfExecutor(BaseFramewordExecutor):
class BaseTFExecutor(BaseFrameworkExecutor):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._backend = 'tensorflow'
Expand All @@ -601,7 +618,7 @@ def set_device(self):
tf.config.experimental.set_visible_devices(self._device)


class BasePaddleExecutor(BaseFramewordExecutor):
class BasePaddleExecutor(BaseFrameworkExecutor):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._backend = 'paddlepaddle'
Expand Down
4 changes: 2 additions & 2 deletions jina/executors/encoders/image/tfkeras.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

from .. import BaseImageEncoder
from ...decorators import batching, as_ndarray
from ... import BaseTfExecutor
from ... import BaseTFExecutor


class KerasImageEncoder(BaseImageEncoder, BaseTfExecutor):
class KerasImageEncoder(BaseImageEncoder, BaseTFExecutor):
"""
:class:`KerasImageEncoder` encodes data from a ndarray, potentially B x (Channel x Height x Width) into a
ndarray of `B x D`.
Expand Down
6 changes: 3 additions & 3 deletions jina/executors/encoders/nlp/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from .. import BaseTextEncoder
from ..helper import reduce_mean, reduce_max, reduce_min, reduce_cls
from ...decorators import batching, as_ndarray
from ... import BaseFramewordExecutor, BaseTfExecutor, BaseTorchExecutor
from ... import BaseFrameworkExecutor, BaseTFExecutor, BaseTorchExecutor


class BaseTransformerEncoder(BaseFramewordExecutor):
class BaseTransformerEncoder(BaseFrameworkExecutor):
"""
:class:`TransformerTextEncoder` encodes data from an array of string in size `B` into an ndarray in size `B x D`.
"""
Expand Down Expand Up @@ -138,7 +138,7 @@ def _build_model(self):
raise NotImplementedError


class TransformerTFEncoder(BaseTfExecutor, BaseTransformerEncoder):
class TransformerTFEncoder(BaseTFExecutor, BaseTransformerEncoder):
"""
Internally, TransformerTFEncoder wraps the tensorflow-version of transformers from huggingface.
"""
Expand Down
3 changes: 2 additions & 1 deletion tests/executors/encoders/image/test_paddlehub.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
class MyTestCase(ImageTestCase):
def _get_encoder(self):
self.target_output_dim = 2048
return ImagePaddlehubEncoder()
self.input_dim = 224
return ImagePaddlehubEncoder(on_gpu=True)


if __name__ == '__main__':
Expand Down
1 change: 1 addition & 0 deletions tests/executors/encoders/image/test_tfkeras.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
class MyTestCase(ImageTestCase):
def _get_encoder(self):
self.target_output_dim = 1280
self.input_dim = 224
return KerasImageEncoder(channel_axis=1)


Expand Down

0 comments on commit a40b185

Please sign in to comment.