Skip to content

Commit

Permalink
feat(executors): add typing
Browse files Browse the repository at this point in the history
  • Loading branch information
nan-wang committed May 19, 2020
1 parent ff89b38 commit 15f9f1f
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 34 deletions.
57 changes: 47 additions & 10 deletions jina/executors/clients.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from . import BaseExecutor
import grpc
from typing import Dict


class BaseClientExecutor(BaseExecutor):
Expand All @@ -24,6 +25,27 @@ class BaseTFServingClientExecutor(BaseClientExecutor):
:class:`BaseTFServingClientExecutor` is the base class for the executors that wrap up a tf serving client. For the
sake of generality, this implementation has the dependency on :mod:`tensorflow_serving`.
To implement your own executor with `tfserving`,
.. highlight:: python
.. code-block:: python
class MyAwesomeTFServingClientEncoder(BaseTFServingClientExecutor, BaseEncoder):
def encode(self, data: Any, *args, **kwargs) -> Any:
_req = self.get_request(data)
return self.get_response(_req)
def get_input(self, data):
input_1 = data[:, 0]
input_2 = data[:, 1:]
return {
'my_input_1': inpnut_1.reshape(-1, 1).astype(np.float32),
'my_input_2': inpnut_2.astype(np.float32)
}
def get_output(self, response):
return np.array(response.result().outputs['output_feature'].float_val)
"""
def __init__(self, service_name, signature_name='serving_default', *args, **kwargs):
"""
Expand All @@ -45,23 +67,41 @@ def post_init(self):
self._stub = prediction_service_pb2_grpc.PredictionServiceStub(self._channel)

def get_request(self, data):
"""
Construct the gRPC request to the tf server.
"""
request = self.get_default_request()
input_dict = self.get_input(data)
return self.fill_request(request, input_dict)

def get_input(self, data):
def get_input(self, data) -> Dict:
"""
Convert the input data into a dict with the models input feature names as the keys and the input tensors as the
values.
"""
raise NotImplementedError

def get_response(self, response):
if response.exception():
self.logger.error('exception raised in encoding: {}'.format(response.exception))
def get_response(self, request: 'predict_pb2.PredictRequest'):
"""
Get the response from the tf server and postprocess the response
"""
_response = self._stub.Predict.future(request, self.timeout)
if _response.exception():
self.logger.error('exception raised in encoding: {}'.format(_response.exception))
raise ValueError
return self.get_output(response)
return self.get_output(_response)

def get_output(self, response):
def get_output(self, response: grpc.UnaryUnaryMultiCallable):
"""
Postprocess the response from the tf server
"""
raise NotImplementedError

def get_default_request(self):
def get_default_request(self) -> 'predict_pb2.PredictRequest':
"""
Construct the default gRPC request to the tf server.
"""
from tensorflow_serving.apis import predict_pb2
request = predict_pb2.PredictRequest()
request.model_spec.name = self.service_name
Expand All @@ -74,6 +114,3 @@ def fill_request(request, data_dict):
for k, v in data_dict.items():
request.inputs[k].CopyFrom(tf.make_tensor_proto(v))
return request

def callback(self, response):
pass
25 changes: 3 additions & 22 deletions jina/executors/encoders/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,18 @@
from . import BaseEncoder


class BaseTFServingEncoder(BaseTFServingClientExecutor, BaseEncoder):
class BaseTFServingClientEncoder(BaseTFServingClientExecutor, BaseEncoder):
"""
:class:`BaseTFServingEncoder` is the base class for the encoders that wrap up a tf serving client. The client call
the gRPC port of the tf server.
To implement your own executor with `tfserving`,
.. highlight:: python
.. code-block:: python
class MyAwesomeTFServingEncoder(BaseTFServingEncoder):
def get_input(self, data):
input_1 = data[:, 0]
input_2 = data[:, 1:]
return {
'my_input_1': inpnut_1.reshape(-1, 1).astype(np.float32),
'my_input_2': inpnut_2.astype(np.float32)
}
def get_output(self, response):
return np.array(response.result().outputs['output_feature'].float_val)
"""
def encode(self, data: Any, *args, **kwargs) -> Any:
_req = self.get_request(data)
_rsp = self._stub.Predict.future(_req, self.timeout)
output = self.get_response(_rsp)
return output
return self.get_response(_req)


class UnaryTFServingEncoder(BaseTFServingEncoder):
class UnaryTFServingClientEncoder(BaseTFServingClientEncoder):
"""
:class:`UnaryTFServingEncoder` is an encoder that wraps up a tf serving client. This client covers the simplest
case, in which both the request and the response have a single data field.
Expand Down
4 changes: 2 additions & 2 deletions tests/executors/encoders/clients.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import unittest

from tests import JinaTestCase
from jina.executors.encoders.clients import UnaryTFServingEncoder
from jina.executors.encoders.clients import UnaryTFServingClientEncoder


class MyTestCase(JinaTestCase):
@unittest.skip('add grpc mocking for this test')
def test_something(self):
encoder = UnaryTFServingEncoder(
encoder = UnaryTFServingClientEncoder(
host='0.0.0.0', port='8500', service_name='mnist',
input_name='images', output_name='scores',
signature_name='predict_images')
Expand Down

0 comments on commit 15f9f1f

Please sign in to comment.