Skip to content

Commit

Permalink
fix: Handling grpc cancellation edge-case:: Cancelling at step START (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
oandreeva-nv authored Jun 6, 2024
1 parent 1f68c0d commit 42742a3
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 2 deletions.
57 changes: 56 additions & 1 deletion qa/L0_request_cancellation/grpc_cancellation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import asyncio
import queue
import re
import time
import unittest
from functools import partial
Expand Down Expand Up @@ -62,6 +63,7 @@ def setUp(self):
self._callback = partial(callback, self._user_data)
self._prepare_request()
self._start_time = time.time() # seconds
self.test_duration_delta = 0.5

def tearDown(self):
self._end_time = time.time() # seconds
Expand All @@ -75,7 +77,7 @@ def _prepare_request(self):
self._inputs[0].set_data_from_numpy(np.array([[10]], dtype=np.int32))

def _assert_max_duration(self):
max_duration = self._model_delay * 0.5 # seconds
max_duration = self._model_delay * self.test_duration_delta # seconds
duration = self._end_time - self._start_time # seconds
self.assertLess(
duration,
Expand Down Expand Up @@ -136,6 +138,59 @@ async def requests_generator():
async for result, error in responses_iterator:
self._callback(result, error)

def test_grpc_async_infer_cancellation_at_step_start(self):
# This is a longer test
self.test_duration_delta = 4.5
server_log_name = "grpc_cancellation_test.test_grpc_async_infer_cancellation_at_step_start.server.log"
with open(server_log_name, "r") as f:
server_log = f.read()

prev_new_req_handl_count = len(
re.findall("New request handler for ModelInferHandler", server_log)
)
self.assertEqual(
prev_new_req_handl_count,
2,
"Expected 2 request handler for ModelInferHandler log entries, but got {}".format(
prev_new_req_handl_count
),
)
future = self._client.async_infer(
model_name=self._model_name,
inputs=self._inputs,
callback=self._callback,
outputs=self._outputs,
)
time.sleep(2) # ensure the inference request reached server
future.cancel()
# ensures TRITONSERVER_DELAY_GRPC_PROCESS delay passed on the server
time.sleep(self._model_delay * 2)

with open(server_log_name, "r") as f:
server_log = f.read()

cancel_at_start_count = len(
re.findall(
r"Cancellation notification received for ModelInferHandler, rpc_ok=1, context \d+, \d+ step START",
server_log,
)
)
cur_new_req_handl_count = len(
re.findall("New request handler for ModelInferHandler", server_log)
)
self.assertEqual(
cancel_at_start_count,
2,
"Expected 2 cancellation at step START log entries, but got {}".format(
cancel_at_start_count
),
)
self.assertGreater(
cur_new_req_handl_count,
prev_new_req_handl_count,
"gRPC Cancellation on step START Test Failed: New request handler for ModelInferHandler was not created",
)


if __name__ == "__main__":
unittest.main()
9 changes: 8 additions & 1 deletion qa/L0_request_cancellation/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,13 @@ mkdir -p models/custom_identity_int32/1 && (cd models/custom_identity_int32 && \
echo 'instance_group [{ kind: KIND_CPU }]' >> config.pbtxt && \
echo -e 'parameters [{ key: "execute_delay_ms" \n value: { string_value: "10000" } }]' >> config.pbtxt)

for TEST_CASE in "test_grpc_async_infer" "test_grpc_stream_infer" "test_aio_grpc_async_infer" "test_aio_grpc_stream_infer"; do
for TEST_CASE in "test_grpc_async_infer" "test_grpc_stream_infer" "test_aio_grpc_async_infer" "test_aio_grpc_stream_infer" "test_grpc_async_infer_cancellation_at_step_start"; do

TEST_LOG="./grpc_cancellation_test.$TEST_CASE.log"
SERVER_LOG="grpc_cancellation_test.$TEST_CASE.server.log"
if [ "$TEST_CASE" == "test_grpc_async_infer_cancellation_at_step_start" ]; then
export TRITONSERVER_DELAY_GRPC_PROCESS=5000
fi

SERVER_ARGS="--model-repository=`pwd`/models --log-verbose=1"
run_server
Expand All @@ -108,6 +111,10 @@ for TEST_CASE in "test_grpc_async_infer" "test_grpc_stream_infer" "test_aio_grpc

kill $SERVER_PID
wait $SERVER_PID

if [ "$TEST_CASE" == "test_grpc_async_infer_cancellation_at_step_start" ]; then
unset TRITONSERVER_DELAY_GRPC_PROCESS
fi
done

#
Expand Down
26 changes: 26 additions & 0 deletions src/grpc/infer_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -691,9 +691,35 @@ ModelInferHandler::Process(InferHandler::State* state, bool rpc_ok)
// Need to protect the state transitions for these cases.
std::lock_guard<std::recursive_mutex> lock(state->step_mtx_);

if (state->delay_process_ms_ != 0) {
// Will delay the Process execution by the specified time.
// This can be used to test the flow when cancellation request
// issued for the request, which is still at START step.
LOG_INFO << "Delaying the write of the response by "
<< state->delay_process_ms_ << " ms...";
std::this_thread::sleep_for(
std::chrono::milliseconds(state->delay_process_ms_));
}

// Handle notification for cancellation which can be raised
// asynchronously if detected on the network.
if (state->IsGrpcContextCancelled()) {
if (rpc_ok && (state->step_ == Steps::START) &&
(state->context_->step_ != Steps::CANCELLED)) {
#ifdef TRITON_ENABLE_TRACING
// Can't create trace as we don't know the model to be requested,
// track timestamps in 'state'
state->trace_timestamps_.emplace_back(std::make_pair(
"GRPC_WAITREAD_END", TraceManager::CaptureTimestamp()));
#endif // TRITON_ENABLE_TRACING
// Need to create a new request object here explicitly for step START,
// because we will never leave this if body. Refer to PR 7325.
// This is a special case for ModelInferHandler, since we have 2 threads,
// and each of them can process cancellation. ModelStreamInfer has only 1
// thread, and cancellation at step START was not reproducible in a
// single thread scenario.
StartNewRequest();
}
bool resume = state->context_->HandleCancellation(state, rpc_ok, Name());
return resume;
}
Expand Down
6 changes: 6 additions & 0 deletions src/grpc/infer_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -1024,6 +1024,11 @@ class InferHandlerState {
if (cstr != nullptr) {
delay_complete_ms_ = atoi(cstr);
}
const char* pstr = getenv("TRITONSERVER_DELAY_GRPC_PROCESS");
delay_process_ms_ = 0;
if (pstr != nullptr) {
delay_process_ms_ = atoi(pstr);
}

response_queue_.reset(new ResponseQueue<ResponseType>());
Reset(context, start_step);
Expand Down Expand Up @@ -1120,6 +1125,7 @@ class InferHandlerState {
// For testing and debugging
int delay_response_ms_;
int delay_complete_ms_;
int delay_process_ms_;

// For inference requests the allocator payload, unused for other
// requests.
Expand Down

0 comments on commit 42742a3

Please sign in to comment.