Skip to content

Commit

Permalink
Switch from beta to GA gRPC API.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 206612637
  • Loading branch information
netfs authored and tensorflower-gardener committed Jul 30, 2018
1 parent 1e74469 commit aa35cfd
Show file tree
Hide file tree
Showing 7 changed files with 230 additions and 293 deletions.
5 changes: 4 additions & 1 deletion tensorflow_serving/apis/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,10 @@ serving_proto_library(

py_library(
name = "prediction_service_proto_py_pb2",
srcs = ["prediction_service_pb2.py"],
srcs = [
"prediction_service_pb2.py",
"prediction_service_pb2_grpc.py",
],
srcs_version = "PY2AND3",
deps = [
":classification_proto_py_pb2",
Expand Down
313 changes: 59 additions & 254 deletions tensorflow_serving/apis/prediction_service_pb2.py

Large diffs are not rendered by default.

139 changes: 139 additions & 0 deletions tensorflow_serving/apis/prediction_service_pb2_grpc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
# source: tensorflow_serving/apis/prediction_service.proto
# To regenerate run
# python -m grpc.tools.protoc --python_out=. --grpc_python_out=. -I. tensorflow_serving/apis/prediction_service.proto
import grpc

from tensorflow_serving.apis import classification_pb2 as tensorflow__serving_dot_apis_dot_classification__pb2
from tensorflow_serving.apis import get_model_metadata_pb2 as tensorflow__serving_dot_apis_dot_get__model__metadata__pb2
from tensorflow_serving.apis import inference_pb2 as tensorflow__serving_dot_apis_dot_inference__pb2
from tensorflow_serving.apis import predict_pb2 as tensorflow__serving_dot_apis_dot_predict__pb2
from tensorflow_serving.apis import regression_pb2 as tensorflow__serving_dot_apis_dot_regression__pb2


class PredictionServiceStub(object):
"""open source marker; do not remove
PredictionService provides access to machine-learned models loaded by
model_servers.
"""

def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.Classify = channel.unary_unary(
'/tensorflow.serving.PredictionService/Classify',
request_serializer=tensorflow__serving_dot_apis_dot_classification__pb2.ClassificationRequest.SerializeToString,
response_deserializer=tensorflow__serving_dot_apis_dot_classification__pb2.ClassificationResponse.FromString,
)
self.Regress = channel.unary_unary(
'/tensorflow.serving.PredictionService/Regress',
request_serializer=tensorflow__serving_dot_apis_dot_regression__pb2.RegressionRequest.SerializeToString,
response_deserializer=tensorflow__serving_dot_apis_dot_regression__pb2.RegressionResponse.FromString,
)
self.Predict = channel.unary_unary(
'/tensorflow.serving.PredictionService/Predict',
request_serializer=tensorflow__serving_dot_apis_dot_predict__pb2.PredictRequest.SerializeToString,
response_deserializer=tensorflow__serving_dot_apis_dot_predict__pb2.PredictResponse.FromString,
)
self.MultiInference = channel.unary_unary(
'/tensorflow.serving.PredictionService/MultiInference',
request_serializer=tensorflow__serving_dot_apis_dot_inference__pb2.MultiInferenceRequest.SerializeToString,
response_deserializer=tensorflow__serving_dot_apis_dot_inference__pb2.MultiInferenceResponse.FromString,
)
self.GetModelMetadata = channel.unary_unary(
'/tensorflow.serving.PredictionService/GetModelMetadata',
request_serializer=tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.GetModelMetadataRequest.SerializeToString,
response_deserializer=tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.GetModelMetadataResponse.FromString,
)


class PredictionServiceServicer(object):
"""open source marker; do not remove
PredictionService provides access to machine-learned models loaded by
model_servers.
"""

def Classify(self, request, context):
"""Classify.
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def Regress(self, request, context):
"""Regress.
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def Predict(self, request, context):
"""Predict -- provides access to loaded TensorFlow model.
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def MultiInference(self, request, context):
"""MultiInference API for multi-headed models.
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def GetModelMetadata(self, request, context):
"""GetModelMetadata - provides access to metadata for loaded models.
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')


def add_PredictionServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
'Classify': grpc.unary_unary_rpc_method_handler(
servicer.Classify,
request_deserializer=tensorflow__serving_dot_apis_dot_classification__pb2.ClassificationRequest.FromString,
response_serializer=tensorflow__serving_dot_apis_dot_classification__pb2.ClassificationResponse.SerializeToString,
),
'Regress': grpc.unary_unary_rpc_method_handler(
servicer.Regress,
request_deserializer=tensorflow__serving_dot_apis_dot_regression__pb2.RegressionRequest.FromString,
response_serializer=tensorflow__serving_dot_apis_dot_regression__pb2.RegressionResponse.SerializeToString,
),
'Predict': grpc.unary_unary_rpc_method_handler(
servicer.Predict,
request_deserializer=tensorflow__serving_dot_apis_dot_predict__pb2.PredictRequest.FromString,
response_serializer=tensorflow__serving_dot_apis_dot_predict__pb2.PredictResponse.SerializeToString,
),
'MultiInference': grpc.unary_unary_rpc_method_handler(
servicer.MultiInference,
request_deserializer=tensorflow__serving_dot_apis_dot_inference__pb2.MultiInferenceRequest.FromString,
response_serializer=tensorflow__serving_dot_apis_dot_inference__pb2.MultiInferenceResponse.SerializeToString,
),
'GetModelMetadata': grpc.unary_unary_rpc_method_handler(
servicer.GetModelMetadata,
request_deserializer=tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.GetModelMetadataRequest.FromString,
response_serializer=tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.GetModelMetadataResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'tensorflow.serving.PredictionService', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
9 changes: 4 additions & 5 deletions tensorflow_serving/example/inception_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@

# This is a placeholder for a Google-internal import.

from grpc.beta import implementations
import grpc
import tensorflow as tf

from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2
from tensorflow_serving.apis import prediction_service_pb2_grpc


tf.app.flags.DEFINE_string('server', 'localhost:9000',
Expand All @@ -36,9 +36,8 @@


def main(_):
host, port = FLAGS.server.split(':')
channel = implementations.insecure_channel(host, int(port))
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
channel = grpc.insecure_channel(FLAGS.server)
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
# Send request
with open(FLAGS.image, 'rb') as f:
# See prediction_service.proto for gRPC request/response details.
Expand Down
9 changes: 4 additions & 5 deletions tensorflow_serving/example/mnist_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@

# This is a placeholder for a Google-internal import.

from grpc.beta import implementations
import grpc
import numpy
import tensorflow as tf

from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2
from tensorflow_serving.apis import prediction_service_pb2_grpc
import mnist_input_data


Expand Down Expand Up @@ -137,9 +137,8 @@ def do_inference(hostport, work_dir, concurrency, num_tests):
IOError: An error occurred processing test data set.
"""
test_data_set = mnist_input_data.read_data_sets(work_dir).test
host, port = hostport.split(':')
channel = implementations.insecure_channel(host, int(port))
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
channel = grpc.insecure_channel(hostport)
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
result_counter = _ResultCounter(num_tests, concurrency)
for _ in range(num_tests):
request = predict_pb2.PredictRequest()
Expand Down
39 changes: 16 additions & 23 deletions tensorflow_serving/model_servers/tensorflow_model_server_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@
# This is a placeholder for a Google-internal import.

import grpc
from grpc.beta import implementations
from grpc.beta import interfaces as beta_interfaces
from grpc.framework.interfaces.face import face
import tensorflow as tf

from tensorflow.core.framework import types_pb2
Expand All @@ -42,7 +39,7 @@
from tensorflow_serving.apis import inference_pb2
from tensorflow_serving.apis import model_service_pb2_grpc
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2
from tensorflow_serving.apis import prediction_service_pb2_grpc
from tensorflow_serving.apis import regression_pb2

FLAGS = flags.FLAGS
Expand Down Expand Up @@ -70,12 +67,12 @@ def WaitForServerReady(port):

try:
# Send empty request to missing model
channel = implementations.insecure_channel('localhost', port)
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
channel = grpc.insecure_channel('localhost:{}'.format(port))
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
stub.Predict(request, RPC_TIMEOUT)
except face.AbortionError as error:
except grpc.RpcError as error:
# Missing model error will have details containing 'Servable'
if 'Servable' in error.details:
if 'Servable' in error.details():
print 'Server is ready'
break

Expand Down Expand Up @@ -199,9 +196,8 @@ def VerifyPredictRequest(self,
if specify_output:
request.output_filter.append('y')
# Send request
host, port = model_server_address.split(':')
channel = implementations.insecure_channel(host, int(port))
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
channel = grpc.insecure_channel(model_server_address)
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
result = stub.Predict(request, RPC_TIMEOUT) # 5 secs timeout
# Verify response
self.assertTrue('y' in result.outputs)
Expand Down Expand Up @@ -313,9 +309,8 @@ def testClassify(self):
example.features.feature['x'].float_list.value.extend([2.0])

# Send request
host, port = model_server_address.split(':')
channel = implementations.insecure_channel(host, int(port))
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
channel = grpc.insecure_channel(model_server_address)
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
result = stub.Classify(request, RPC_TIMEOUT) # 5 secs timeout
# Verify response
self.assertEquals(1, len(result.result.classifications))
Expand Down Expand Up @@ -345,9 +340,8 @@ def testRegress(self):
example.features.feature['x'].float_list.value.extend([2.0])

# Send request
host, port = model_server_address.split(':')
channel = implementations.insecure_channel(host, int(port))
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
channel = grpc.insecure_channel(model_server_address)
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
result = stub.Regress(request, RPC_TIMEOUT) # 5 secs timeout
# Verify response
self.assertEquals(1, len(result.result.regressions))
Expand Down Expand Up @@ -381,9 +375,8 @@ def testMultiInference(self):
example.features.feature['x'].float_list.value.extend([2.0])

# Send request
host, port = model_server_address.split(':')
channel = implementations.insecure_channel(host, int(port))
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
channel = grpc.insecure_channel(model_server_address)
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
result = stub.MultiInference(request, RPC_TIMEOUT) # 5 secs timeout

# Verify response
Expand Down Expand Up @@ -451,13 +444,13 @@ def _TestBadModel(self):
model_server_address = self.RunServer(PickUnusedPort(), 'default',
model_path,
wait_for_server_ready=False)
with self.assertRaises(face.AbortionError) as error:
with self.assertRaises(grpc.RpcError) as ectxt:
self.VerifyPredictRequest(
model_server_address, expected_output=3.0,
expected_version=self._GetModelVersion(model_path),
signature_name='')
self.assertIs(beta_interfaces.StatusCode.FAILED_PRECONDITION,
error.exception.code)
self.assertIs(grpc.StatusCode.FAILED_PRECONDITION,
ectxt.exception.code())

def _TestBadModelUpconvertedSavedModel(self):
"""Test Predict against a bad upconverted SavedModel model export."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@

# This is a placeholder for a Google-internal import.

from grpc.beta import implementations
import grpc
import tensorflow as tf

from tensorflow.core.framework import types_pb2
from tensorflow.python.platform import flags
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2
from tensorflow_serving.apis import prediction_service_pb2_grpc


tf.app.flags.DEFINE_string('server', 'localhost:8500',
Expand All @@ -41,9 +41,8 @@ def main(_):
request.inputs['x'].float_val.append(2.0)
request.output_filter.append('y')
# Send request
host, port = FLAGS.server.split(':')
channel = implementations.insecure_channel(host, int(port))
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
channel = grpc.insecure_channel(FLAGS.server)
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
print stub.Predict(request, 5.0) # 5 secs timeout


Expand Down

0 comments on commit aa35cfd

Please sign in to comment.