Skip to content

Commit

Permalink
Fix gRPC frontend race condition (#7110)
Browse files Browse the repository at this point in the history
* Fix state complete_ race condition

* Add delay and error checking to StreamInferResponseComplete

* Add test for gRPC decoupled infer complete flag
  • Loading branch information
kthui authored and mc-nv committed Apr 18, 2024
1 parent 03efe23 commit 538ba51
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 9 deletions.
9 changes: 9 additions & 0 deletions qa/L0_grpc_state_cleanup/cleanup_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,15 @@ def test_decoupled_infer_with_params_shutdownserver(self):
infer_helper_map=[False, True],
)

def test_decoupled_infer_complete(self):
# Test if the Process() thread could release the state object before
# the StreamInferResponseComplete() thread is done accessing it.
self._decoupled_infer(request_count=1, repeat_count=1, stream_timeout=16)
# Check no error is printed to the log.
with open(os.environ["SERVER_LOG"]) as f:
server_log = f.read()
self.assertNotIn("Should not print this", server_log)


if __name__ == "__main__":
CleanUpTest.SERVER_PID = os.environ.get("SERVER_PID", CleanUpTest.SERVER_PID)
Expand Down
34 changes: 34 additions & 0 deletions qa/L0_grpc_state_cleanup/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,40 @@ for i in test_simple_infer_shutdownserver \
set -e
done

TEST_NAME=test_decoupled_infer_complete
export TRITONSERVER_DELAY_GRPC_COMPLETE=2000

SERVER_LOG="./inference_server.$TEST_NAME.log"
SERVER_ARGS="--model-repository=`pwd`/models --log-verbose=2"
run_server
if [ "$SERVER_PID" == "0" ]; then
echo -e "\n***\n*** Failed to start $SERVER\n***"
cat $SERVER_LOG
exit 1
fi

echo "Test: $TEST_NAME" >>$CLIENT_LOG

set +e

SERVER_LOG=$SERVER_LOG python $CLEANUP_TEST CleanUpTest.$TEST_NAME >>$CLIENT_LOG 2>&1
if [ $? -ne 0 ]; then
cat $CLIENT_LOG
echo -e "\n***\n*** Test $TEST_NAME Failed\n***"
RET=1
fi

kill $SERVER_PID
wait $SERVER_PID

check_state_release $SERVER_LOG
if [ $? -ne 0 ]; then
cat $SERVER_LOG
echo -e "\n***\n*** State Verification Failed for $TEST_NAME\n***"
RET=1
fi

set -e

if [ $RET -eq 0 ]; then
echo -e "\n***\n*** Test Passed\n***"
Expand Down
9 changes: 8 additions & 1 deletion src/grpc/infer_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -1013,12 +1013,18 @@ class InferHandlerState {
const std::shared_ptr<Context>& context, Steps start_step = Steps::START)
: tritonserver_(tritonserver), async_notify_state_(false)
{
// For debugging and testing,
// For debugging and testing
const char* dstr = getenv("TRITONSERVER_DELAY_GRPC_RESPONSE");
delay_response_ms_ = 0;
if (dstr != nullptr) {
delay_response_ms_ = atoi(dstr);
}
const char* cstr = getenv("TRITONSERVER_DELAY_GRPC_COMPLETE");
delay_complete_ms_ = 0;
if (cstr != nullptr) {
delay_complete_ms_ = atoi(cstr);
}

response_queue_.reset(new ResponseQueue<ResponseType>());
Reset(context, start_step);
}
Expand Down Expand Up @@ -1113,6 +1119,7 @@ class InferHandlerState {

// For testing and debugging
int delay_response_ms_;
int delay_complete_ms_;

// For inference requests the allocator payload, unused for other
// requests.
Expand Down
37 changes: 29 additions & 8 deletions src/grpc/stream_infer_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -574,9 +574,10 @@ ModelStreamInferHandler::StreamInferResponseComplete(
#endif // TRITON_ENABLE_TRACING

// Log appropriate errors
state->complete_ = ((flags & TRITONSERVER_RESPONSE_COMPLETE_FINAL) != 0);
bool is_complete =
state->complete_ || (flags & TRITONSERVER_RESPONSE_COMPLETE_FINAL) != 0;
if (!state->is_decoupled_) {
if (!state->complete_) {
if (!is_complete) {
LOG_ERROR << "[INTERNAL] ModelStreamInfer received a response without "
"FINAL flag for a model with one-to-one transaction";
}
Expand All @@ -591,7 +592,7 @@ ModelStreamInferHandler::StreamInferResponseComplete(
// Also make sure that if this state was sent to gRPC async notification
// mechanism then the state is not removed as it would be needed for handling
// the cancellation if detected.
if (state->complete_ && (!state->IsAsyncNotifyState())) {
if (is_complete && (!state->IsAsyncNotifyState())) {
state->context_->EraseInflightState(state);
}

Expand All @@ -610,11 +611,12 @@ ModelStreamInferHandler::StreamInferResponseComplete(
// If this was the final callback for the state
// then cycle through the completion queue so
// that state object can be released.
if (state->complete_) {
if (is_complete) {
state->step_ = Steps::CANCELLED;
state->context_->PutTaskBackToQueue(state);
}

state->complete_ = is_complete;
return;
}

Expand Down Expand Up @@ -661,8 +663,7 @@ ModelStreamInferHandler::StreamInferResponseComplete(
// "empty" responses are not sent back to the client. Clients can
// opt-in to receiving these empty responses via request parameters.
// NOTE: The complete flag is the only flag used for this case at this time.
const bool empty_final =
(!iresponse && state->is_decoupled_ && state->complete_);
const bool empty_final = !iresponse && state->is_decoupled_ && is_complete;
const bool enable_empty_final =
state->parameters_.enable_empty_final_response_;

Expand Down Expand Up @@ -690,7 +691,24 @@ ModelStreamInferHandler::StreamInferResponseComplete(
infer_response.set_model_version(state->request_.model_version());
}
auto& params = *(infer_response.mutable_parameters());
params["triton_final_response"].set_bool_param(state->complete_);
params["triton_final_response"].set_bool_param(is_complete);
}

if (state->delay_complete_ms_ != 0) {
// Delay updating the state. This is useful for testing race condition with
// the thread that runs Process().
LOG_INFO << "Delaying the completion of reporting response / flag by "
<< state->delay_complete_ms_ << " ms...";
void* context_ptr_before_delay = (void*)state->context_.get();
std::this_thread::sleep_for(
std::chrono::milliseconds(state->delay_complete_ms_));
void* context_ptr_after_delay = (void*)state->context_.get();
if (context_ptr_before_delay != context_ptr_after_delay) {
LOG_ERROR << "Should not print this! The state context object has "
"changed after delay, pointer before: "
<< context_ptr_before_delay
<< ", pointer after: " << context_ptr_after_delay;
}
}

// Update states to signal that response/error is ready to write to stream
Expand All @@ -708,11 +726,12 @@ ModelStreamInferHandler::StreamInferResponseComplete(
// If this was the final callback for the state
// then cycle through the completion queue so
// that state object can be released.
if (state->complete_) {
if (is_complete) {
state->step_ = Steps::CANCELLED;
state->context_->PutTaskBackToQueue(state);
}

state->complete_ = is_complete;
return;
}

Expand All @@ -728,6 +747,8 @@ ModelStreamInferHandler::StreamInferResponseComplete(
state->step_ = Steps::WRITEREADY;
state->context_->WriteResponseIfReady(state);
}

state->complete_ = is_complete;
}
}

Expand Down

0 comments on commit 538ba51

Please sign in to comment.