diff --git a/Dockerfile.QA b/Dockerfile.QA index 2c43f735a5..b381abfaaf 100644 --- a/Dockerfile.QA +++ b/Dockerfile.QA @@ -267,6 +267,12 @@ RUN cp -r qa/L0_decoupled/models qa/L0_decoupled/python_models/ && \ cp /workspace/tritonbuild/python/examples/decoupled/square_config.pbtxt \ qa/L0_decoupled/python_models/square_int32/. +RUN mkdir -p qa/L0_decoupled_grpc_error && \ + cp -r qa/L0_decoupled/. qa/L0_decoupled_grpc_error + +RUN mkdir -p qa/L0_grpc_error_state_cleanup && \ + cp -r qa/L0_grpc_state_cleanup/. qa/L0_grpc_error_state_cleanup + RUN mkdir -p qa/L0_repoagent_checksum/models/identity_int32/1 && \ cp tritonbuild/identity/install/backends/identity/libtriton_identity.so \ qa/L0_repoagent_checksum/models/identity_int32/1/. diff --git a/docs/customization_guide/inference_protocols.md b/docs/customization_guide/inference_protocols.md index 592f26e7d1..a241f097da 100644 --- a/docs/customization_guide/inference_protocols.md +++ b/docs/customization_guide/inference_protocols.md @@ -115,6 +115,16 @@ These options can be used to configure the KeepAlive settings: For client-side documentation, see [Client-Side GRPC KeepAlive](https://github.com/triton-inference-server/client/blob/main/README.md#grpc-keepalive). +#### GRPC Status Codes + +Triton implements GRPC error handling for streaming requests when a specific flag is enabled through headers. Upon encountering an error, Triton returns the appropriate GRPC error code and subsequently closes the stream. + +* `triton_grpc_error` : The header value needs to be set to true while starting the stream. + +GRPC status codes can be used for better visibility and monitoring. For more details, see [gRPC Status Codes](https://grpc.io/docs/guides/status-codes/) + +For client-side documentation, see [Client-Side GRPC Status Codes](https://github.com/triton-inference-server/client/tree/main#GRPC-Status-Codes) + ### Limit Endpoint Access (BETA) Triton users may want to restrict access to protocols or APIs that are diff --git a/qa/L0_backend_python/lifecycle/lifecycle_test.py b/qa/L0_backend_python/lifecycle/lifecycle_test.py index 883f6d20b6..d6eb2a8f53 100755 --- a/qa/L0_backend_python/lifecycle/lifecycle_test.py +++ b/qa/L0_backend_python/lifecycle/lifecycle_test.py @@ -35,6 +35,7 @@ sys.path.append("../../common") import queue +import threading import time import unittest from functools import partial @@ -241,6 +242,135 @@ def test_infer_pymodel_error(self): initial_metrics_value, ) + # Test grpc stream behavior when triton_grpc_error is set to true. + # Expected to close stream and return GRPC error when model returns error. + def test_triton_grpc_error_error_on(self): + model_name = "execute_grpc_error" + shape = [2, 2] + number_of_requests = 2 + user_data = UserData() + triton_client = grpcclient.InferenceServerClient(f"{_tritonserver_ipaddr}:8001") + metadata = {"triton_grpc_error": "true"} + triton_client.start_stream( + callback=partial(callback, user_data), headers=metadata + ) + stream_end = False + for i in range(number_of_requests): + input_data = np.random.randn(*shape).astype(np.float32) + inputs = [ + grpcclient.InferInput( + "IN", input_data.shape, np_to_triton_dtype(input_data.dtype) + ) + ] + inputs[0].set_data_from_numpy(input_data) + try: + triton_client.async_stream_infer(model_name=model_name, inputs=inputs) + result = user_data._completed_requests.get() + if type(result) == InferenceServerException: + # execute_grpc_error intentionally returns error with StatusCode.INTERNAL status on 2nd request + self.assertEqual(str(result.status()), "StatusCode.INTERNAL") + stream_end = True + else: + # Stream is not killed + output_data = result.as_numpy("OUT") + self.assertIsNotNone(output_data, "error: expected 'OUT'") + except Exception as e: + if stream_end == True: + # We expect the stream to have closed + self.assertTrue( + True, + "This should always pass as cancellation should succeed", + ) + else: + self.assertFalse( + True, "Unexpected Stream killed without Error from CORE" + ) + + # Test grpc stream behavior when triton_grpc_error is set to true in multiple open streams. + # Expected to close stream and return GRPC error when model returns error. + def test_triton_grpc_error_multithreaded(self): + thread1 = threading.Thread(target=self.test_triton_grpc_error_error_on) + thread2 = threading.Thread(target=self.test_triton_grpc_error_error_on) + # Start the threads + thread1.start() + thread2.start() + # Wait for both threads to finish + thread1.join() + thread2.join() + + # Test grpc stream behavior when triton_grpc_error is set to true and subsequent stream is cancelled. + # Expected cancellation is successful. + def test_triton_grpc_error_cancel(self): + model_name = "execute_grpc_error" + shape = [2, 2] + number_of_requests = 1 + user_data = UserData() + triton_server_url = "localhost:8001" # Replace with your Triton server address + stream_end = False + triton_client = grpcclient.InferenceServerClient(triton_server_url) + + metadata = {"triton_grpc_error": "true"} + + triton_client.start_stream( + callback=partial(callback, user_data), headers=metadata + ) + + for i in range(number_of_requests): + input_data = np.random.randn(*shape).astype(np.float32) + inputs = [ + grpcclient.InferInput( + "IN", input_data.shape, np_to_triton_dtype(input_data.dtype) + ) + ] + inputs[0].set_data_from_numpy(input_data) + try: + triton_client.async_stream_infer(model_name=model_name, inputs=inputs) + result = user_data._completed_requests.get() + if type(result) == InferenceServerException: + stream_end = True + if i == 0: + triton_client.stop_stream(cancel_requests=True) + except Exception as e: + if stream_end == True: + # We expect the stream to have closed + self.assertTrue( + True, + "This should always pass as cancellation should succeed", + ) + else: + self.assertFalse( + True, "Unexpected Stream killed without Error from CORE" + ) + self.assertTrue( + True, + "This should always pass as cancellation should succeed without any exception", + ) + + # Test grpc stream behavior when triton_grpc_error is set to false + # and subsequent stream is NOT closed when error is reported from CORE + def test_triton_grpc_error_error_off(self): + model_name = "execute_grpc_error" + shape = [2, 2] + number_of_requests = 4 + response_counter = 0 + user_data = UserData() + triton_client = grpcclient.InferenceServerClient(f"{_tritonserver_ipaddr}:8001") + triton_client.start_stream(callback=partial(callback, user_data)) + for i in range(number_of_requests): + input_data = np.random.randn(*shape).astype(np.float32) + inputs = [ + grpcclient.InferInput( + "IN", input_data.shape, np_to_triton_dtype(input_data.dtype) + ) + ] + inputs[0].set_data_from_numpy(input_data) + triton_client.async_stream_infer(model_name=model_name, inputs=inputs) + _ = user_data._completed_requests.get() + response_counter += 1 + # we expect response_counter == number_of_requests, + # which indicates that after the first reported grpc error stream did NOT close and mode != triton_grpc_error + self.assertEqual(response_counter, number_of_requests) + if __name__ == "__main__": unittest.main() diff --git a/qa/L0_backend_python/lifecycle/test.sh b/qa/L0_backend_python/lifecycle/test.sh index dba4581ddd..59b846f56b 100755 --- a/qa/L0_backend_python/lifecycle/test.sh +++ b/qa/L0_backend_python/lifecycle/test.sh @@ -52,6 +52,14 @@ cp ../../python_models/execute_error/config.pbtxt ./models/execute_error/ sed -i "s/^max_batch_size:.*/max_batch_size: 8/" config.pbtxt && \ echo "dynamic_batching { preferred_batch_size: [8], max_queue_delay_microseconds: 12000000 }" >> config.pbtxt) +mkdir -p models/execute_grpc_error/1/ +cp ../../python_models/execute_grpc_error/model.py ./models/execute_grpc_error/1/ +cp ../../python_models/execute_grpc_error/config.pbtxt ./models/execute_grpc_error/ +(cd models/execute_grpc_error && \ + sed -i "s/^name:.*/name: \"execute_grpc_error\"/" config.pbtxt && \ + sed -i "s/^max_batch_size:.*/max_batch_size: 8/" config.pbtxt && \ + echo "dynamic_batching { preferred_batch_size: [8], max_queue_delay_microseconds: 1200000 }" >> config.pbtxt) + mkdir -p models/execute_return_error/1/ cp ../../python_models/execute_return_error/model.py ./models/execute_return_error/1/ cp ../../python_models/execute_return_error/config.pbtxt ./models/execute_return_error/ diff --git a/qa/L0_decoupled/decoupled_test.py b/qa/L0_decoupled/decoupled_test.py index 1f76f4845b..d7bc59f5c7 100755 --- a/qa/L0_decoupled/decoupled_test.py +++ b/qa/L0_decoupled/decoupled_test.py @@ -116,7 +116,13 @@ def _stream_infer_with_params( url="localhost:8001", verbose=True ) as triton_client: # Establish stream - triton_client.start_stream(callback=partial(callback, user_data)) + if "TRITONSERVER_GRPC_STATUS_FLAG" in os.environ: + metadata = {"triton_grpc_error": "true"} + triton_client.start_stream( + callback=partial(callback, user_data), headers=metadata + ) + else: + triton_client.start_stream(callback=partial(callback, user_data)) # Send specified many requests in parallel for i in range(request_count): time.sleep((request_delay / 1000)) @@ -175,7 +181,13 @@ def _stream_infer( url="localhost:8001", verbose=True ) as triton_client: # Establish stream - triton_client.start_stream(callback=partial(callback, user_data)) + if "TRITONSERVER_GRPC_STATUS_FLAG" in os.environ: + metadata = {"triton_grpc_error": "true"} + triton_client.start_stream( + callback=partial(callback, user_data), headers=metadata + ) + else: + triton_client.start_stream(callback=partial(callback, user_data)) # Send specified many requests in parallel for i in range(request_count): time.sleep((request_delay / 1000)) diff --git a/qa/L0_decoupled/test.sh b/qa/L0_decoupled/test.sh index 98ad134d8b..22c37dff49 100755 --- a/qa/L0_decoupled/test.sh +++ b/qa/L0_decoupled/test.sh @@ -176,4 +176,4 @@ else echo -e "\n***\n*** Test Failed\n***" fi -exit $RET +exit $RET \ No newline at end of file diff --git a/qa/L0_grpc_state_cleanup/cleanup_test.py b/qa/L0_grpc_state_cleanup/cleanup_test.py index 431eeb1720..f7507747e9 100755 --- a/qa/L0_grpc_state_cleanup/cleanup_test.py +++ b/qa/L0_grpc_state_cleanup/cleanup_test.py @@ -161,9 +161,17 @@ def _stream_infer_with_params( url="localhost:8001", verbose=True ) as triton_client: # Establish stream - triton_client.start_stream( - callback=partial(callback, user_data), stream_timeout=stream_timeout - ) + if "TRITONSERVER_GRPC_STATUS_FLAG" in os.environ: + metadata = {"triton_grpc_error": "true"} + triton_client.start_stream( + callback=partial(callback, user_data), + stream_timeout=stream_timeout, + headers=metadata, + ) + else: + triton_client.start_stream( + callback=partial(callback, user_data), stream_timeout=stream_timeout + ) # Send specified many requests in parallel for i in range(request_count): time.sleep((request_delay / 1000)) @@ -229,9 +237,17 @@ def _stream_infer( url="localhost:8001", verbose=True ) as triton_client: # Establish stream - triton_client.start_stream( - callback=partial(callback, user_data), stream_timeout=stream_timeout - ) + if "TRITONSERVER_GRPC_STATUS_FLAG" in os.environ: + metadata = {"triton_grpc_error": "true"} + triton_client.start_stream( + callback=partial(callback, user_data), + stream_timeout=stream_timeout, + headers=metadata, + ) + else: + triton_client.start_stream( + callback=partial(callback, user_data), stream_timeout=stream_timeout + ) # Send specified many requests in parallel for i in range(request_count): time.sleep((request_delay / 1000)) @@ -608,9 +624,17 @@ def test_non_decoupled_streaming_multi_response(self): url="localhost:8001", verbose=True ) as client: # Establish stream - client.start_stream( - callback=partial(callback, user_data), stream_timeout=16 - ) + if "TRITONSERVER_GRPC_STATUS_FLAG" in os.environ: + metadata = {"triton_grpc_error": "true"} + client.start_stream( + callback=partial(callback, user_data), + stream_timeout=16, + headers=metadata, + ) + else: + client.start_stream( + callback=partial(callback, user_data), stream_timeout=16 + ) # Send a request client.async_stream_infer( model_name=self.repeat_non_decoupled_model_name, diff --git a/qa/python_models/execute_grpc_error/config.pbtxt b/qa/python_models/execute_grpc_error/config.pbtxt new file mode 100644 index 0000000000..70e247148a --- /dev/null +++ b/qa/python_models/execute_grpc_error/config.pbtxt @@ -0,0 +1,51 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +backend: "python" +max_batch_size: 64 + +input [ + { + name: "IN" + data_type: TYPE_FP32 + dims: [ -1 ] + } +] + +output [ + { + name: "OUT" + data_type: TYPE_FP32 + dims: [ -1 ] + } +] + +instance_group [ + { + count: 1 + kind : KIND_CPU + } +] diff --git a/qa/python_models/execute_grpc_error/model.py b/qa/python_models/execute_grpc_error/model.py new file mode 100644 index 0000000000..d5087a49ec --- /dev/null +++ b/qa/python_models/execute_grpc_error/model.py @@ -0,0 +1,52 @@ +# Copyright 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import triton_python_backend_utils as pb_utils + + +class TritonPythonModel: + def __init__(self): + # Maintain total inference count, so as to return error on 2nd request, all of this to simulate model failure + self.inf_count = 1 + + def execute(self, requests): + """This function is called on inference request.""" + responses = [] + + # Generate the error for the second request + for request in requests: + input_tensor = pb_utils.get_input_tensor_by_name(request, "IN") + out_tensor = pb_utils.Tensor("OUT", input_tensor.as_numpy()) + if self.inf_count % 2: + # Every odd request is success + responses.append(pb_utils.InferenceResponse([out_tensor])) + else: + # Every even request is failure + error = pb_utils.TritonError("An error occurred during execution") + responses.append(pb_utils.InferenceResponse([out_tensor], error)) + self.inf_count += 1 + + return responses diff --git a/src/grpc/grpc_utils.h b/src/grpc/grpc_utils.h index 898e4acb4f..032dec3ad9 100644 --- a/src/grpc/grpc_utils.h +++ b/src/grpc/grpc_utils.h @@ -76,6 +76,46 @@ typedef enum { PARTIAL_COMPLETION } Steps; +typedef enum { + // No error from CORE seen yet + NONE, + // Error from CORE encountered, waiting to be picked up by completion queue to + // initiate cancellation + ERROR_ENCOUNTERED, + // Error from CORE encountered, stream closed + // This state is added to avoid double cancellation + ERROR_HANDLING_COMPLETE +} TritonGRPCErrorSteps; + +class gRPCErrorTracker { + public: + // True if set by user via header + // Can be accessed without a lock, as set only once in startstream + std::atomic triton_grpc_error_; + + // Indicates the state of triton_grpc_error, only relevant if special + // triton_grpc_error feature set to true by client + TritonGRPCErrorSteps grpc_stream_error_state_; + + // Constructor + gRPCErrorTracker() + : triton_grpc_error_(false), + grpc_stream_error_state_(TritonGRPCErrorSteps::NONE) + { + } + // Changes the state of grpc_stream_error_state_ to ERROR_HANDLING_COMPLETE, + // indicating we have closed the stream and initiated the cancel flow + void MarkGRPCErrorHandlingComplete(); + + // Returns true ONLY when GRPC_ERROR from CORE is waiting to be processed. + bool CheckAndUpdateGRPCError(); + + // Marks error after it has been responded to + void MarkGRPCErrorEncountered(); + + // Checks if error already responded to in triton_grpc_error mode + bool GRPCErrorEncountered(); +}; // Debugging helper std::ostream& operator<<(std::ostream& out, const Steps& step); @@ -183,5 +223,4 @@ TRITONSERVER_Error* ParseClassificationParams( void ReadFile(const std::string& filename, std::string& data); - }}} // namespace triton::server::grpc diff --git a/src/grpc/infer_handler.h b/src/grpc/infer_handler.h index a976ccff02..51307d4ae0 100644 --- a/src/grpc/infer_handler.h +++ b/src/grpc/infer_handler.h @@ -646,6 +646,7 @@ class InferHandlerState { { ctx_.reset(new ::grpc::ServerContext()); responder_.reset(new ServerResponderType(ctx_.get())); + gRPCErrorTracker_ = std::make_unique(); } void SetCompressionLevel(grpc_compression_level compression_level) @@ -672,7 +673,6 @@ class InferHandlerState { gRPCErrorTracker_->CheckAndUpdateGRPCError()) : false; } - // Increments the ongoing request counter void IncrementRequestCounter() { ongoing_requests_++; } @@ -714,6 +714,37 @@ class InferHandlerState { return false; } + // Extracts headers from GRPC request and updates state + void ExtractStateFromHeaders(InferHandlerStateType* state) + { + const auto& metadata = state->context_->ctx_->client_metadata(); + std::string triton_grpc_error_key = "triton_grpc_error"; + + auto it = metadata.find( + {triton_grpc_error_key.data(), triton_grpc_error_key.size()}); + + if (it != metadata.end()) { + if (it->second == "true") { + LOG_VERBOSE(2) + << "GRPC: triton_grpc_error mode detected in new grpc stream"; + state->context_->gRPCErrorTracker_->triton_grpc_error_ = true; + } + } + } + + void WriteGRPCErrorResponse(InferHandlerStateType* state) + { + std::lock_guard lock(state->context_->mu_); + // Check if Error not responded previously + // Avoid closing connection twice on multiple errors from core + if (!state->context_->gRPCErrorTracker_->GRPCErrorEncountered()) { + state->step_ = Steps::COMPLETE; + state->context_->responder_->Finish(state->status_, state); + // Mark error for this stream + state->context_->gRPCErrorTracker_->MarkGRPCErrorEncountered(); + } + } + const std::string DebugString(InferHandlerStateType* state) { std::string debug_string(""); @@ -797,6 +828,7 @@ class InferHandlerState { bool HandleCancellation( InferHandlerStateType* state, bool rpc_ok, const std::string& name) { + // Check to avoid early exit in case of triton_grpc_error if (!IsCancelled()) { LOG_ERROR << "[INTERNAL] HandleCancellation called even when the context was " @@ -820,7 +852,6 @@ class InferHandlerState { IssueRequestCancellation(); // Mark the context as cancelled state->context_->step_ = Steps::CANCELLED; - // The state returns true because the CancelExecution // call above would have raised alarm objects on all // pending inflight states objects. This state will @@ -1003,6 +1034,8 @@ class InferHandlerState { // Tracks whether the async notification has been delivered by // completion queue. bool received_notification_; + + std::unique_ptr gRPCErrorTracker_; }; // This constructor is used to build a wrapper state object @@ -1094,7 +1127,6 @@ class InferHandlerState { void MarkAsAsyncNotifyState() { async_notify_state_ = true; } bool IsAsyncNotifyState() { return async_notify_state_; } - // Needed in the response handle for classification outputs. TRITONSERVER_Server* tritonserver_; diff --git a/src/grpc/stream_infer_handler.cc b/src/grpc/stream_infer_handler.cc index 585f88d536..6651eca813 100644 --- a/src/grpc/stream_infer_handler.cc +++ b/src/grpc/stream_infer_handler.cc @@ -189,7 +189,7 @@ ModelStreamInferHandler::Process(InferHandler::State* state, bool rpc_ok) state->context_->responder_->Finish(status, state); return !finished; } - + state->context_->ExtractStateFromHeaders(state); } else if (state->step_ == Steps::READ) { TRITONSERVER_Error* err = nullptr; const inference::ModelInferRequest& request = state->request_; @@ -355,7 +355,6 @@ ModelStreamInferHandler::Process(InferHandler::State* state, bool rpc_ok) GrpcStatusUtil::Create(&status, err); TRITONSERVER_ErrorDelete(err); response->set_error_message(status.error_message()); - response->mutable_infer_response()->Clear(); // repopulate the id so that client knows which request failed. response->mutable_infer_response()->set_id(request.id()); @@ -596,7 +595,13 @@ ModelStreamInferHandler::StreamInferResponseComplete( void* userp) { State* state = reinterpret_cast(userp); - + // Ignore Response from CORE in case GRPC Strict as we dont care about + if (state->context_->gRPCErrorTracker_->triton_grpc_error_) { + std::lock_guard lock(state->context_->mu_); + if (state->context_->gRPCErrorTracker_->GRPCErrorEncountered()) { + return; + } + } // Increment the callback index uint32_t response_index = state->cb_count_++; @@ -671,14 +676,27 @@ ModelStreamInferHandler::StreamInferResponseComplete( } else { LOG_ERROR << "expected the response allocator to have added the response"; } - if (err != nullptr) { failed = true; ::grpc::Status status; + // Converts CORE errors to GRPC error codes GrpcStatusUtil::Create(&status, err); response->mutable_infer_response()->Clear(); response->set_error_message(status.error_message()); LOG_VERBOSE(1) << "Failed for ID: " << log_request_id << std::endl; + if (state->context_->gRPCErrorTracker_->triton_grpc_error_) { + state->status_ = status; + // Finish only once, if backend ignores cancellation + LOG_VERBOSE(1) << "GRPC streaming error detected with status: " + << status.error_code() << "Closing stream connection." + << std::endl; + state->context_->WriteGRPCErrorResponse(state); + TRITONSERVER_ErrorDelete(err); + LOG_TRITONSERVER_ERROR( + TRITONSERVER_InferenceResponseDelete(iresponse), + "deleting GRPC inference response"); + return; + } } TRITONSERVER_ErrorDelete(err); @@ -802,4 +820,42 @@ ModelStreamInferHandler::StreamInferResponseComplete( } } +// Changes the state of grpc_stream_error_state_ to ERROR_HANDLING_COMPLETE, +// indicating we have closed the stream and initiated the cancel flow +void +gRPCErrorTracker::MarkGRPCErrorHandlingComplete() +{ + grpc_stream_error_state_ = TritonGRPCErrorSteps::ERROR_HANDLING_COMPLETE; +} + +// Returns true ONLY when GRPC_ERROR from CORE is waiting to be processed. +bool +gRPCErrorTracker::CheckAndUpdateGRPCError() +{ + if (grpc_stream_error_state_ == TritonGRPCErrorSteps::ERROR_ENCOUNTERED) { + // Change the state to ERROR_HANDLING_COMPLETE as we have called + // HandleCancellation + MarkGRPCErrorHandlingComplete(); + return true; + } + return false; +} + +// Marks error after it has been responded to +void +gRPCErrorTracker::MarkGRPCErrorEncountered() +{ + grpc_stream_error_state_ = TritonGRPCErrorSteps::ERROR_ENCOUNTERED; +} + +// Checks if error already responded to in triton_grpc_error mode +bool +gRPCErrorTracker::GRPCErrorEncountered() +{ + if (grpc_stream_error_state_ == TritonGRPCErrorSteps::NONE) { + return false; + } + return true; +} + }}} // namespace triton::server::grpc