Skip to content

Commit

Permalink
Allow non-decoupled model to send response and FINAL flag separately
Browse files Browse the repository at this point in the history
  • Loading branch information
GuanLuo committed Aug 7, 2023
1 parent 36d80fe commit 5f21ff1
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 70 deletions.
52 changes: 52 additions & 0 deletions qa/L0_decoupled/decoupled_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,5 +591,57 @@ def test_wrong_shape(self):
)


class NonDecoupledTest(tu.TestResultCollector):
def setUp(self):
self.model_name_ = "repeat_int32"
self.input_data = {
"IN": np.array([1], dtype=np.int32),
"DELAY": np.array([0], dtype=np.uint32),
"WAIT": np.array([0], dtype=np.uint32),
}

def test_grpc(self):
inputs = [
grpcclient.InferInput("IN", [1], "INT32").set_data_from_numpy(
self.input_data["IN"]
),
grpcclient.InferInput("DELAY", [1], "UINT32").set_data_from_numpy(
self.input_data["DELAY"]
),
grpcclient.InferInput("WAIT", [1], "UINT32").set_data_from_numpy(
self.input_data["WAIT"]
),
]

triton_client = grpcclient.InferenceServerClient(
url="localhost:8001", verbose=True
)
# Expect the inference is successful
res = triton_client.infer(model_name=self.model_name_, inputs=inputs)
self.assertEqual(1, res.as_numpy("OUT")[0])
self.assertEqual(0, res.as_numpy("IDX")[0])

def test_http(self):
inputs = [
httpclient.InferInput("IN", [1], "INT32").set_data_from_numpy(
self.input_data["IN"]
),
httpclient.InferInput("DELAY", [1], "UINT32").set_data_from_numpy(
self.input_data["DELAY"]
),
httpclient.InferInput("WAIT", [1], "UINT32").set_data_from_numpy(
self.input_data["WAIT"]
),
]

triton_client = httpclient.InferenceServerClient(
url="localhost:8000", verbose=True
)
# Expect the inference is successful
res = triton_client.infer(model_name=self.model_name_, inputs=inputs)
self.assertEqual(1, res.as_numpy("OUT")[0])
self.assertEqual(0, res.as_numpy("IDX")[0])


if __name__ == "__main__":
unittest.main()
43 changes: 42 additions & 1 deletion qa/L0_decoupled/test.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/bin/bash
# Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
# Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -129,6 +129,47 @@ for trial in $TRIALS; do
wait $SERVER_PID
done

# Test the server frontend can merge the responses of non-decoupled model that
# sends inference response and COMPLETE flag separately. In other words, from
# the client's perspective there will still be one response.
NON_DECOUPLED_DIR=`pwd`/non_decoupled_models
rm -rf ${NON_DECOUPLED_DIR} && mkdir -p ${NON_DECOUPLED_DIR}
cp -r `pwd`/models/repeat_int32 ${NON_DECOUPLED_DIR}/. && \
(cd ${NON_DECOUPLED_DIR}/repeat_int32 && \
sed -i "s/decoupled: True/decoupled: False/" config.pbtxt)

SERVER_ARGS="--model-repository=${NON_DECOUPLED_DIR}"
SERVER_LOG="./non_decoupled_inference_server.log"

run_server
if [ "$SERVER_PID" == "0" ]; then
echo -e "\n***\n*** Failed to start $SERVER\n***"
cat $SERVER_LOG
exit 1
fi

CLIENT_LOG=`pwd`/non_decoupled_client.log
echo "Test: NonDecoupledTest" >>$CLIENT_LOG
set +e
python $DECOUPLED_TEST NonDecoupledTest >>$CLIENT_LOG 2>&1
if [ $? -ne 0 ]; then
echo -e "\n***\n*** Test NonDecoupledTest Failed\n***" >>$CLIENT_LOG
echo -e "\n***\n*** Test NonDecoupledTest Failed\n***"
RET=1
else
check_test_results $TEST_RESULT_FILE 2
if [ $? -ne 0 ]; then
cat $CLIENT_LOG
echo -e "\n***\n*** Test Result Verification Failed\n***"
RET=1
fi
fi

set -e

kill $SERVER_PID
wait $SERVER_PID

