diff --git a/qa/L0_request_cancellation/grpc_cancellation_test.py b/qa/L0_request_cancellation/grpc_cancellation_test.py index fadaa291e8..4b103e21e1 100755 --- a/qa/L0_request_cancellation/grpc_cancellation_test.py +++ b/qa/L0_request_cancellation/grpc_cancellation_test.py @@ -28,6 +28,7 @@ import asyncio import queue +import re import time import unittest from functools import partial @@ -62,6 +63,7 @@ def setUp(self): self._callback = partial(callback, self._user_data) self._prepare_request() self._start_time = time.time() # seconds + self.test_duration_delta = 0.5 def tearDown(self): self._end_time = time.time() # seconds @@ -75,7 +77,7 @@ def _prepare_request(self): self._inputs[0].set_data_from_numpy(np.array([[10]], dtype=np.int32)) def _assert_max_duration(self): - max_duration = self._model_delay * 0.5 # seconds + max_duration = self._model_delay * self.test_duration_delta # seconds duration = self._end_time - self._start_time # seconds self.assertLess( duration, @@ -136,6 +138,59 @@ async def requests_generator(): async for result, error in responses_iterator: self._callback(result, error) + def test_grpc_async_infer_cancellation_at_step_start(self): + # This is a longer test + self.test_duration_delta = 4.5 + server_log_name = "grpc_cancellation_test.test_grpc_async_infer_cancellation_at_step_start.server.log" + with open(server_log_name, "r") as f: + server_log = f.read() + + prev_new_req_handl_count = len( + re.findall("New request handler for ModelInferHandler", server_log) + ) + self.assertEqual( + prev_new_req_handl_count, + 2, + "Expected 2 request handler for ModelInferHandler log entries, but got {}".format( + prev_new_req_handl_count + ), + ) + future = self._client.async_infer( + model_name=self._model_name, + inputs=self._inputs, + callback=self._callback, + outputs=self._outputs, + ) + time.sleep(2) # ensure the inference request reached server + future.cancel() + # ensures TRITONSERVER_DELAY_GRPC_PROCESS delay passed on the server + time.sleep(self._model_delay * 2) + + with open(server_log_name, "r") as f: + server_log = f.read() + + cancel_at_start_count = len( + re.findall( + r"Cancellation notification received for ModelInferHandler, rpc_ok=1, context \d+, \d+ step START", + server_log, + ) + ) + cur_new_req_handl_count = len( + re.findall("New request handler for ModelInferHandler", server_log) + ) + self.assertEqual( + cancel_at_start_count, + 2, + "Expected 2 cancellation at step START log entries, but got {}".format( + cancel_at_start_count + ), + ) + self.assertGreater( + cur_new_req_handl_count, + prev_new_req_handl_count, + "gRPC Cancellation on step START Test Failed: New request handler for ModelInferHandler was not created", + ) + if __name__ == "__main__": unittest.main() diff --git a/qa/L0_request_cancellation/test.sh b/qa/L0_request_cancellation/test.sh index b43089e3e7..0c9ab74086 100755 --- a/qa/L0_request_cancellation/test.sh +++ b/qa/L0_request_cancellation/test.sh @@ -78,10 +78,13 @@ mkdir -p models/custom_identity_int32/1 && (cd models/custom_identity_int32 && \ echo 'instance_group [{ kind: KIND_CPU }]' >> config.pbtxt && \ echo -e 'parameters [{ key: "execute_delay_ms" \n value: { string_value: "10000" } }]' >> config.pbtxt) -for TEST_CASE in "test_grpc_async_infer" "test_grpc_stream_infer" "test_aio_grpc_async_infer" "test_aio_grpc_stream_infer"; do +for TEST_CASE in "test_grpc_async_infer" "test_grpc_stream_infer" "test_aio_grpc_async_infer" "test_aio_grpc_stream_infer" "test_grpc_async_infer_cancellation_at_step_start"; do TEST_LOG="./grpc_cancellation_test.$TEST_CASE.log" SERVER_LOG="grpc_cancellation_test.$TEST_CASE.server.log" + if [ "$TEST_CASE" == "test_grpc_async_infer_cancellation_at_step_start" ]; then + export TRITONSERVER_DELAY_GRPC_PROCESS=5000 + fi SERVER_ARGS="--model-repository=`pwd`/models --log-verbose=1" run_server @@ -108,6 +111,10 @@ for TEST_CASE in "test_grpc_async_infer" "test_grpc_stream_infer" "test_aio_grpc kill $SERVER_PID wait $SERVER_PID + + if [ "$TEST_CASE" == "test_grpc_async_infer_cancellation_at_step_start" ]; then + unset TRITONSERVER_DELAY_GRPC_PROCESS + fi done # diff --git a/src/grpc/infer_handler.cc b/src/grpc/infer_handler.cc index 4cd16cee16..35659f4900 100644 --- a/src/grpc/infer_handler.cc +++ b/src/grpc/infer_handler.cc @@ -691,9 +691,35 @@ ModelInferHandler::Process(InferHandler::State* state, bool rpc_ok) // Need to protect the state transitions for these cases. std::lock_guard lock(state->step_mtx_); + if (state->delay_process_ms_ != 0) { + // Will delay the Process execution by the specified time. + // This can be used to test the flow when cancellation request + // issued for the request, which is still at START step. + LOG_INFO << "Delaying the write of the response by " + << state->delay_process_ms_ << " ms..."; + std::this_thread::sleep_for( + std::chrono::milliseconds(state->delay_process_ms_)); + } + // Handle notification for cancellation which can be raised // asynchronously if detected on the network. if (state->IsGrpcContextCancelled()) { + if (rpc_ok && (state->step_ == Steps::START) && + (state->context_->step_ != Steps::CANCELLED)) { +#ifdef TRITON_ENABLE_TRACING + // Can't create trace as we don't know the model to be requested, + // track timestamps in 'state' + state->trace_timestamps_.emplace_back(std::make_pair( + "GRPC_WAITREAD_END", TraceManager::CaptureTimestamp())); +#endif // TRITON_ENABLE_TRACING + // Need to create a new request object here explicitly for step START, + // because we will never leave this if body. Refer to PR 7325. + // This is a special case for ModelInferHandler, since we have 2 threads, + // and each of them can process cancellation. ModelStreamInfer has only 1 + // thread, and cancellation at step START was not reproducible in a + // single thread scenario. + StartNewRequest(); + } bool resume = state->context_->HandleCancellation(state, rpc_ok, Name()); return resume; } diff --git a/src/grpc/infer_handler.h b/src/grpc/infer_handler.h index 1e25e72c7f..6ef03807a2 100644 --- a/src/grpc/infer_handler.h +++ b/src/grpc/infer_handler.h @@ -1024,6 +1024,11 @@ class InferHandlerState { if (cstr != nullptr) { delay_complete_ms_ = atoi(cstr); } + const char* pstr = getenv("TRITONSERVER_DELAY_GRPC_PROCESS"); + delay_process_ms_ = 0; + if (pstr != nullptr) { + delay_process_ms_ = atoi(pstr); + } response_queue_.reset(new ResponseQueue()); Reset(context, start_step); @@ -1120,6 +1125,7 @@ class InferHandlerState { // For testing and debugging int delay_response_ms_; int delay_complete_ms_; + int delay_process_ms_; // For inference requests the allocator payload, unused for other // requests.