Skip to content

Commit

Permalink
fix(executors): fix the mro issue
Browse files Browse the repository at this point in the history
  • Loading branch information
nan-wang committed May 8, 2020
1 parent 96e9294 commit bda92d8
Show file tree
Hide file tree
Showing 11 changed files with 80 additions and 60 deletions.
100 changes: 57 additions & 43 deletions jina/executors/encoders/frameworks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,15 @@

from . import BaseNumericEncoder
from ..decorators import batching, as_ndarray
from ..frameworks import BaseOnnxExecutor, BasePaddleExecutor, BaseTorchExecutor
from ..frameworks import BaseOnnxExecutor, BasePaddleExecutor, BaseTorchExecutor, BaseTFExecutor, BaseFrameworkExecutor
from ...helper import is_url


class BaseOnnxEncoder(BaseOnnxExecutor):
class BaseFrameworkEncoder(BaseFrameworkExecutor, BaseNumericEncoder):
pass


class BaseOnnxEncoder(BaseOnnxExecutor, BaseFrameworkEncoder):

def __init__(self, output_feature: str, model_path: str = None, *args, **kwargs):
"""
Expand Down Expand Up @@ -56,7 +60,55 @@ def _append_outputs(input_fn, outputs_name_to_append, output_fn):
onnx.save(model, output_fn)


class BaseCVPaddlehubEncoder(BasePaddleExecutor, BaseNumericEncoder):
class BaseTorchEncoder(BaseTorchExecutor, BaseFrameworkEncoder):
""""
:class:`BaseTorchEncoder` implements the common part for :class:`ImageTorchEncoder` and :class:`VideoTorchEncoder`.
..warning::
:class:`BaseTorchEncoder` is not intented to be used to do the real encoding.
"""

def __init__(self,
model_name: str = '',
channel_axis: int = 1,
*args, **kwargs):
super().__init__(*args, **kwargs)
self.model_name = model_name
self.channel_axis = channel_axis
self._default_channel_axis = 1

@batching
@as_ndarray
def encode(self, data: 'np.ndarray', *args, **kwargs) -> 'np.ndarray':
import numpy as np
if self.channel_axis != self._default_channel_axis:
data = np.moveaxis(data, self.channel_axis, self._default_channel_axis)
import torch
_input = torch.from_numpy(data.astype('float32'))
if self.on_gpu:
_input = _input.cuda()
_feature = self._get_features(_input).detach()
if self.on_gpu:
_feature = _feature.cpu()
_feature = _feature.numpy()
return self._get_pooling(_feature)

def _get_features(self, data):
raise NotImplementedError

def _get_pooling(self, feature_map):
return feature_map


class BaseTFEncoder(BaseTFExecutor, BaseFrameworkEncoder):
pass


class BasePaddlehubEncoder(BasePaddleExecutor, BaseFrameworkEncoder):
pass


class BaseCVPaddlehubEncoder(BasePaddlehubEncoder):
"""
:class:`BaseCVPaddlehubEncoder` implements the common parts for :class:`ImagePaddlehubEncoder` and
:class:`VideoPaddlehubEncoder`.
Expand All @@ -66,8 +118,8 @@ class BaseCVPaddlehubEncoder(BasePaddleExecutor, BaseNumericEncoder):
"""

def __init__(self,
model_name: str,
output_feature: str,
model_name: str = None,
output_feature: str = None,
pool_strategy: str = None,
channel_axis: int = -3,
*args,
Expand Down Expand Up @@ -119,41 +171,3 @@ def get_pooling(self, data: 'np.ndarray', axis=None) -> 'np.ndarray':
return getattr(np, self.pool_strategy)(data, axis=_reduce_axis)


class BaseTorchEncoder(BaseTorchExecutor):
""""
:class:`BaseTorchEncoder` implements the common part for :class:`ImageTorchEncoder` and :class:`VideoTorchEncoder`.
..warning::
:class:`BaseTorchEncoder` is not intented to be used to do the real encoding.
"""

def __init__(self,
model_name: str,
channel_axis: int = 1,
*args, **kwargs):
super().__init__(*args, **kwargs)
self.model_name = model_name
self.channel_axis = channel_axis
self._default_channel_axis = 1

@batching
@as_ndarray
def encode(self, data: 'np.ndarray', *args, **kwargs) -> 'np.ndarray':
import numpy as np
if self.channel_axis != self._default_channel_axis:
data = np.moveaxis(data, self.channel_axis, self._default_channel_axis)
import torch
_input = torch.from_numpy(data.astype('float32'))
if self.on_gpu:
_input = _input.cuda()
_feature = self._get_features(_input).detach()
if self.on_gpu:
_feature = _feature.cpu()
_feature = _feature.numpy()
return self._get_pooling(_feature)

def _get_features(self, data):
raise NotImplementedError

def _get_pooling(self, feature_map):
return feature_map
3 changes: 2 additions & 1 deletion jina/executors/encoders/image/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@

import numpy as np

from .. import BaseImageEncoder
from ..frameworks import BaseOnnxEncoder
from ...decorators import batching, as_ndarray


class OnnxImageEncoder(BaseOnnxEncoder):
class OnnxImageEncoder(BaseImageEncoder, BaseOnnxEncoder):
"""
:class:`OnnxImageEncoder` encodes data from a ndarray, potentially B x (Channel x Height x Width) into a
ndarray of `B x D`.
Expand Down
3 changes: 2 additions & 1 deletion jina/executors/encoders/image/paddlehub.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
__copyright__ = "Copyright (c) 2020 Jina AI Limited. All rights reserved."
__license__ = "Apache-2.0"