if [ $RET -eq 0 ]; then
echo -e "\n***\n*** Test Passed\n***"
else
Expand Down
23 changes: 11 additions & 12 deletions src/grpc/infer_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -965,18 +965,14 @@ ModelInferHandler::InferResponseComplete(
{
State* state = reinterpret_cast<State*>(userp);

// Increment the callback index
state->cb_count_++;
// Increment the callback index if received valid 'iresponse'
if (iresponse != nullptr) {
state->cb_count_++;
}

LOG_VERBOSE(1) << "ModelInferHandler::InferResponseComplete, "
<< state->unique_id_ << " step " << state->step_;

// Defer to the callback with the final response
if ((flags & TRITONSERVER_RESPONSE_COMPLETE_FINAL) == 0) {
LOG_ERROR << "[INTERNAL] ModelInfer received a response without FINAL flag";
return;
}

#ifdef TRITON_ENABLE_TRACING
state->trace_timestamps_.emplace_back(std::make_pair(
"INFER_RESPONSE_COMPLETE", TraceManager::CaptureTimestamp()));
Expand Down Expand Up @@ -1004,10 +1000,7 @@ ModelInferHandler::InferResponseComplete(
"expected a single response, got " +
std::to_string(state->cb_count_))
.c_str());
} else if (iresponse == nullptr) {
err = TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INTERNAL, "received an unexpected null response");
} else {
} else if (iresponse != nullptr) {
err = InferResponseCompleteCommon<inference::ModelInferResponse>(
state->tritonserver_, iresponse, *response, state->alloc_payload_);
}
Expand All @@ -1024,6 +1017,12 @@ ModelInferHandler::InferResponseComplete(
TRITONSERVER_InferenceResponseDelete(iresponse),
"deleting GRPC inference response");

// Defer sending the response until FINAL flag is seen or
// there is error
if (status.ok() && (flags & TRITONSERVER_RESPONSE_COMPLETE_FINAL) == 0) {
return;
}

#ifdef TRITON_ENABLE_TRACING
state->trace_timestamps_.emplace_back(
std::make_pair("GRPC_SEND_START", TraceManager::CaptureTimestamp()));
Expand Down
47 changes: 20 additions & 27 deletions src/http_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3057,7 +3057,7 @@ HTTPAPIServer::InferRequestClass::InferRequestClass(
TRITONSERVER_Server* server, evhtp_request_t* req,
DataCompressor::Type response_compression_type)
: server_(server), req_(req),
response_compression_type_(response_compression_type), response_count_(0)
response_compression_type_(response_compression_type)
{
evhtp_connection_t* htpconn = evhtp_request_get_connection(req);
thread_ = htpconn->thread;
Expand Down Expand Up @@ -3097,25 +3097,19 @@ HTTPAPIServer::InferRequestClass::InferResponseComplete(
HTTPAPIServer::InferRequestClass* infer_request =
reinterpret_cast<HTTPAPIServer::InferRequestClass*>(userp);

auto response_count = infer_request->IncrementResponseCount();

// Defer to the callback with the final response
if ((flags & TRITONSERVER_RESPONSE_COMPLETE_FINAL) == 0) {
LOG_ERROR << "[INTERNAL] received a response without FINAL flag";
return;
if (response != nullptr) {
++infer_request->response_count_;
}

TRITONSERVER_Error* err = nullptr;
if (response_count != 0) {
err = TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INTERNAL, std::string(
"expected a single response, got " +
std::to_string(response_count + 1))
.c_str());
} else if (response == nullptr) {
if (infer_request->response_count_ != 1) {
err = TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INTERNAL, "received an unexpected null response");
} else {
TRITONSERVER_ERROR_INTERNAL,
std::string(
"expected a single response, got " +
std::to_string(infer_request->response_count_))
.c_str());
} else if (response != nullptr) {
err = infer_request->FinalizeResponse(response);
}

Expand All @@ -3126,17 +3120,23 @@ HTTPAPIServer::InferRequestClass::InferResponseComplete(
}
#endif // TRITON_ENABLE_TRACING

LOG_TRITONSERVER_ERROR(
TRITONSERVER_InferenceResponseDelete(response),
"deleting inference response");

// Defer sending the response until FINAL flag is seen or
// there is error
if ((err == nullptr) && (flags & TRITONSERVER_RESPONSE_COMPLETE_FINAL) == 0) {
return;
}

if (err == nullptr) {
evthr_defer(infer_request->thread_, OKReplyCallback, infer_request);
} else {
EVBufferAddErrorJson(infer_request->req_->buffer_out, err);
TRITONSERVER_ErrorDelete(err);
evthr_defer(infer_request->thread_, BADReplyCallback, infer_request);
}

LOG_TRITONSERVER_ERROR(
TRITONSERVER_InferenceResponseDelete(response),
"deleting inference response");
}

