Skip to content

Commit

Permalink
fix(encoder): add batching mode
Browse files Browse the repository at this point in the history
  • Loading branch information
nan-wang committed Apr 6, 2020
1 parent 93382b0 commit f6ca182
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 1 deletion.
3 changes: 3 additions & 0 deletions jina/executors/encoders/image/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import re
from .. import BaseImageEncoder
from ...decorators import batching, as_ndarray


class OnnxImageEncoder(BaseImageEncoder):
Expand Down Expand Up @@ -58,6 +59,8 @@ def _append_outputs(input_fn, outputs_name_to_append, output_fn):
model.graph.output.append(feature_map)
onnx.save(model, output_fn)

@batching
@as_ndarray
def encode(self, data: 'np.ndarray', *args, **kwargs) -> 'np.ndarray':
"""
Expand Down
3 changes: 3 additions & 0 deletions jina/executors/encoders/nlp/char.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np

from .. import BaseTextEncoder
from ...decorators import batching, as_ndarray


class OneHotTextEncoder(BaseTextEncoder):
Expand Down Expand Up @@ -28,6 +29,8 @@ def post_init(self):
self.embeddings = np.eye(self.dim) * self.on_value + \
(np.ones((self.dim, self.dim)) - np.eye(self.dim)) * self.off_value

@batching
@as_ndarray
def encode(self, data: 'np.ndarray', *args, **kwargs) -> 'np.ndarray':
"""
Expand Down
3 changes: 3 additions & 0 deletions jina/executors/encoders/nlp/flair.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np

from .. import BaseTextEncoder
from ...decorators import batching, as_ndarray


class FlairTextEncoder(BaseTextEncoder):
Expand Down Expand Up @@ -59,6 +60,8 @@ def post_init(self):
self.model = DocumentPoolEmbeddings(embeddings_list, pooling=self.pooling_strategy)
self.logger.info('initialize flair encoder with embeddings: {}'.format(self.embeddings))

@batching
@as_ndarray
def encode(self, data: 'np.ndarray', *args, **kwargs) -> 'np.ndarray':
"""
Expand Down
1 change: 0 additions & 1 deletion jina/executors/encoders/numeric/pca.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import numpy as np
import os
from ...decorators import batching, require_train

from .. import BaseNumericEncoder
Expand Down

0 comments on commit f6ca182

Please sign in to comment.