From eb5180736eb2aed174d3f1cc603e27d5ce87ba27 Mon Sep 17 00:00:00 2001 From: Tanmay Verma Date: Tue, 15 Aug 2023 12:46:04 -0700 Subject: [PATCH] Test preserve_ordering for oldest strategy sequence batcher (#6185) (#6191) Co-authored-by: Ryan McCormick --- .../sequence_batcher_test.py | 182 ++++++++++++++++++ qa/L0_sequence_batcher/test.sh | 81 ++++++++ qa/python_models/sequence_py/config.pbtxt | 53 +++++ qa/python_models/sequence_py/model.py | 93 +++++++++ 4 files changed, 409 insertions(+) create mode 100644 qa/python_models/sequence_py/config.pbtxt create mode 100644 qa/python_models/sequence_py/model.py diff --git a/qa/L0_sequence_batcher/sequence_batcher_test.py b/qa/L0_sequence_batcher/sequence_batcher_test.py index 1c238d151d..3e6cfc032a 100755 --- a/qa/L0_sequence_batcher/sequence_batcher_test.py +++ b/qa/L0_sequence_batcher/sequence_batcher_test.py @@ -31,10 +31,12 @@ sys.path.append("../common") import os +import random import threading import time import unittest from builtins import str +from functools import partial import numpy as np import sequence_util as su @@ -3432,5 +3434,185 @@ def test_send_request_after_timeout(self): raise last_err +class SequenceBatcherPreserveOrderingTest(su.SequenceBatcherTestUtil): + def setUp(self): + super().setUp() + # By default, find tritonserver on "localhost", but can be overridden + # with TRITONSERVER_IPADDR envvar + self.server_address_ = ( + os.environ.get("TRITONSERVER_IPADDR", "localhost") + ":8001" + ) + + # Prepare input and expected output based on the model and + # the infer sequence sent for testing. If the test is to be extended + # for different sequence and model, then proper grouping should be added + self.model_name_ = "sequence_py" + self.tensor_data_ = np.ones(shape=[1, 1], dtype=np.int32) + self.inputs_ = [grpcclient.InferInput("INPUT0", [1, 1], "INT32")] + self.inputs_[0].set_data_from_numpy(self.tensor_data_) + self.triton_client = grpcclient.InferenceServerClient(self.server_address_) + + # Atomic request ID for multi-threaded inference + self.request_id_lock = threading.Lock() + self.request_id = 1 + + def send_sequence(self, seq_id, seq_id_map, req_id_map): + if seq_id not in seq_id_map: + seq_id_map[seq_id] = [] + + start, middle, end = (True, False), (False, False), (False, True) + # Send sequence with 1 start, 1 middle, and 1 end request + seq_flags = [start, middle, end] + for start_flag, end_flag in seq_flags: + # Introduce random sleep to better interweave requests from different sequences + time.sleep(random.uniform(0.0, 1.0)) + + # Serialize sending requests to ensure ordered request IDs + with self.request_id_lock: + req_id = self.request_id + self.request_id += 1 + + # Store metadata to validate results later + req_id_map[req_id] = seq_id + seq_id_map[seq_id].append(req_id) + + self.triton_client.async_stream_infer( + self.model_name_, + self.inputs_, + sequence_id=seq_id, + sequence_start=start_flag, + sequence_end=end_flag, + timeout=None, + request_id=str(req_id), + ) + + def _test_sequence_ordering(self, preserve_ordering, decoupled): + # 1. Send a few grpc streaming sequence requests to the model. + # 2. With grpc streaming, the model should receive the requests in + # the same order they are sent from client, and the client should + # receive the responses in the same order sent back by the + # model/server. With sequence scheduler, the requests for each sequence should be routed to the same model + # instance, and no two requests from the same sequence should + # get batched together. + # 3. With preserve_ordering=False, we may get the responses back in a different + # order than the requests, but with grpc streaming we should still expect responses for each sequence to be ordered. + # 4. Assert that the sequence values are ordered, and that the response IDs per sequence are ordered + class SequenceResult: + def __init__(self, seq_id, result, request_id): + self.seq_id = seq_id + self.result = result + self.request_id = int(request_id) + + def full_callback(sequence_dict, sequence_list, result, error): + # We expect no model errors for this test + if error: + self.assertTrue(False, error) + + # Gather all the necessary metadata for validation + request_id = int(result.get_response().id) + sequence_id = request_id_map[request_id] + # Overall list of results in the order received, regardless of sequence ID + sequence_list.append(SequenceResult(sequence_id, result, request_id)) + # Ordered results organized by their seq IDs + sequence_dict[sequence_id].append(result) + + # Store ordered list in which responses are received by client + sequence_list = [] + # Store mapping of sequence ID to response results + sequence_dict = {} + # Store mapping of sequence ID to request IDs and vice versa + sequence_id_map = {} + request_id_map = {} + + # Start stream + seq_callback = partial(full_callback, sequence_dict, sequence_list) + self.triton_client.start_stream(callback=seq_callback) + + # Send N sequences concurrently + threads = [] + num_sequences = 10 + for i in range(num_sequences): + # Sequence IDs are 1-indexed + sequence_id = i + 1 + # Add a result list and callback for each sequence + sequence_dict[sequence_id] = [] + threads.append( + threading.Thread( + target=self.send_sequence, + args=(sequence_id, sequence_id_map, request_id_map), + ) + ) + + # Start all sequence threads + for t in threads: + t.start() + + # Wait for threads to return + for t in threads: + t.join() + + # Block until all requests are completed + self.triton_client.stop_stream() + + # Make sure some inferences occurred and metadata was collected + self.assertGreater(len(sequence_dict), 0) + self.assertGreater(len(sequence_list), 0) + + # Validate model results are sorted per sequence ID (model specific logic) + print(f"=== {preserve_ordering=} {decoupled=} ===") + print("Outputs per Sequence:") + for seq_id, sequence in sequence_dict.items(): + seq_outputs = [ + result.as_numpy("OUTPUT0").flatten().tolist() for result in sequence + ] + print(f"{seq_id}: {seq_outputs}") + self.assertEqual(seq_outputs, sorted(seq_outputs)) + + # Validate request/response IDs for each response in a sequence is sorted + # This should be true regardless of preserve_ordering or not + print("Request IDs per Sequence:") + for seq_id in sequence_id_map: + per_seq_request_ids = sequence_id_map[seq_id] + print(f"{seq_id}: {per_seq_request_ids}") + self.assertEqual(per_seq_request_ids, sorted(per_seq_request_ids)) + + # Validate results are sorted in request order if preserve_ordering is True + if preserve_ordering: + request_ids = [s.request_id for s in sequence_list] + print(f"Request IDs overall:\n{request_ids}") + sequence_ids = [s.seq_id for s in sequence_list] + print(f"Sequence IDs overall:\n{sequence_ids}") + self.assertEqual(request_ids, sorted(request_ids)) + + # Assert some dynamic batching of requests was done + stats = self.triton_client.get_inference_statistics( + model_name=self.model_name_, headers={}, as_json=True + ) + model_stats = stats["model_stats"][0] + self.assertEqual(model_stats["name"], self.model_name_) + self.assertLess( + int(model_stats["execution_count"]), int(model_stats["inference_count"]) + ) + + def test_sequence_with_preserve_ordering(self): + self.model_name_ = "seqpy_preserve_ordering_nondecoupled" + self._test_sequence_ordering(preserve_ordering=True, decoupled=False) + + def test_sequence_without_preserve_ordering(self): + self.model_name_ = "seqpy_no_preserve_ordering_nondecoupled" + self._test_sequence_ordering(preserve_ordering=False, decoupled=False) + + # FIXME [DLIS-5280]: This may fail for decoupled models if writes to GRPC + # stream are done out of order in server, so disable test for now. + # def test_sequence_with_preserve_ordering_decoupled(self): + # self.model_name_ = "seqpy_preserve_ordering_decoupled" + # self._test_sequence_ordering(preserve_ordering=True, decoupled=True) + + # FIXME [DLIS-5280] + # def test_sequence_without_preserve_ordering_decoupled(self): + # self.model_name_ = "seqpy_no_preserve_ordering_decoupled" + # self._test_sequence_ordering(preserve_ordering=False, decoupled=True) + + if __name__ == "__main__": unittest.main() diff --git a/qa/L0_sequence_batcher/test.sh b/qa/L0_sequence_batcher/test.sh index 3dabfaba7a..0889a602e1 100755 --- a/qa/L0_sequence_batcher/test.sh +++ b/qa/L0_sequence_batcher/test.sh @@ -90,12 +90,14 @@ TF_VERSION=${TF_VERSION:=2} # On windows the paths invoked by the script (running in WSL) must use # /mnt/c when needed but the paths on the tritonserver command-line # must be C:/ style. +WINDOWS=0 if [[ "$(< /proc/sys/kernel/osrelease)" == *microsoft* ]]; then MODELDIR=${MODELDIR:=C:/models} DATADIR=${DATADIR:="/mnt/c/data/inferenceserver/${REPO_VERSION}"} BACKEND_DIR=${BACKEND_DIR:=C:/tritonserver/backends} SERVER=${SERVER:=/mnt/c/tritonserver/bin/tritonserver.exe} export WSLENV=$WSLENV:TRITONSERVER_DELAY_SCHEDULER:TRITONSERVER_BACKLOG_DELAY_SCHEDULER + WINDOWS=1 else MODELDIR=${MODELDIR:=`pwd`} DATADIR=${DATADIR:="/data/inferenceserver/${REPO_VERSION}"} @@ -800,6 +802,85 @@ if [ "$TEST_SYSTEM_SHARED_MEMORY" -ne 1 ] && [ "$TEST_CUDA_SHARED_MEMORY" -ne 1 set -e fi +### Start Preserve Ordering Tests ### + +# Test only supported on windows currently due to use of python backend models +if [ ${WINDOWS} -ne 1 ]; then + # Test preserve ordering true/false and decoupled/non-decoupled + TEST_CASE=SequenceBatcherPreserveOrderingTest + MODEL_PATH=preserve_ordering_models + BASE_MODEL="../python_models/sequence_py" + rm -rf ${MODEL_PATH} + + # FIXME [DLIS-5280]: This may fail for decoupled models if writes to GRPC + # stream are done out of order in server, so decoupled tests are disabled. + MODES="decoupled nondecoupled" + for mode in $MODES; do + NO_PRESERVE="${MODEL_PATH}/seqpy_no_preserve_ordering_${mode}" + mkdir -p ${NO_PRESERVE}/1 + cp ${BASE_MODEL}/config.pbtxt ${NO_PRESERVE} + cp ${BASE_MODEL}/model.py ${NO_PRESERVE}/1 + + PRESERVE="${MODEL_PATH}/seqpy_preserve_ordering_${mode}" + cp -r ${NO_PRESERVE} ${PRESERVE} + sed -i "s/^preserve_ordering: False/preserve_ordering: True/" ${PRESERVE}/config.pbtxt + + if [ ${mode} == "decoupled" ]; then + echo -e "\nmodel_transaction_policy { decoupled: true }" >> ${NO_PRESERVE}/config.pbtxt + echo -e "\nmodel_transaction_policy { decoupled: true }" >> ${PRESERVE}/config.pbtxt + fi + done + + SERVER_ARGS="--model-repository=$MODELDIR/$MODEL_PATH ${SERVER_ARGS_EXTRA}" + SERVER_LOG="./$TEST_CASE.$MODEL_PATH.server.log" + + if [ "$TEST_VALGRIND" -eq 1 ]; then + LEAKCHECK_LOG="./$i.$MODEL_PATH.valgrind.log" + LEAKCHECK_ARGS="$LEAKCHECK_ARGS_BASE --log-file=$LEAKCHECK_LOG" + run_server_leakcheck + else + run_server + fi + + if [ "$SERVER_PID" == "0" ]; then + echo -e "\n***\n*** Failed to start $SERVER\n***" + cat $SERVER_LOG + exit 1 + fi + + echo "Test: $TEST_CASE, repository $MODEL_PATH" >>$CLIENT_LOG + + set +e + python3 $BATCHER_TEST $TEST_CASE >>$CLIENT_LOG 2>&1 + if [ $? -ne 0 ]; then + echo -e "\n***\n*** Test $TEST_CASE Failed\n***" >>$CLIENT_LOG + echo -e "\n***\n*** Test $TEST_CASE Failed\n***" + RET=1 + else + # 2 for preserve_ordering = True/False + check_test_results $TEST_RESULT_FILE 2 + if [ $? -ne 0 ]; then + cat $CLIENT_LOG + echo -e "\n***\n*** Test Result Verification Failed\n***" + RET=1 + fi + fi + set -e + + kill_server + + set +e + if [ "$TEST_VALGRIND" -eq 1 ]; then + python3 ../common/check_valgrind_log.py -f $LEAKCHECK_LOG + if [ $? -ne 0 ]; then + RET=1 + fi + fi + set -e +fi + +### End Preserve Ordering Tests ### + if [ $RET -eq 0 ]; then echo -e "\n***\n*** Test Passed\n***" else diff --git a/qa/python_models/sequence_py/config.pbtxt b/qa/python_models/sequence_py/config.pbtxt new file mode 100644 index 0000000000..b58796058d --- /dev/null +++ b/qa/python_models/sequence_py/config.pbtxt @@ -0,0 +1,53 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +backend: "python" +max_batch_size: 4 + +input [ + { + name: "INPUT0" + data_type: TYPE_INT32 + dims: [ 1 ] + + } +] +output [ + { + name: "OUTPUT0" + data_type: TYPE_INT32 + dims: [ 1 ] + } +] + +sequence_batching { + oldest { + max_candidate_sequences: 4 + max_queue_delay_microseconds: 1000000 + preserve_ordering: False + } + max_sequence_idle_microseconds: 10000000 +} diff --git a/qa/python_models/sequence_py/model.py b/qa/python_models/sequence_py/model.py new file mode 100644 index 0000000000..b375af3e30 --- /dev/null +++ b/qa/python_models/sequence_py/model.py @@ -0,0 +1,93 @@ +# Copyright 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import json + +import numpy as np +import triton_python_backend_utils as pb_utils + + +class TritonPythonModel: + def initialize(self, args): + self.model_config = json.loads(args["model_config"]) + self.sequences = {} + self.decoupled = self.model_config.get("model_transaction_policy", {}).get( + "decoupled" + ) + + def get_next_sequence_output_tensor(self, request): + sid = request.correlation_id() + flags = request.flags() + if flags == pb_utils.TRITONSERVER_REQUEST_FLAG_SEQUENCE_START: + if sid in self.sequences: + raise pb_utils.TritonModelException( + "Can't start a new sequence with existing ID" + ) + self.sequences[sid] = [1] + else: + if sid not in self.sequences: + raise pb_utils.TritonModelException( + "Need START flag for a sequence ID that doesn't already exist." + ) + + last = self.sequences[sid][-1] + self.sequences[sid].append(last + 1) + + output = self.sequences[sid][-1] + output = np.array([output]) + out_tensor = pb_utils.Tensor("OUTPUT0", output.astype(np.int32)) + return out_tensor + + def execute(self, requests): + if self.decoupled: + return self.execute_decoupled(requests) + else: + return self.execute_non_decoupled(requests) + + def execute_non_decoupled(self, requests): + responses = [] + for request in requests: + output_tensor = self.get_next_sequence_output_tensor(request) + response = pb_utils.InferenceResponse([output_tensor]) + responses.append(response) + return responses + + def execute_decoupled(self, requests): + for request in requests: + sender = request.get_response_sender() + output_tensor = self.get_next_sequence_output_tensor(request) + + # Send 3 responses per request + for _ in range(3): + response = pb_utils.InferenceResponse([output_tensor]) + sender.send(response) + + sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) + + return None + + def finalize(self): + print(f"Cleaning up. Final sequences stored: {self.sequences}")