from .. import BaseImageEncoder
from ..frameworks import BaseCVPaddlehubEncoder


class ImagePaddlehubEncoder(BaseCVPaddlehubEncoder):
class ImagePaddlehubEncoder(BaseImageEncoder, BaseCVPaddlehubEncoder):
"""
:class:`ImagePaddlehubEncoder` encodes data from a ndarray, potentially B x (Channel x Height x Width) into a
ndarray of `B x D`.
Expand Down
4 changes: 2 additions & 2 deletions jina/executors/encoders/image/tfkeras.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

from .. import BaseImageEncoder
from ...decorators import batching, as_ndarray
from ...frameworks import BaseTFExecutor
from ..frameworks import BaseTFEncoder


class KerasImageEncoder(BaseTFExecutor, BaseImageEncoder):
class KerasImageEncoder(BaseImageEncoder, BaseTFEncoder):
"""
:class:`KerasImageEncoder` encodes data from a ndarray, potentially B x (Channel x Height x Width) into a
ndarray of `B x D`.
Expand Down
3 changes: 2 additions & 1 deletion jina/executors/encoders/image/torchvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@

import numpy as np

from .. import BaseImageEncoder
from ..frameworks import BaseTorchEncoder


class ImageTorchEncoder(BaseTorchEncoder):
class ImageTorchEncoder(BaseImageEncoder, BaseTorchEncoder):
"""
:class:`ImageTorchEncoder` encodes data from a ndarray, potentially B x (Channel x Height x Width) into a
ndarray of `B x D`.
Expand Down
4 changes: 2 additions & 2 deletions jina/executors/encoders/nlp/farm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

from .. import BaseTextEncoder
from ...decorators import batching, as_ndarray
from ...frameworks import BaseTorchExecutor
from ..frameworks import BaseTorchEncoder


class FarmTextEncoder(BaseTorchExecutor, BaseTextEncoder):
class FarmTextEncoder(BaseTextEncoder, BaseTorchEncoder):
"""FARM-based text encoder: (Framework for Adapting Representation Models)
https://github.com/deepset-ai/FARM
Expand Down
4 changes: 2 additions & 2 deletions jina/executors/encoders/nlp/flair.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

from .. import BaseTextEncoder
from ...decorators import batching, as_ndarray
from ...frameworks import BaseTorchExecutor
from ..frameworks import BaseTorchEncoder


