Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix gRPC streaming non-decoupled segfault if sending response and final flag separately #7265

Merged
merged 5 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 59 additions & 3 deletions qa/L0_grpc_state_cleanup/cleanup_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python3

# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -67,9 +67,10 @@ class CleanUpTest(tu.TestResultCollector):
def setUp(self):
self.decoupled_model_name_ = "repeat_int32"
self.identity_model_name_ = "custom_zero_1_float32"
self.repeat_non_decoupled_model_name = "repeat_int32_non_decoupled"

def _prepare_inputs_and_outputs(self, kind):
if kind == "decoupled_streaming":
if kind in ("decoupled_streaming", "non_decoupled_streaming"):
self.inputs_ = []
self.inputs_.append(grpcclient.InferInput("IN", [1], "INT32"))
self.inputs_.append(grpcclient.InferInput("DELAY", [1], "UINT32"))
Expand All @@ -79,7 +80,7 @@ def _prepare_inputs_and_outputs(self, kind):
self.outputs_.append(grpcclient.InferRequestedOutput("OUT"))
self.outputs_.append(grpcclient.InferRequestedOutput("IDX"))
self.requested_outputs_ = self.outputs_
elif kind == "simple" or kind == "streaming":
elif kind in ("simple", "streaming"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, what is streaming compared to decoupled_streaming and non_decoupled_streaming?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The decoupled_streaming and non_decoupled_streaming are for the repeat_int32 and repeat_int32_non_decoupled models on repeat backend. The simple and streaming are for the custom_zero_1_float32 model on identity backend.

I think we could do some refactoring on how the inputs are prepared, i.e. for the repeat backend, we can have a function that returns a inputs that can be used directly on the client.async_stream_infer().

self.inputs_ = []
self.inputs_.append(grpcclient.InferInput("INPUT0", [1, 1], "FP32"))

Expand Down Expand Up @@ -563,6 +564,61 @@ def test_decoupled_infer_complete(self):
server_log = f.read()
self.assertNotIn("Should not print this", server_log)

def test_non_decoupled_streaming_multi_response(self):
# Test non-decoupled streaming infer with more than one response should return
# the first response.
response_count = 4
expected_response_count = 1
expected_response_index = 0

# Prepare input data
self._prepare_inputs_and_outputs("non_decoupled_streaming")
# Initialize data for IN
data_offset = 100
input_data = np.arange(
start=data_offset, stop=data_offset + response_count, dtype=np.int32
)
self.inputs_[0].set_shape([response_count])
self.inputs_[0].set_data_from_numpy(input_data)
# Initialize data for DELAY
delay_data = np.zeros([response_count], dtype=np.uint32)
self.inputs_[1].set_shape([response_count])
self.inputs_[1].set_data_from_numpy(delay_data)
# Initialize data for WAIT
wait_data = np.array([0], dtype=np.uint32)
self.inputs_[2].set_data_from_numpy(wait_data)

# Infer
user_data = UserData()
with grpcclient.InferenceServerClient(
url="localhost:8001", verbose=True
) as client:
# Establish stream
client.start_stream(
callback=partial(callback, user_data), stream_timeout=16
)
# Send a request
client.async_stream_infer(
model_name=self.repeat_non_decoupled_model_name,
inputs=self.inputs_,
request_id="0",
outputs=self.requested_outputs_,
)
# Wait for all results and stop stream
client.stop_stream()

# Check infer output
actual_response_count = 0
while not user_data._response_queue.empty():
actual_response_count += 1
data_item = user_data._response_queue.get()
if type(data_item) == InferenceServerException:
raise data_item
else:
response_idx = data_item.as_numpy("IDX")[0]
self.assertEqual(response_idx, expected_response_index)
self.assertEqual(actual_response_count, expected_response_count)


if __name__ == "__main__":
CleanUpTest.SERVER_PID = os.environ.get("SERVER_PID", CleanUpTest.SERVER_PID)
Expand Down
11 changes: 9 additions & 2 deletions qa/L0_grpc_state_cleanup/test.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/bin/bash
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
# Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -73,6 +73,12 @@ rm -fr ./models/custom_zero_1_float32 && \
echo "{ key: \"execute_delay_ms\"; value: { string_value: \"1000\" }}" >> config.pbtxt && \
echo "]" >> config.pbtxt)

rm -rf models/repeat_int32_non_decoupled && \
cp -r models/repeat_int32 models/repeat_int32_non_decoupled && \
(cd models/repeat_int32_non_decoupled && \
sed -i "/model_transaction_policy/,+2d" config.pbtxt && \
sed -i "s/repeat_int32/repeat_int32_non_decoupled/" config.pbtxt)

for i in test_simple_infer \
test_simple_infer_cancellation \
test_simple_infer_timeout \
Expand All @@ -81,7 +87,8 @@ for i in test_simple_infer \
test_streaming_cancellation \
test_decoupled_infer \
test_decoupled_cancellation \
test_decoupled_timeout; do
test_decoupled_timeout \
test_non_decoupled_streaming_multi_response; do
SERVER_LOG="./inference_server.$i.log"
SERVER_ARGS="--model-repository=`pwd`/models --log-verbose=2"
run_server
Expand Down
20 changes: 6 additions & 14 deletions src/grpc/stream_infer_handler.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -564,7 +564,8 @@ ModelStreamInferHandler::StreamInferResponseComplete(
LOG_VERBOSE(1) << "ModelStreamInferHandler::StreamInferComplete, context "
<< state->context_->unique_id_ << ", " << state->unique_id_
<< " step " << state->step_ << ", callback index "
<< state->cb_count_ << ", flags " << flags;
<< state->cb_count_ << ", flags " << flags
<< ", response is nullptr " << (iresponse == nullptr);

#ifdef TRITON_ENABLE_TRACING
if (state->cb_count_ == 1) {
Expand All @@ -573,19 +574,8 @@ ModelStreamInferHandler::StreamInferResponseComplete(
}
#endif // TRITON_ENABLE_TRACING

// Log appropriate errors
bool is_complete =
state->complete_ || (flags & TRITONSERVER_RESPONSE_COMPLETE_FINAL) != 0;
if (!state->is_decoupled_) {
if (!is_complete) {
LOG_ERROR << "[INTERNAL] ModelStreamInfer received a response without "
"FINAL flag for a model with one-to-one transaction";
}
if (iresponse == nullptr) {
LOG_ERROR << "[INTERNAL] ModelStreamInfer received a null response for a "
"model with one-to-one transaction";
}
}

// If receiving the final callback then erase the state from the inflight
// state data structure to prevent cancellation being called on the request.
Expand Down Expand Up @@ -745,7 +735,9 @@ ModelStreamInferHandler::StreamInferResponseComplete(
}
} else {
state->step_ = Steps::WRITEREADY;
state->context_->WriteResponseIfReady(state);
if (is_complete) {
state->context_->WriteResponseIfReady(state);
}
kthui marked this conversation as resolved.
Show resolved Hide resolved
}

state->complete_ = is_complete;
Expand Down
Loading