TRITONSERVER_Error*
Expand Down Expand Up @@ -3439,13 +3439,6 @@ HTTPAPIServer::InferRequestClass::SetResponseHeader(
}
}

uint32_t
HTTPAPIServer::InferRequestClass::IncrementResponseCount()
{
return response_count_++;
}


void
HTTPAPIServer::Handle(evhtp_request_t* req)
{
Expand Down
19 changes: 9 additions & 10 deletions src/http_server.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -207,8 +207,6 @@ class HTTPAPIServer : public HTTPServer {
virtual void SetResponseHeader(
const bool has_binary_data, const size_t header_length);

uint32_t IncrementResponseCount();

#ifdef TRITON_ENABLE_TRACING
std::shared_ptr<TraceManager::Trace> trace_;
#endif // TRITON_ENABLE_TRACING
Expand All @@ -220,15 +218,16 @@ class HTTPAPIServer : public HTTPServer {
// lifetime of the request.
std::list<std::vector<char>> serialized_data_;

protected:
TRITONSERVER_Server* server_;
evhtp_request_t* req_;
evthr_t* thread_;
// Counter to keep track of number of responses generated.
std::atomic<uint32_t> response_count_{0};

DataCompressor::Type response_compression_type_;
protected:
TRITONSERVER_Server* server_{nullptr};
evhtp_request_t* req_{nullptr};
evthr_t* thread_{nullptr};

// Counter to keep track of number of responses generated.
std::atomic<uint32_t> response_count_;
DataCompressor::Type response_compression_type_{
DataCompressor::Type::IDENTITY};
};

protected:
Expand Down
40 changes: 20 additions & 20 deletions src/sagemaker_server.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -361,25 +361,19 @@ SagemakerAPIServer::SagemakeInferRequestClass::InferResponseComplete(
SagemakerAPIServer::SagemakeInferRequestClass* infer_request =
reinterpret_cast<SagemakerAPIServer::SagemakeInferRequestClass*>(userp);

auto response_count = infer_request->IncrementResponseCount();

// Defer to the callback with the final response
if ((flags & TRITONSERVER_RESPONSE_COMPLETE_FINAL) == 0) {
LOG_ERROR << "[INTERNAL] received a response without FINAL flag";
return;
if (response != nullptr) {
++infer_request->response_count_;
}

TRITONSERVER_Error* err = nullptr;
if (response_count != 0) {
err = TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INTERNAL, std::string(
"expected a single response, got " +
std::to_string(response_count + 1))
.c_str());
} else if (response == nullptr) {
if (infer_request->response_count_ != 1) {
err = TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INTERNAL, "received an unexpected null response");
} else {
TRITONSERVER_ERROR_INTERNAL,
std::string(
"expected a single response, got " +
std::to_string(infer_request->response_count_))
.c_str());
} else if (response != nullptr) {
err = infer_request->FinalizeResponse(response);
}

Expand All @@ -390,6 +384,16 @@ SagemakerAPIServer::SagemakeInferRequestClass::InferResponseComplete(
}
#endif // TRITON_ENABLE_TRACING

LOG_TRITONSERVER_ERROR(
TRITONSERVER_InferenceResponseDelete(response),
"deleting inference response");

// Defer sending the response until FINAL flag is seen or
// there is error
if ((err == nullptr) && (flags & TRITONSERVER_RESPONSE_COMPLETE_FINAL) == 0) {
return;
}

if (err == nullptr) {
evthr_defer(infer_request->thread_, OKReplyCallback, infer_request);
} else {
Expand All @@ -404,10 +408,6 @@ SagemakerAPIServer::SagemakeInferRequestClass::InferResponseComplete(
}
TRITONSERVER_ErrorDelete(err);
}

LOG_TRITONSERVER_ERROR(
TRITONSERVER_InferenceResponseDelete(response),
"deleting inference response");
}

void
Expand Down

0 comments on commit 5f21ff1

Please sign in to comment.