class FlairTextEncoder(BaseTorchExecutor, BaseTextEncoder):
class FlairTextEncoder(BaseTextEncoder, BaseTorchEncoder):
"""
:class:`FlairTextEncoder` encodes data from an array of string in size `B` into a ndarray in size `B x D`.
Internally, :class:`FlairTextEncoder` wraps the DocumentPoolEmbeddings from Flair.
Expand Down
4 changes: 2 additions & 2 deletions jina/executors/encoders/nlp/paddlehub.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

from .. import BaseTextEncoder
from ...decorators import batching, as_ndarray
from ...frameworks import BasePaddleExecutor
from ..frameworks import BasePaddlehubEncoder


class TextPaddlehubEncoder(BasePaddleExecutor, BaseTextEncoder):
class TextPaddlehubEncoder(BaseTextEncoder, BasePaddlehubEncoder):
"""
:class:`TextPaddlehubEncoder` encodes data from an array of string in size `B` into a ndarray in size `B x D`.
Internally, :class:`TextPaddlehubEncoder` wraps the Ernie module from paddlehub.
Expand Down
8 changes: 4 additions & 4 deletions jina/executors/encoders/nlp/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

from ..helper import reduce_mean, reduce_max, reduce_min, reduce_cls
from ...decorators import batching, as_ndarray
from ...frameworks import BaseFrameworkExecutor, BaseTorchExecutor, BaseTFExecutor
from ..frameworks import BaseFrameworkEncoder, BaseTFEncoder, BaseTorchEncoder


class BaseTransformerEncoder(BaseFrameworkExecutor):
class BaseTransformerEncoder(BaseFrameworkEncoder):
"""
:class:`TransformerTextEncoder` encodes data from an array of string in size `B` into an ndarray in size `B x D`.
"""
Expand Down Expand Up @@ -174,7 +174,7 @@ def get_tensor_func(self):
raise NotImplementedError


class TransformerTFEncoder(BaseTransformerEncoder, BaseTFExecutor):
class TransformerTFEncoder(BaseTransformerEncoder, BaseTFEncoder):
"""
Internally, TransformerTFEncoder wraps the tensorflow-version of transformers from huggingface.
"""
Expand Down Expand Up @@ -209,7 +209,7 @@ def get_tensor_func(self):
return tf.constant


class TransformerTorchEncoder(BaseTransformerEncoder, BaseTorchExecutor):
class TransformerTorchEncoder(BaseTransformerEncoder, BaseTorchEncoder):
"""
Internally, TransformerTorchEncoder wraps the pytorch-version of transformers from huggingface.
"""
Expand Down
4 changes: 3 additions & 1 deletion jina/executors/encoders/video/paddlehub.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
__copyright__ = "Copyright (c) 2020 Jina AI Limited. All rights reserved."
__license__ = "Apache-2.0"


from .. import BaseVideoEncoder
from ..frameworks import BaseCVPaddlehubEncoder


class VideoPaddlehubEncoder(BaseCVPaddlehubEncoder):
class VideoPaddlehubEncoder(BaseVideoEncoder, BaseCVPaddlehubEncoder):
"""
:class:`VideoPaddlehubEncoder` encodes data from a ndarray, potentially B x T x (Channel x Height x Width) into a
ndarray of `B x D`.
Expand Down
3 changes: 2 additions & 1 deletion jina/executors/encoders/video/torchvision.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
__copyright__ = "Copyright (c) 2020 Jina AI Limited. All rights reserved."
__license__ = "Apache-2.0"

from .. import BaseVideoEncoder
from ..frameworks import BaseTorchEncoder


class VideoTorchEncoder(BaseTorchEncoder):
class VideoTorchEncoder(BaseVideoEncoder, BaseTorchEncoder):
"""
:class:`VideoTorchEncoder` encodes data from a ndarray, potentially B x T x (Channel x Height x Width) into an
ndarray of `B x D`.
Expand Down

0 comments on commit bda92d8

Please sign in to comment.