Skip to content

Commit

Permalink
Address comment
Browse files Browse the repository at this point in the history
  • Loading branch information
krishung5 committed Nov 7, 2023
1 parent 482d4bb commit 8ab6186
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 12 deletions.
9 changes: 4 additions & 5 deletions src/infer_request.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,13 @@ InferRequest::InferRequest(
const std::string& model_name, const int64_t model_version,
const std::string& parameters, const uint32_t flags, const int32_t timeout,
const intptr_t response_factory_address, const intptr_t request_address,
const PreferredMemory& preferred_memory, const InferenceTrace& trace,
const uint32_t& request_release_flags)
const PreferredMemory& preferred_memory, const InferenceTrace& trace)
: request_id_(request_id), correlation_id_(correlation_id), inputs_(inputs),
requested_output_names_(requested_output_names), model_name_(model_name),
model_version_(model_version), parameters_(parameters), flags_(flags),
timeout_(timeout), response_factory_address_(response_factory_address),
request_address_(request_address), preferred_memory_(preferred_memory),
trace_(trace), request_release_flags_(request_release_flags)
trace_(trace), request_release_flags_(TRITONSERVER_REQUEST_RELEASE_ALL)
{
for (auto& input : inputs) {
if (!input) {
Expand All @@ -74,7 +73,7 @@ InferRequest::InferRequest(
#ifdef TRITON_PB_STUB
pb_cancel_ =
std::make_shared<PbCancel>(response_factory_address_, request_address_);
response_sender_ = std::make_shared<ResponseSender>(
response_sender_ = Stub::GetOrCreateInstance()->GetResponseSender(
request_address_, response_factory_address_,
Stub::GetOrCreateInstance()->SharedMemory(), pb_cancel_);
#endif
Expand Down Expand Up @@ -400,7 +399,7 @@ InferRequest::InferRequest(
#ifdef TRITON_PB_STUB
pb_cancel_ =
std::make_shared<PbCancel>(response_factory_address_, request_address_);
response_sender_ = std::make_shared<ResponseSender>(
response_sender_ = Stub::GetOrCreateInstance()->GetResponseSender(
request_address_, response_factory_address_,
Stub::GetOrCreateInstance()->SharedMemory(), pb_cancel_);
#endif
Expand Down
3 changes: 1 addition & 2 deletions src/infer_request.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@ class InferRequest {
const intptr_t request_address = 0,
const PreferredMemory& preferred_memory =
PreferredMemory(PreferredMemory::DEFAULT, 0),
const InferenceTrace& trace = InferenceTrace(),
const uint32_t& request_release_flags = TRITONSERVER_REQUEST_RELEASE_ALL);
const InferenceTrace& trace = InferenceTrace());

const std::vector<std::shared_ptr<PbTensor>>& Inputs();
const std::string& RequestId();
Expand Down
20 changes: 20 additions & 0 deletions src/pb_stub.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1386,6 +1386,26 @@ Stub::ProcessBLSResponseDecoupled(std::unique_ptr<IPCMessage>& ipc_message)
}
}

std::shared_ptr<ResponseSender>
Stub::GetResponseSender(
intptr_t request_address, intptr_t response_factory_address,
std::unique_ptr<SharedMemoryManager>& shm_pool,
const std::shared_ptr<PbCancel>& pb_cancel)
{
std::lock_guard<std::mutex> lock(response_sender_map_mu_);
if (response_sender_map_.find(request_address) !=
response_sender_map_.end()) {
return response_sender_map_[request_address];
} else {
auto response_sender = std::make_shared<ResponseSender>(
request_address, response_factory_address, shm_pool, pb_cancel);
response_sender_map_.insert(
std::pair<intptr_t, std::shared_ptr<ResponseSender>>(
request_address, response_sender));
return response_sender;
}
}

std::unique_ptr<Logger> Logger::log_instance_;

std::unique_ptr<Logger>&
Expand Down
10 changes: 10 additions & 0 deletions src/pb_stub.h
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,13 @@ class Stub {
/// Get the CUDA memory pool address from the parent process.
void GetCUDAMemoryPoolAddress(std::unique_ptr<IPCMessage>& ipc_message);

/// Get the response sender associated with the request, or create a new one
/// if it does not exist in the map.
std::shared_ptr<ResponseSender> GetResponseSender(
intptr_t request_address, intptr_t response_factory_address,
std::unique_ptr<SharedMemoryManager>& shm_pool,
const std::shared_ptr<PbCancel>& pb_cancel);

private:
bi::interprocess_mutex* stub_mutex_;
bi::interprocess_condition* stub_cond_;
Expand Down Expand Up @@ -395,6 +402,9 @@ class Stub {
response_iterator_map_;
std::mutex dlpack_proxy_stream_pool_mu_;
std::unordered_map<int, cudaStream_t> dlpack_proxy_stream_pool_;
std::mutex response_sender_map_mu_;
std::unordered_map<intptr_t, std::shared_ptr<ResponseSender>>
response_sender_map_;
};

template <typename MessageType>
Expand Down
13 changes: 8 additions & 5 deletions src/python_be.cc
Original file line number Diff line number Diff line change
Expand Up @@ -396,16 +396,14 @@ ModelInstanceState::SaveRequestsToSharedMemory(
model_state->Name(), model_state->Version(), parameters_string, flags,
0 /* BLS request timeout*/, reinterpret_cast<intptr_t>(factory_ptr),
reinterpret_cast<intptr_t>(request),
PreferredMemory(PreferredMemory::DEFAULT, 0), trace,
TRITONSERVER_REQUEST_RELEASE_ALL /* request release flags */);
PreferredMemory(PreferredMemory::DEFAULT, 0), trace);
} else {
infer_request = std::make_unique<InferRequest>(
id, correlation_id, pb_input_tensors, requested_output_names,
model_state->Name(), model_state->Version(), parameters_string, flags,
0 /* BLS request timeout*/, 0 /* response_factory_address */,
reinterpret_cast<intptr_t>(request),
PreferredMemory(PreferredMemory::DEFAULT, 0), trace,
TRITONSERVER_REQUEST_RELEASE_ALL /* request release flags */);
PreferredMemory(PreferredMemory::DEFAULT, 0), trace);
}

RETURN_IF_EXCEPTION(infer_request->SaveToSharedMemory(Stub()->ShmPool()));
Expand Down Expand Up @@ -1366,6 +1364,11 @@ ModelInstanceState::ProcessRequestsDecoupled(
TRITONSERVER_ERROR_INTERNAL, error->String().c_str());
}

// Reset the release flags for all the requests.
for (auto& infer_request : pb_infer_requests) {
infer_request->SetReleaseFlags(TRITONSERVER_REQUEST_RELEASE_ALL);
}

return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INTERNAL, "Failed to process the requests.");
}
Expand Down Expand Up @@ -2502,7 +2505,7 @@ TRITONBACKEND_ModelInstanceExecute(
(std::string("Failed to release request: ") + pb_exception.what())
.c_str());
if (request_release_flags[r] == TRITONSERVER_REQUEST_RELEASE_RESCHEDULE) {
// If error occurs during request reschedule, release the request with
// If error occurs during request rescheduling, release the request with
// `TRITONSERVER_REQUEST_RELEASE_ALL` flag.
LOG_IF_ERROR(
TRITONBACKEND_RequestRelease(
Expand Down
1 change: 1 addition & 0 deletions src/resources/triton_python_backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,4 +606,5 @@ def set_model_transaction_policy(self, transaction_policy_dict):
TRITONSERVER_REQUEST_FLAG_SEQUENCE_START = 1
TRITONSERVER_REQUEST_FLAG_SEQUENCE_END = 2
TRITONSERVER_RESPONSE_COMPLETE_FINAL = 1
TRITONSERVER_REQUEST_RELEASE_ALL = 1
TRITONSERVER_REQUEST_RELEASE_RESCHEDULE = 2

0 comments on commit 8ab6186

Please sign in to comment.