From 8ab618628ba492c409e37f87a2ae40961a457ecf Mon Sep 17 00:00:00 2001 From: krishung5 Date: Mon, 6 Nov 2023 18:26:53 -0800 Subject: [PATCH] Address comment --- src/infer_request.cc | 9 ++++----- src/infer_request.h | 3 +-- src/pb_stub.cc | 20 ++++++++++++++++++++ src/pb_stub.h | 10 ++++++++++ src/python_be.cc | 13 ++++++++----- src/resources/triton_python_backend_utils.py | 1 + 6 files changed, 44 insertions(+), 12 deletions(-) diff --git a/src/infer_request.cc b/src/infer_request.cc index 6d21e54d..8f875565 100644 --- a/src/infer_request.cc +++ b/src/infer_request.cc @@ -44,14 +44,13 @@ InferRequest::InferRequest( const std::string& model_name, const int64_t model_version, const std::string& parameters, const uint32_t flags, const int32_t timeout, const intptr_t response_factory_address, const intptr_t request_address, - const PreferredMemory& preferred_memory, const InferenceTrace& trace, - const uint32_t& request_release_flags) + const PreferredMemory& preferred_memory, const InferenceTrace& trace) : request_id_(request_id), correlation_id_(correlation_id), inputs_(inputs), requested_output_names_(requested_output_names), model_name_(model_name), model_version_(model_version), parameters_(parameters), flags_(flags), timeout_(timeout), response_factory_address_(response_factory_address), request_address_(request_address), preferred_memory_(preferred_memory), - trace_(trace), request_release_flags_(request_release_flags) + trace_(trace), request_release_flags_(TRITONSERVER_REQUEST_RELEASE_ALL) { for (auto& input : inputs) { if (!input) { @@ -74,7 +73,7 @@ InferRequest::InferRequest( #ifdef TRITON_PB_STUB pb_cancel_ = std::make_shared(response_factory_address_, request_address_); - response_sender_ = std::make_shared( + response_sender_ = Stub::GetOrCreateInstance()->GetResponseSender( request_address_, response_factory_address_, Stub::GetOrCreateInstance()->SharedMemory(), pb_cancel_); #endif @@ -400,7 +399,7 @@ InferRequest::InferRequest( #ifdef TRITON_PB_STUB pb_cancel_ = std::make_shared(response_factory_address_, request_address_); - response_sender_ = std::make_shared( + response_sender_ = Stub::GetOrCreateInstance()->GetResponseSender( request_address_, response_factory_address_, Stub::GetOrCreateInstance()->SharedMemory(), pb_cancel_); #endif diff --git a/src/infer_request.h b/src/infer_request.h index 8205d245..3d81c5d2 100644 --- a/src/infer_request.h +++ b/src/infer_request.h @@ -88,8 +88,7 @@ class InferRequest { const intptr_t request_address = 0, const PreferredMemory& preferred_memory = PreferredMemory(PreferredMemory::DEFAULT, 0), - const InferenceTrace& trace = InferenceTrace(), - const uint32_t& request_release_flags = TRITONSERVER_REQUEST_RELEASE_ALL); + const InferenceTrace& trace = InferenceTrace()); const std::vector>& Inputs(); const std::string& RequestId(); diff --git a/src/pb_stub.cc b/src/pb_stub.cc index 3d473101..1de7f975 100644 --- a/src/pb_stub.cc +++ b/src/pb_stub.cc @@ -1386,6 +1386,26 @@ Stub::ProcessBLSResponseDecoupled(std::unique_ptr& ipc_message) } } +std::shared_ptr +Stub::GetResponseSender( + intptr_t request_address, intptr_t response_factory_address, + std::unique_ptr& shm_pool, + const std::shared_ptr& pb_cancel) +{ + std::lock_guard lock(response_sender_map_mu_); + if (response_sender_map_.find(request_address) != + response_sender_map_.end()) { + return response_sender_map_[request_address]; + } else { + auto response_sender = std::make_shared( + request_address, response_factory_address, shm_pool, pb_cancel); + response_sender_map_.insert( + std::pair>( + request_address, response_sender)); + return response_sender; + } +} + std::unique_ptr Logger::log_instance_; std::unique_ptr& diff --git a/src/pb_stub.h b/src/pb_stub.h index 12b47abc..031dfa14 100644 --- a/src/pb_stub.h +++ b/src/pb_stub.h @@ -357,6 +357,13 @@ class Stub { /// Get the CUDA memory pool address from the parent process. void GetCUDAMemoryPoolAddress(std::unique_ptr& ipc_message); + /// Get the response sender associated with the request, or create a new one + /// if it does not exist in the map. + std::shared_ptr GetResponseSender( + intptr_t request_address, intptr_t response_factory_address, + std::unique_ptr& shm_pool, + const std::shared_ptr& pb_cancel); + private: bi::interprocess_mutex* stub_mutex_; bi::interprocess_condition* stub_cond_; @@ -395,6 +402,9 @@ class Stub { response_iterator_map_; std::mutex dlpack_proxy_stream_pool_mu_; std::unordered_map dlpack_proxy_stream_pool_; + std::mutex response_sender_map_mu_; + std::unordered_map> + response_sender_map_; }; template diff --git a/src/python_be.cc b/src/python_be.cc index 8f1266c9..8da35d4b 100644 --- a/src/python_be.cc +++ b/src/python_be.cc @@ -396,16 +396,14 @@ ModelInstanceState::SaveRequestsToSharedMemory( model_state->Name(), model_state->Version(), parameters_string, flags, 0 /* BLS request timeout*/, reinterpret_cast(factory_ptr), reinterpret_cast(request), - PreferredMemory(PreferredMemory::DEFAULT, 0), trace, - TRITONSERVER_REQUEST_RELEASE_ALL /* request release flags */); + PreferredMemory(PreferredMemory::DEFAULT, 0), trace); } else { infer_request = std::make_unique( id, correlation_id, pb_input_tensors, requested_output_names, model_state->Name(), model_state->Version(), parameters_string, flags, 0 /* BLS request timeout*/, 0 /* response_factory_address */, reinterpret_cast(request), - PreferredMemory(PreferredMemory::DEFAULT, 0), trace, - TRITONSERVER_REQUEST_RELEASE_ALL /* request release flags */); + PreferredMemory(PreferredMemory::DEFAULT, 0), trace); } RETURN_IF_EXCEPTION(infer_request->SaveToSharedMemory(Stub()->ShmPool())); @@ -1366,6 +1364,11 @@ ModelInstanceState::ProcessRequestsDecoupled( TRITONSERVER_ERROR_INTERNAL, error->String().c_str()); } + // Reset the release flags for all the requests. + for (auto& infer_request : pb_infer_requests) { + infer_request->SetReleaseFlags(TRITONSERVER_REQUEST_RELEASE_ALL); + } + return TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_INTERNAL, "Failed to process the requests."); } @@ -2502,7 +2505,7 @@ TRITONBACKEND_ModelInstanceExecute( (std::string("Failed to release request: ") + pb_exception.what()) .c_str()); if (request_release_flags[r] == TRITONSERVER_REQUEST_RELEASE_RESCHEDULE) { - // If error occurs during request reschedule, release the request with + // If error occurs during request rescheduling, release the request with // `TRITONSERVER_REQUEST_RELEASE_ALL` flag. LOG_IF_ERROR( TRITONBACKEND_RequestRelease( diff --git a/src/resources/triton_python_backend_utils.py b/src/resources/triton_python_backend_utils.py index 47c20593..936ed8d3 100644 --- a/src/resources/triton_python_backend_utils.py +++ b/src/resources/triton_python_backend_utils.py @@ -606,4 +606,5 @@ def set_model_transaction_policy(self, transaction_policy_dict): TRITONSERVER_REQUEST_FLAG_SEQUENCE_START = 1 TRITONSERVER_REQUEST_FLAG_SEQUENCE_END = 2 TRITONSERVER_RESPONSE_COMPLETE_FINAL = 1 +TRITONSERVER_REQUEST_RELEASE_ALL = 1 TRITONSERVER_REQUEST_RELEASE_RESCHEDULE = 2