diff --git a/src/pb_stub.cc b/src/pb_stub.cc index a9a910a1..13ce7d7a 100644 --- a/src/pb_stub.cc +++ b/src/pb_stub.cc @@ -533,6 +533,8 @@ Stub::Initialize(bi::managed_external_buffer::handle_t map_handle) c_python_backend_utils.attr("InferenceResponse")); c_python_backend_utils.attr("shared_memory") = py::cast(shm_pool_.get()); + async_event_loop_ = py::none(); + py::object TritonPythonModel = sys.attr("TritonPythonModel"); deserialize_bytes_ = python_backend_utils.attr("deserialize_bytes_tensor"); serialize_bytes_ = python_backend_utils.attr("serialize_byte_tensor"); @@ -690,11 +692,18 @@ Stub::ProcessRequestsDecoupled(RequestBatch* request_batch_shm_ptr) py::object execute_return = model_instance_.attr("execute")(py_request_list); - if (!py::isinstance(execute_return)) { - throw PythonBackendException( - "Python model '" + name_ + - "' is using the decoupled mode and the execute function must " - "return None."); + bool is_coroutine = py::module::import("asyncio") + .attr("iscoroutine")(execute_return) + .cast(); + if (is_coroutine) { + RunCoroutine(execute_return); + } else { + if (!py::isinstance(execute_return)) { + throw PythonBackendException( + "Python model '" + name_ + + "' is using the decoupled mode and the execute function must " + "return None."); + } } } } @@ -870,6 +879,60 @@ Stub::ProcessRequests(RequestBatch* request_batch_shm_ptr) } } +py::object +Stub::GetAsyncEventLoop() +{ + if (py::isinstance(async_event_loop_)) { + // Create the event loop if not already. + async_event_loop_ = py::module_::import("asyncio").attr("new_event_loop")(); + py::object py_thread = + py::module_::import("threading") + .attr("Thread")( + "target"_a = async_event_loop_.attr("run_forever"), + "daemon"_a = true); + py_thread.attr("start")(); + } + return async_event_loop_; +} + +py::object +Stub::RunCoroutine(py::object coroutine) +{ + py::object loop = GetAsyncEventLoop(); + py::object py_future = py::module_::import("asyncio").attr( + "run_coroutine_threadsafe")(coroutine, loop); + + { + std::lock_guard lock(async_event_futures_mu_); + + std::shared_ptr> shared_future(new std::future()); + std::future c_future = std::async( + std::launch::async, [this, shared_future, py_future]() mutable { + { + py::gil_scoped_acquire gil_acquire; + try { + py_future.attr("result")(); + } + catch (const PythonBackendException& pb_exception) { + LOG_ERROR << pb_exception.what(); + } + catch (const py::error_already_set& error) { + LOG_ERROR << error.what(); + } + py_future = py::none(); + } + { + std::lock_guard lock(async_event_futures_mu_); + async_event_futures_.erase(shared_future); + } + }); + *shared_future = std::move(c_future); + async_event_futures_.emplace(std::move(shared_future)); + } + + return py::none(); +} + void Stub::UpdateHealth() { @@ -881,6 +944,14 @@ void Stub::Finalize() { finalizing_ = true; + // Stop async event loop if created. + if (!py::isinstance(async_event_loop_)) { + if (!async_event_futures_.empty()) { + LOG_ERROR << "Finalizing stub with " << async_event_futures_.size() + << " ongoing coroutines"; + } + async_event_loop_.attr("stop")(); + } // Call finalize if exists. if (initialized_ && py::hasattr(model_instance_, "finalize")) { try { @@ -943,6 +1014,8 @@ Stub::~Stub() { py::gil_scoped_acquire acquire; + async_event_futures_.clear(); + async_event_loop_ = py::none(); model_instance_ = py::none(); } stub_instance_.reset(); diff --git a/src/pb_stub.h b/src/pb_stub.h index a51f25f5..1b11c439 100644 --- a/src/pb_stub.h +++ b/src/pb_stub.h @@ -1,4 +1,4 @@ -// Copyright 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -31,6 +31,9 @@ #include #include +#include +#include +#include #include "infer_request.h" #include "infer_response.h" @@ -255,6 +258,10 @@ class Stub { void ProcessRequestsDecoupled(RequestBatch* request_batch_shm_ptr); + py::object GetAsyncEventLoop(); + + py::object RunCoroutine(py::object coroutine); + /// Get the memory manager message queue std::unique_ptr>& MemoryManagerQueue(); @@ -363,6 +370,9 @@ class Stub { py::object model_instance_; py::object deserialize_bytes_; py::object serialize_bytes_; + py::object async_event_loop_; + std::unordered_set>> async_event_futures_; + std::mutex async_event_futures_mu_; std::unique_ptr> stub_message_queue_; std::unique_ptr>