Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Handling grpc cancellation edge-case:: Cancelling at step START #7325

Merged
merged 9 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion qa/L0_request_cancellation/grpc_cancellation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import asyncio
import queue
import re
import time
import unittest
from functools import partial
Expand Down Expand Up @@ -62,6 +63,7 @@ def setUp(self):
self._callback = partial(callback, self._user_data)
self._prepare_request()
self._start_time = time.time() # seconds
self.test_duration_delta = 0.5

def tearDown(self):
self._end_time = time.time() # seconds
Expand All @@ -75,7 +77,7 @@ def _prepare_request(self):
self._inputs[0].set_data_from_numpy(np.array([[10]], dtype=np.int32))

def _assert_max_duration(self):
max_duration = self._model_delay * 0.5 # seconds
max_duration = self._model_delay * self.test_duration_delta # seconds
duration = self._end_time - self._start_time # seconds
self.assertLess(
duration,
Expand Down Expand Up @@ -136,6 +138,59 @@ async def requests_generator():
async for result, error in responses_iterator:
self._callback(result, error)

def test_grpc_async_infer_cancellation_at_step_start(self):
# This is a longer test
self.test_duration_delta = 4.5
server_log_name = "grpc_cancellation_test.test_grpc_async_infer_cancellation_at_step_start.server.log"
with open(server_log_name, "r") as f:
server_log = f.read()

prev_new_req_handl_count = len(
re.findall("New request handler for ModelInferHandler", server_log)
)
self.assertEqual(
prev_new_req_handl_count,
2,
"Expected 2 request handler for ModelInferHandler log entries, but got {}".format(
prev_new_req_handl_count
),
)
future = self._client.async_infer(
model_name=self._model_name,
inputs=self._inputs,
callback=self._callback,
outputs=self._outputs,
)
time.sleep(2) # ensure the inference request reached server
future.cancel()
# ensures TRITONSERVER_DELAY_GRPC_PROCESS delay passed on the server
time.sleep(self._model_delay * 2)

with open(server_log_name, "r") as f:
server_log = f.read()

cancel_at_start_count = len(
re.findall(
r"Cancellation notification received for ModelInferHandler, rpc_ok=1, context \d+, \d+ step START",
server_log,
)
)
cur_new_req_handl_count = len(
re.findall("New request handler for ModelInferHandler", server_log)
)
self.assertEqual(
cancel_at_start_count,
2,
"Expected 2 cancellation at step START log entries, but got {}".format(
cancel_at_start_count
),
)
self.assertGreater(
cur_new_req_handl_count,
prev_new_req_handl_count,
"gRPC Cancellation on step START Test Failed: New request handler for ModelInferHandler was not created",
)


if __name__ == "__main__":
unittest.main()
9 changes: 8 additions & 1 deletion qa/L0_request_cancellation/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,13 @@ mkdir -p models/custom_identity_int32/1 && (cd models/custom_identity_int32 && \
echo 'instance_group [{ kind: KIND_CPU }]' >> config.pbtxt && \
echo -e 'parameters [{ key: "execute_delay_ms" \n value: { string_value: "10000" } }]' >> config.pbtxt)

for TEST_CASE in "test_grpc_async_infer" "test_grpc_stream_infer" "test_aio_grpc_async_infer" "test_aio_grpc_stream_infer"; do
for TEST_CASE in "test_grpc_async_infer" "test_grpc_stream_infer" "test_aio_grpc_async_infer" "test_aio_grpc_stream_infer" "test_grpc_async_infer_cancellation_at_step_start"; do

TEST_LOG="./grpc_cancellation_test.$TEST_CASE.log"
SERVER_LOG="grpc_cancellation_test.$TEST_CASE.server.log"
if [ "$TEST_CASE" == "test_grpc_async_infer_cancellation_at_step_start" ]; then
export TRITONSERVER_DELAY_GRPC_PROCESS=5000
fi

SERVER_ARGS="--model-repository=`pwd`/models --log-verbose=1"
run_server
Expand All @@ -108,6 +111,10 @@ for TEST_CASE in "test_grpc_async_infer" "test_grpc_stream_infer" "test_aio_grpc

kill $SERVER_PID
wait $SERVER_PID

if [ "$TEST_CASE" == "test_grpc_async_infer_cancellation_at_step_start" ]; then
unset TRITONSERVER_DELAY_GRPC_PROCESS
fi
done

#
Expand Down
26 changes: 26 additions & 0 deletions src/grpc/infer_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -691,9 +691,35 @@ ModelInferHandler::Process(InferHandler::State* state, bool rpc_ok)
// Need to protect the state transitions for these cases.
std::lock_guard<std::recursive_mutex> lock(state->step_mtx_);

if (state->delay_process_ms_ != 0) {
// Will delay the Process execution by the specified time.
// This can be used to test the flow when cancellation request
// issued for the request, which is still at START step.
LOG_INFO << "Delaying the write of the response by "
<< state->delay_process_ms_ << " ms...";
std::this_thread::sleep_for(
std::chrono::milliseconds(state->delay_process_ms_));
}

// Handle notification for cancellation which can be raised
// asynchronously if detected on the network.
if (state->IsGrpcContextCancelled()) {
if (rpc_ok && (state->step_ == Steps::START) &&
(state->context_->step_ != Steps::CANCELLED)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: what does the second clause here imply?

!= Steps::Cancelled ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid calling StartNewRequest twice, at first we fall into HandleCancellation and go through this block, which returns true for resume, so we will go into if (state->IsGrpcContextCancelled()) loop for the second time but this time state->context_->step_ is CANCELLED

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Late to the game, but what is the reasoning of not moving the original "StartNewRequest() if at START" to before handling the cancellation? Although I think other code needs to be moved around as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not 100% aware of all underlying processes, meaning state->step_ and state->context_->step_ combinations. This change helps to address the bug with known symptoms. Refactoring if the Process logic needs proper time and testing IMHO

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kthui thoughts? If feasible, this can be done as follow-up and by someone else. Want to make sure if there is room for improvement.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think there is definitely room for improvement/refactoring, i.e. I think the if (shutdown) { ... } could also be moved into the if (state->step_ == Steps::START) { ... } else ... block, so all procedures for Steps::START would be inside the if (state->step_ == Steps::START) { ... } block, but it can be done as a follow-up later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Jira ticket: DLIS-6831

#ifdef TRITON_ENABLE_TRACING
// Can't create trace as we don't know the model to be requested,
// track timestamps in 'state'
state->trace_timestamps_.emplace_back(std::make_pair(
"GRPC_WAITREAD_END", TraceManager::CaptureTimestamp()));
#endif // TRITON_ENABLE_TRACING
// Need to create a new request object here explicitly for step START,
// because we will never leave this if body. Refer to PR 7325.
// This is a special case for ModelInferHandler, since we have 2 threads,
// and each of them can process cancellation. ModelStreamInfer has only 1
// thread, and cancellation at step START was not reproducible in a
// single thread scenario.
StartNewRequest();
}
bool resume = state->context_->HandleCancellation(state, rpc_ok, Name());
return resume;
}
Expand Down
6 changes: 6 additions & 0 deletions src/grpc/infer_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -1024,6 +1024,11 @@ class InferHandlerState {
if (cstr != nullptr) {
delay_complete_ms_ = atoi(cstr);
}
const char* pstr = getenv("TRITONSERVER_DELAY_GRPC_PROCESS");
delay_process_ms_ = 0;
if (pstr != nullptr) {
delay_process_ms_ = atoi(pstr);
}

response_queue_.reset(new ResponseQueue<ResponseType>());
Reset(context, start_step);
Expand Down Expand Up @@ -1120,6 +1125,7 @@ class InferHandlerState {
// For testing and debugging
int delay_response_ms_;
int delay_complete_ms_;
int delay_process_ms_;

// For inference requests the allocator payload, unused for other
// requests.
Expand Down
Loading