From bdf227c0bab1e9a86862f24b9cd235c986d3b97f Mon Sep 17 00:00:00 2001 From: GuanLuo <41310872+GuanLuo@users.noreply.github.com> Date: Fri, 6 Oct 2023 08:41:21 -0700 Subject: [PATCH 1/2] Add basic generate endpoints for LLM tasks (#6366) * PoC of parsing request prompt and converting to Triton infer request * Remove extra trace * Add generate endpoint * Enable streaming version * Fix bug * Fix up * Add basic testing. Cherry pick from #6369 * format * Address comment. Fix build * Minor cleanup * cleanup syntax * Wrap error in SSE format * Fix up * Restrict number of response on non-streaming generate * Address comment on implementation. * Re-enable trace on generate endpoint * Add more comprehensive llm endpoint tests (#6377) * Add security policy (#6376) * Start adding some more comprehensive tests * Fix test case * Add response error testing * Complete test placeholder * Address comment * Address comments * Fix code check --------- Co-authored-by: dyastremsky <58150256+dyastremsky@users.noreply.github.com> Co-authored-by: GuanLuo * Address comment * Address comment * Address comment * Fix typo --------- Co-authored-by: Ryan McCormick Co-authored-by: dyastremsky <58150256+dyastremsky@users.noreply.github.com> --- qa/L0_http/generate_endpoint_test.py | 361 +++++ .../generate_models/mock_llm/1/model.py | 104 ++ .../generate_models/mock_llm/config.pbtxt | 60 + qa/L0_http/test.sh | 46 +- src/common.h | 4 +- src/http_server.cc | 1375 +++++++++++++---- src/http_server.h | 155 +- src/test/CMakeLists.txt | 1 + 8 files changed, 1813 insertions(+), 293 deletions(-) create mode 100755 qa/L0_http/generate_endpoint_test.py create mode 100644 qa/L0_http/generate_models/mock_llm/1/model.py create mode 100644 qa/L0_http/generate_models/mock_llm/config.pbtxt diff --git a/qa/L0_http/generate_endpoint_test.py b/qa/L0_http/generate_endpoint_test.py new file mode 100755 index 0000000000..31c987804a --- /dev/null +++ b/qa/L0_http/generate_endpoint_test.py @@ -0,0 +1,361 @@ +#!/usr/bin/python3 +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import sys + +sys.path.append("../common") + +import json +import threading +import time +import unittest + +import requests +import sseclient +import test_util as tu + + +class GenerateEndpointTest(tu.TestResultCollector): + def setUp(self): + self._model_name = "mock_llm" + + def _get_infer_url(self, model_name, route): + return f"http://localhost:8000/v2/models/{model_name}/{route}" + + def generate_stream(self, model_name, inputs, stream=False): + headers = {"Accept": "text/event-stream"} + url = self._get_infer_url(model_name, "generate_stream") + # stream=True used to indicate response can be iterated over, which + # should be the common setting for generate_stream. + # For correctness test case, stream=False so that we can re-examine + # the response content. + return requests.post( + url, + data=inputs if isinstance(inputs, str) else json.dumps(inputs), + headers=headers, + stream=stream, + ) + + def generate(self, model_name, inputs): + url = self._get_infer_url(model_name, "generate") + return requests.post( + url, data=inputs if isinstance(inputs, str) else json.dumps(inputs) + ) + + def generate_expect_failure(self, model_name, inputs, msg): + url = self._get_infer_url(model_name, "generate") + r = requests.post( + url, data=inputs if isinstance(inputs, str) else json.dumps(inputs) + ) + try: + r.raise_for_status() + self.assertTrue(False, f"Expected failure, success for {inputs}") + except requests.exceptions.HTTPError as e: + self.assertIn(msg, r.json()["error"]) + + def generate_stream_expect_failure(self, model_name, inputs, msg): + r = self.generate_stream(model_name, inputs) + try: + r.raise_for_status() + self.assertTrue(False, f"Expected failure, success for {inputs}") + except requests.exceptions.HTTPError as e: + self.assertIn(msg, r.json()["error"]) + + def generate_stream_expect_success( + self, model_name, inputs, expected_output, rep_count + ): + r = self.generate_stream(model_name, inputs) + r.raise_for_status() + self.check_sse_responses(r, [{"TEXT": expected_output}] * rep_count) + + def check_sse_responses(self, res, expected_res): + # Validate SSE format + self.assertIn("Content-Type", res.headers) + self.assertIn("text/event-stream", res.headers["Content-Type"]) + + # SSE format (data: []) is hard to parse, use helper library for simplicity + client = sseclient.SSEClient(res) + res_count = 0 + for event in client.events(): + # Parse event data, join events into a single response + data = json.loads(event.data) + for key, value in expected_res[res_count].items(): + self.assertIn(key, data) + self.assertEqual(value, data[key]) + res_count += 1 + self.assertTrue(len(expected_res), res_count) + # Make sure there is no message in the wrong form + for remaining in client._read(): + self.assertTrue( + remaining.startswith(b"data:"), + f"SSE response not formed properly, got: {remaining}", + ) + self.assertTrue( + remaining.endswith(b"\n\n"), + f"SSE response not formed properly, got: {remaining}", + ) + + def test_generate(self): + # Setup text-based input + text = "hello world" + inputs = {"PROMPT": text, "STREAM": False} + + r = self.generate(self._model_name, inputs) + r.raise_for_status() + + self.assertIn("Content-Type", r.headers) + self.assertIn("application/json", r.headers["Content-Type"]) + + data = r.json() + self.assertIn("TEXT", data) + self.assertEqual(text, data["TEXT"]) + + def test_generate_stream(self): + # Setup text-based input + text = "hello world" + rep_count = 3 + inputs = {"PROMPT": [text], "STREAM": True, "REPETITION": rep_count} + self.generate_stream_expect_success(self._model_name, inputs, text, rep_count) + + def test_streaming(self): + # verify the responses are streamed as soon as it is generated + text = "hello world" + rep_count = 3 + inputs = {"PROMPT": [text], "STREAM": True, "REPETITION": rep_count, "DELAY": 2} + past = time.time() + res = self.generate_stream(self._model_name, inputs, stream=True) + client = sseclient.SSEClient(res) + # This test does not focus on event content + for _ in client.events(): + now = time.time() + self.assertTrue(1 < (now - past) < 3) + past = now + + def test_missing_inputs(self): + missing_all_inputs = [ + # Missing all inputs + {}, + {"abc": 123}, + ] + missing_one_input = [ + # Missing 1 input + {"PROMPT": "hello"}, + {"STREAM": False}, + {"STREAM": False, "other": "param"}, + ] + for inputs in missing_all_inputs: + self.generate_expect_failure( + self._model_name, inputs, "expected 2 inputs but got 0" + ) + self.generate_stream_expect_failure( + self._model_name, inputs, "expected 2 inputs but got 0" + ) + + for inputs in missing_one_input: + self.generate_expect_failure( + self._model_name, inputs, "expected 2 inputs but got 1" + ) + self.generate_stream_expect_failure( + self._model_name, inputs, "expected 2 inputs but got 1" + ) + + def test_invalid_input_types(self): + invalid_bool = "attempt to access JSON non-boolean as boolean" + invalid_string = "attempt to access JSON non-string as string" + invalid_type_inputs = [ + # Prompt bad type + ({"PROMPT": 123, "STREAM": False}, invalid_string), + # Stream bad type + ({"PROMPT": "hello", "STREAM": "false"}, invalid_bool), + # Both bad type, parsed in order + ({"PROMPT": True, "STREAM": 123}, invalid_string), + ({"STREAM": 123, "PROMPT": True}, invalid_bool), + ] + + for inputs, error_msg in invalid_type_inputs: + self.generate_expect_failure(self._model_name, inputs, error_msg) + self.generate_stream_expect_failure(self._model_name, inputs, error_msg) + + def test_duplicate_inputs(self): + dupe_prompt = "input 'PROMPT' already exists in request" + dupe_stream = "input 'STREAM' already exists in request" + # Use JSON string directly as Python Dict doesn't support duplicate keys + invalid_type_inputs = [ + # One duplicate + ( + '{"PROMPT": "hello", "STREAM": false, "PROMPT": "duplicate"}', + dupe_prompt, + ), + ('{"PROMPT": "hello", "STREAM": false, "STREAM": false}', dupe_stream), + # Multiple duplicates, parsed in order + ( + '{"PROMPT": "hello", "STREAM": false, "PROMPT": "duplicate", "STREAM": true}', + dupe_prompt, + ), + ( + '{"PROMPT": "hello", "STREAM": false, "STREAM": true, "PROMPT": "duplicate"}', + dupe_stream, + ), + ] + for inputs, error_msg in invalid_type_inputs: + self.generate_expect_failure(self._model_name, inputs, error_msg) + self.generate_stream_expect_failure(self._model_name, inputs, error_msg) + + def test_generate_stream_response_error(self): + # Setup text-based input + text = "hello world" + inputs = {"PROMPT": [text], "STREAM": True, "REPETITION": 0, "FAIL_LAST": True} + r = self.generate_stream(self._model_name, inputs) + + # With "REPETITION": 0, error will be first response and the HTTP code + # will be set properly + try: + r.raise_for_status() + except requests.exceptions.HTTPError as e: + self.check_sse_responses(r, [{"error": "An Error Occurred"}]) + + # With "REPETITION" > 0, the first response is valid response and set + # HTTP code to success, so user must validate each response + inputs["REPETITION"] = 1 + r = self.generate_stream(self._model_name, inputs) + r.raise_for_status() + + self.check_sse_responses(r, [{"TEXT": text}, {"error": "An Error Occurred"}]) + + def test_race_condition(self): + # In Triton HTTP frontend, the HTTP response is sent in a different + # thread than Triton response complete thread, both programs have shared + # access to the same object, so this test is sending sufficient load to + # the endpoint, in attempt to expose race condition if any . + input1 = {"PROMPT": "hello", "STREAM": False, "param": "segfault"} + input2 = { + "PROMPT": "hello", + "STREAM": True, + "REPETITION": 3, + "param": "segfault", + } + threads = [] + + def thread_func(model_name, inputs): + self.generate_stream(model_name, inputs).raise_for_status() + + for _ in range(50): + threads.append( + threading.Thread(target=thread_func, args=((self._model_name, input1))) + ) + threads.append( + threading.Thread(target=thread_func, args=((self._model_name, input2))) + ) + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + def test_one_response(self): + # In the current 'inputs' setting, the model will send at least 1 + # response, "STREAM" controls model behavior on sending model responses: + # If True, the model sends two responses, one is the actual infer + # response and the other contains flag only to signal end of response. + # 'generate_stream' endpoint is designed for this case so it should send + # infer response and complete HTTP response appropriately. And + # 'generate' endpoint will be able to handle this case as at its core + # only one infer response is received, which is the same as typical HTTP + # usage. + # If False, the model sends one response containing infer response and + # end flag, which is the same as how non-decoupled model responds. + inputs = {"PROMPT": "hello world", "STREAM": True} + r = self.generate_stream(self._model_name, inputs) + r.raise_for_status() + r = self.generate(self._model_name, inputs) + r.raise_for_status() + + inputs["STREAM"] = False + r = self.generate_stream(self._model_name, inputs) + r.raise_for_status() + r = self.generate(self._model_name, inputs) + r.raise_for_status() + + def test_zero_response(self): + inputs = {"PROMPT": "hello world", "STREAM": True, "REPETITION": 0} + r = self.generate_stream(self._model_name, inputs) + r.raise_for_status() + # Expect generate fails the inference + r = self.generate(self._model_name, inputs) + try: + r.raise_for_status() + except requests.exceptions.HTTPError as e: + self.assertIn( + "generate expects model to produce exactly 1 response", + r.json()["error"], + ) + + def test_many_response(self): + inputs = {"PROMPT": "hello world", "STREAM": True, "REPETITION": 2} + r = self.generate_stream(self._model_name, inputs) + r.raise_for_status() + # Expect generate fails the inference + r = self.generate(self._model_name, inputs) + try: + r.raise_for_status() + except requests.exceptions.HTTPError as e: + self.assertIn( + "generate expects model to produce exactly 1 response", + r.json()["error"], + ) + + def test_complex_schema(self): + # Currently only the fundamental conversion is supported, nested object + # in the request will results in parsing error + + # complex object to parameters (specifying non model input) + inputs = { + "PROMPT": "hello world", + "STREAM": True, + "PARAMS": {"PARAM_0": 0, "PARAM_1": True}, + } + r = self.generate(self._model_name, inputs) + try: + r.raise_for_status() + except requests.exceptions.HTTPError as e: + self.assertIn("parameter 'PARAMS' has invalid type", r.json()["error"]) + + # complex object to model input + inputs = { + "PROMPT": {"USER": "hello world", "BOT": "world hello"}, + "STREAM": True, + } + r = self.generate(self._model_name, inputs) + try: + r.raise_for_status() + except requests.exceptions.HTTPError as e: + self.assertIn( + "attempt to access JSON non-string as string", r.json()["error"] + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/qa/L0_http/generate_models/mock_llm/1/model.py b/qa/L0_http/generate_models/mock_llm/1/model.py new file mode 100644 index 0000000000..987d22d99b --- /dev/null +++ b/qa/L0_http/generate_models/mock_llm/1/model.py @@ -0,0 +1,104 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import json +import time + +import numpy as np +import triton_python_backend_utils as pb_utils + + +class TritonPythonModel: + def initialize(self, args): + self.model_config = json.loads(args["model_config"]) + self.decoupled = self.model_config.get("model_transaction_policy", {}).get( + "decoupled" + ) + + def execute(self, requests): + if self.decoupled: + return self.exec_decoupled(requests) + else: + return self.exec(requests) + + def exec(self, requests): + responses = [] + for request in requests: + params = json.loads(request.parameters()) + rep_count = params["REPETITION"] if "REPETITION" in params else 1 + + input_np = pb_utils.get_input_tensor_by_name(request, "PROMPT").as_numpy() + stream_np = pb_utils.get_input_tensor_by_name(request, "STREAM").as_numpy() + stream = stream_np.flatten()[0] + if stream: + responses.append( + pb_utils.InferenceResponse( + error=pb_utils.TritonError( + "STREAM only supported in decoupled mode" + ) + ) + ) + else: + out_tensor = pb_utils.Tensor( + "TEXT", np.repeat(input_np, rep_count, axis=1) + ) + responses.append(pb_utils.InferenceResponse([out_tensor])) + return responses + + def exec_decoupled(self, requests): + for request in requests: + params = json.loads(request.parameters()) + rep_count = params["REPETITION"] if "REPETITION" in params else 1 + fail_last = params["FAIL_LAST"] if "FAIL_LAST" in params else False + delay = params["DELAY"] if "DELAY" in params else None + + sender = request.get_response_sender() + input_np = pb_utils.get_input_tensor_by_name(request, "PROMPT").as_numpy() + stream_np = pb_utils.get_input_tensor_by_name(request, "STREAM").as_numpy() + out_tensor = pb_utils.Tensor("TEXT", input_np) + response = pb_utils.InferenceResponse([out_tensor]) + # If stream enabled, just send multiple copies of response + # FIXME: Could split up response string into tokens, but this is simpler for now. + stream = stream_np.flatten()[0] + if stream: + for _ in range(rep_count): + if delay is not None: + time.sleep(delay) + sender.send(response) + sender.send( + None + if not fail_last + else pb_utils.InferenceResponse( + error=pb_utils.TritonError("An Error Occurred") + ), + flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL, + ) + # If stream disabled, just send one response + else: + sender.send( + response, flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL + ) + return None diff --git a/qa/L0_http/generate_models/mock_llm/config.pbtxt b/qa/L0_http/generate_models/mock_llm/config.pbtxt new file mode 100644 index 0000000000..6871661525 --- /dev/null +++ b/qa/L0_http/generate_models/mock_llm/config.pbtxt @@ -0,0 +1,60 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +backend: "python" + +max_batch_size: 0 + +model_transaction_policy { + decoupled: True +} + +input [ + { + name: "PROMPT" + data_type: TYPE_STRING + dims: [ 1, 1 ] + }, + { + name: "STREAM" + data_type: TYPE_BOOL + dims: [ 1, 1 ] + } +] + +output [ + { + name: "TEXT" + data_type: TYPE_STRING + dims: [ 1, -1 ] + } +] + +instance_group [ + { + count: 1 + kind: KIND_MODEL + } +] diff --git a/qa/L0_http/test.sh b/qa/L0_http/test.sh index c08a5fba74..23d13e0b29 100755 --- a/qa/L0_http/test.sh +++ b/qa/L0_http/test.sh @@ -243,7 +243,7 @@ if [ $? -ne 0 ]; then RET=1 fi -python3 $CLIENT_PLUGIN_TEST >> ${CLIENT_LOG}.python.plugin 2>&1 +python $CLIENT_PLUGIN_TEST >> ${CLIENT_LOG}.python.plugin 2>&1 if [ $? -ne 0 ]; then cat ${CLIENT_LOG}.python.plugin RET=1 @@ -254,7 +254,7 @@ echo -n 'username:' > pswd echo "password" | openssl passwd -stdin -apr1 >> pswd nginx -c `pwd`/$NGINX_CONF -python3 $BASIC_AUTH_TEST +python $BASIC_AUTH_TEST if [ $? -ne 0 ]; then cat ${CLIENT_LOG}.python.plugin.auth RET=1 @@ -612,7 +612,7 @@ TEST_RESULT_FILE='test_results.txt' PYTHON_TEST=http_test.py EXPECTED_NUM_TESTS=8 set +e -python3 $PYTHON_TEST >$CLIENT_LOG 2>&1 +python $PYTHON_TEST >$CLIENT_LOG 2>&1 if [ $? -ne 0 ]; then cat $CLIENT_LOG RET=1 @@ -629,6 +629,46 @@ set -e kill $SERVER_PID wait $SERVER_PID +### LLM / Generate REST API Endpoint Tests ### + +# Helper library to parse SSE events +# https://github.com/mpetazzoni/sseclient +pip install sseclient-py + +SERVER_ARGS="--model-repository=`pwd`/generate_models" +SERVER_LOG="./inference_server_generate_endpoint_test.log" +CLIENT_LOG="./generate_endpoint_test.log" +run_server +if [ "$SERVER_PID" == "0" ]; then + echo -e "\n***\n*** Failed to start $SERVER\n***" + cat $SERVER_LOG + exit 1 +fi + +## Python Unit Tests +TEST_RESULT_FILE='test_results.txt' +PYTHON_TEST=generate_endpoint_test.py +EXPECTED_NUM_TESTS=12 +set +e +python $PYTHON_TEST >$CLIENT_LOG 2>&1 +if [ $? -ne 0 ]; then + cat $CLIENT_LOG + RET=1 +else + check_test_results $TEST_RESULT_FILE $EXPECTED_NUM_TESTS + if [ $? -ne 0 ]; then + cat $CLIENT_LOG + echo -e "\n***\n*** Test Result Verification Failed\n***" + RET=1 + fi +fi +set -e + +kill $SERVER_PID +wait $SERVER_PID + +### + if [ $RET -eq 0 ]; then echo -e "\n***\n*** Test Passed\n***" else diff --git a/src/common.h b/src/common.h index fc9857fed9..c11254a6cc 100644 --- a/src/common.h +++ b/src/common.h @@ -63,10 +63,12 @@ const std::vector TRITON_RESERVED_REQUEST_PARAMS{ do { \ TRITONSERVER_Error* err__ = (X); \ if (err__ != nullptr) { \ - return TRITONSERVER_ErrorNew( \ + auto new_err = TRITONSERVER_ErrorNew( \ TRITONSERVER_ErrorCode(err__), \ (std::string(MSG) + ": " + TRITONSERVER_ErrorMessage(err__)) \ .c_str()); \ + TRITONSERVER_ErrorDelete(err__); \ + return new_err; \ } \ } while (false) diff --git a/src/http_server.cc b/src/http_server.cc index bdc4912981..f9d98ade00 100644 --- a/src/http_server.cc +++ b/src/http_server.cc @@ -108,6 +108,46 @@ AddContentTypeHeader(evhtp_request_t* req, const char* type) req->headers_out, evhtp_header_new(kContentTypeHeader, type, 1, 1)); } +TRITONSERVER_Error* +SetTritonParameterFromJsonParameter( + const std::string& parameter, + triton::common::TritonJson::Value& params_json, + TRITONSERVER_InferenceRequest* irequest) +{ + triton::common::TritonJson::Value value; + if (!params_json.Find(parameter.c_str(), &value)) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + ("parameter key '" + parameter + "' was not found in the JSON") + .c_str()); + } + + if (value.IsString()) { + std::string string_value; + RETURN_IF_ERR(value.AsString(&string_value)); + RETURN_IF_ERR(TRITONSERVER_InferenceRequestSetStringParameter( + irequest, parameter.c_str(), string_value.c_str())); + } else if (value.IsInt()) { + int64_t int_value; + RETURN_IF_ERR(value.AsInt(&int_value)); + RETURN_IF_ERR(TRITONSERVER_InferenceRequestSetIntParameter( + irequest, parameter.c_str(), int_value)); + } else if (value.IsBool()) { + bool bool_value; + RETURN_IF_ERR(value.AsBool(&bool_value)); + RETURN_IF_ERR(TRITONSERVER_InferenceRequestSetBoolParameter( + irequest, parameter.c_str(), bool_value)); + } else { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + ("parameter '" + parameter + + "' has invalid type. It should be either " + "'int', 'bool', or 'string'.") + .c_str()); + } + return nullptr; // success +} + } // namespace TRITONSERVER_Error* @@ -296,28 +336,23 @@ JsonBytesArrayByteSize( triton::common::TritonJson::Value& tensor_data, size_t* byte_size) { *byte_size = 0; - - for (size_t i = 0; i < tensor_data.ArraySize(); i++) { - triton::common::TritonJson::Value el; - RETURN_IF_ERR(tensor_data.At(i, &el)); - - // Recurse if not last dimension... - TRITONSERVER_Error* assert_err = - el.AssertType(triton::common::TritonJson::ValueType::ARRAY); - if (assert_err == nullptr) { + // Recurse if not last dimension... + if (tensor_data.IsArray()) { + for (size_t i = 0; i < tensor_data.ArraySize(); i++) { + triton::common::TritonJson::Value el; + RETURN_IF_ERR(tensor_data.At(i, &el)); size_t byte_size_; RETURN_IF_ERR(JsonBytesArrayByteSize(el, &byte_size_)); *byte_size += byte_size_; - } else { - // Serialized data size is the length of the string itself plus - // 4 bytes to record the string length. - const char* str; - size_t len = 0; - RETURN_MSG_IF_ERR( - el.AsString(&str, &len), "Unable to parse JSON bytes array"); - *byte_size += len + sizeof(uint32_t); } - TRITONSERVER_ErrorDelete(assert_err); + } else { + // Serialized data size is the length of the string itself plus + // 4 bytes to record the string length. + const char* str; + size_t len = 0; + RETURN_MSG_IF_ERR( + tensor_data.AsString(&str, &len), "Unable to parse JSON bytes array"); + *byte_size += len + sizeof(uint32_t); } return nullptr; // success @@ -329,141 +364,140 @@ ReadDataFromJsonHelper( triton::common::TritonJson::Value& tensor_data, int* counter, int64_t expected_cnt) { - // FIXME should invert loop and switch so don't have to do a switch - // each iteration. - for (size_t i = 0; i < tensor_data.ArraySize(); i++) { - triton::common::TritonJson::Value el; - RETURN_IF_ERR(tensor_data.At(i, &el)); - - // Recurse if not last dimension... - TRITONSERVER_Error* assert_err = - el.AssertType(triton::common::TritonJson::ValueType::ARRAY); - if (assert_err == nullptr) { + // FIXME should move 'switch' statement outside the recursive function and + // pass in a read data callback once data type is confirmed. + // Currently 'switch' is performed on each element even through all elements + // have the same data type. + + // Recurse on array element if not last dimension... + if (tensor_data.IsArray()) { + for (size_t i = 0; i < tensor_data.ArraySize(); i++) { + triton::common::TritonJson::Value el; + RETURN_IF_ERR(tensor_data.At(i, &el)); RETURN_IF_ERR( ReadDataFromJsonHelper(base, dtype, el, counter, expected_cnt)); - } else { - // Check if writing to 'serialized' is overrunning the expected byte_size - if (*counter >= expected_cnt) { - return TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INTERNAL, - "Shape does not match true shape of 'data' field"); - } - switch (dtype) { - case TRITONSERVER_TYPE_BOOL: { - bool b = false; - RETURN_IF_ERR(el.AsBool(&b)); - uint8_t* data_vec = reinterpret_cast(base); - // FIXME for unsigned should bounds check and raise error - // since otherwise the actually used value will be - // unexpected. - data_vec[*counter] = (uint8_t)(b ? 1 : 0); - *counter += 1; - break; - } - case TRITONSERVER_TYPE_UINT8: { - uint64_t ui = 0; - RETURN_IF_ERR(el.AsUInt(&ui)); - uint8_t* data_vec = reinterpret_cast(base); - data_vec[*counter] = (uint8_t)ui; - *counter += 1; - break; - } - case TRITONSERVER_TYPE_UINT16: { - uint64_t ui = 0; - RETURN_IF_ERR(el.AsUInt(&ui)); - uint16_t* data_vec = reinterpret_cast(base); - data_vec[*counter] = (uint16_t)ui; - *counter += 1; - break; - } - case TRITONSERVER_TYPE_UINT32: { - uint64_t ui = 0; - RETURN_IF_ERR(el.AsUInt(&ui)); - uint32_t* data_vec = reinterpret_cast(base); - data_vec[*counter] = (uint32_t)ui; - *counter += 1; - break; - } - case TRITONSERVER_TYPE_UINT64: { - uint64_t ui = 0; - RETURN_IF_ERR(el.AsUInt(&ui)); - uint64_t* data_vec = reinterpret_cast(base); - data_vec[*counter] = ui; - *counter += 1; - break; - } - case TRITONSERVER_TYPE_INT8: { - // FIXME for signed type just assigning to smaller type is - // "implementation defined" and so really need to bounds - // check. - int64_t si = 0; - RETURN_IF_ERR(el.AsInt(&si)); - int8_t* data_vec = reinterpret_cast(base); - data_vec[*counter] = (int8_t)si; - *counter += 1; - break; - } - case TRITONSERVER_TYPE_INT16: { - int64_t si = 0; - RETURN_IF_ERR(el.AsInt(&si)); - int16_t* data_vec = reinterpret_cast(base); - data_vec[*counter] = (int16_t)si; - *counter += 1; - break; - } - case TRITONSERVER_TYPE_INT32: { - int64_t si = 0; - RETURN_IF_ERR(el.AsInt(&si)); - int32_t* data_vec = reinterpret_cast(base); - data_vec[*counter] = (int32_t)si; - *counter += 1; - break; - } - case TRITONSERVER_TYPE_INT64: { - int64_t si = 0; - RETURN_IF_ERR(el.AsInt(&si)); - int64_t* data_vec = reinterpret_cast(base); - data_vec[*counter] = si; - *counter += 1; - break; - } - case TRITONSERVER_TYPE_FP32: { - double fp64 = 0; - RETURN_IF_ERR(el.AsDouble(&fp64)); - float* data_vec = reinterpret_cast(base); - data_vec[*counter] = fp64; - *counter += 1; - break; - } - case TRITONSERVER_TYPE_FP64: { - double fp64 = 0; - RETURN_IF_ERR(el.AsDouble(&fp64)); - double* data_vec = reinterpret_cast(base); - data_vec[*counter] = fp64; - *counter += 1; - break; - } - case TRITONSERVER_TYPE_BYTES: { - const char* cstr; - size_t len = 0; - RETURN_IF_ERR(el.AsString(&cstr, &len)); - if (static_cast(*counter + len + sizeof(uint32_t)) > - expected_cnt) { - return TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INTERNAL, - "Shape does not match true shape of 'data' field"); - } - memcpy( - base + *counter, reinterpret_cast(&len), sizeof(uint32_t)); - std::copy(cstr, cstr + len, base + *counter + sizeof(uint32_t)); - *counter += len + sizeof(uint32_t); - break; + } + } else { + // Check if writing to 'serialized' is overrunning the expected byte_size + if (*counter >= expected_cnt) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + "Shape does not match true shape of 'data' field"); + } + switch (dtype) { + case TRITONSERVER_TYPE_BOOL: { + bool b = false; + RETURN_IF_ERR(tensor_data.AsBool(&b)); + uint8_t* data_vec = reinterpret_cast(base); + // FIXME for unsigned should bounds check and raise error + // since otherwise the actually used value will be + // unexpected. + data_vec[*counter] = (uint8_t)(b ? 1 : 0); + *counter += 1; + break; + } + case TRITONSERVER_TYPE_UINT8: { + uint64_t ui = 0; + RETURN_IF_ERR(tensor_data.AsUInt(&ui)); + uint8_t* data_vec = reinterpret_cast(base); + data_vec[*counter] = (uint8_t)ui; + *counter += 1; + break; + } + case TRITONSERVER_TYPE_UINT16: { + uint64_t ui = 0; + RETURN_IF_ERR(tensor_data.AsUInt(&ui)); + uint16_t* data_vec = reinterpret_cast(base); + data_vec[*counter] = (uint16_t)ui; + *counter += 1; + break; + } + case TRITONSERVER_TYPE_UINT32: { + uint64_t ui = 0; + RETURN_IF_ERR(tensor_data.AsUInt(&ui)); + uint32_t* data_vec = reinterpret_cast(base); + data_vec[*counter] = (uint32_t)ui; + *counter += 1; + break; + } + case TRITONSERVER_TYPE_UINT64: { + uint64_t ui = 0; + RETURN_IF_ERR(tensor_data.AsUInt(&ui)); + uint64_t* data_vec = reinterpret_cast(base); + data_vec[*counter] = ui; + *counter += 1; + break; + } + case TRITONSERVER_TYPE_INT8: { + // FIXME for signed type just assigning to smaller type is + // "implementation defined" and so really need to bounds + // check. + int64_t si = 0; + RETURN_IF_ERR(tensor_data.AsInt(&si)); + int8_t* data_vec = reinterpret_cast(base); + data_vec[*counter] = (int8_t)si; + *counter += 1; + break; + } + case TRITONSERVER_TYPE_INT16: { + int64_t si = 0; + RETURN_IF_ERR(tensor_data.AsInt(&si)); + int16_t* data_vec = reinterpret_cast(base); + data_vec[*counter] = (int16_t)si; + *counter += 1; + break; + } + case TRITONSERVER_TYPE_INT32: { + int64_t si = 0; + RETURN_IF_ERR(tensor_data.AsInt(&si)); + int32_t* data_vec = reinterpret_cast(base); + data_vec[*counter] = (int32_t)si; + *counter += 1; + break; + } + case TRITONSERVER_TYPE_INT64: { + int64_t si = 0; + RETURN_IF_ERR(tensor_data.AsInt(&si)); + int64_t* data_vec = reinterpret_cast(base); + data_vec[*counter] = si; + *counter += 1; + break; + } + case TRITONSERVER_TYPE_FP32: { + double fp64 = 0; + RETURN_IF_ERR(tensor_data.AsDouble(&fp64)); + float* data_vec = reinterpret_cast(base); + data_vec[*counter] = fp64; + *counter += 1; + break; + } + case TRITONSERVER_TYPE_FP64: { + double fp64 = 0; + RETURN_IF_ERR(tensor_data.AsDouble(&fp64)); + double* data_vec = reinterpret_cast(base); + data_vec[*counter] = fp64; + *counter += 1; + break; + } + case TRITONSERVER_TYPE_BYTES: { + const char* cstr; + size_t len = 0; + RETURN_IF_ERR(tensor_data.AsString(&cstr, &len)); + if (static_cast(*counter + len + sizeof(uint32_t)) > + expected_cnt) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + "Shape does not match true shape of 'data' field"); } - default: - break; + memcpy( + base + *counter, reinterpret_cast(&len), sizeof(uint32_t)); + std::copy(cstr, cstr + len, base + *counter + sizeof(uint32_t)); + *counter += len + sizeof(uint32_t); + break; } + default: + break; } - TRITONSERVER_ErrorDelete(assert_err); } return nullptr; // success @@ -1009,7 +1043,7 @@ HTTPAPIServer::HTTPAPIServer( server_(server), trace_manager_(trace_manager), shm_manager_(shm_manager), allocator_(nullptr), server_regex_(R"(/v2(?:/health/(live|ready))?)"), model_regex_( - R"(/v2/models/([^/]+)(?:/versions/([0-9]+))?(?:/(infer|ready|config|stats|trace/setting))?)"), + R"(/v2/models/([^/]+)(?:/versions/([0-9]+))?(?:/(infer|generate|generate_stream|ready|config|stats|trace/setting))?)"), modelcontrol_regex_( R"(/v2/repository(?:/([^/]+))?/(index|models/([^/]+)/(load|unload)))"), systemsharedmemory_regex_( @@ -1547,6 +1581,36 @@ HTTPAPIServer::HandleModelMetadata( RETURN_AND_RESPOND_IF_ERR(req, err); } +TRITONSERVER_Error* +HTTPAPIServer::GetModelConfig( + const std::string& model_name, int64_t requested_model_version, + std::string* config_json) +{ + if (model_name.empty()) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + "Missing model name in ModelConfig request"); + } + + TRITONSERVER_Message* message = nullptr; + RETURN_IF_ERR(TRITONSERVER_ServerModelConfig( + server_.get(), model_name.c_str(), requested_model_version, + 1 /* config_version */, &message)); + const char* buffer; + size_t byte_size; + TRITONSERVER_Error* err = nullptr; + err = TRITONSERVER_MessageSerializeToJson(message, &buffer, &byte_size); + if (err == nullptr) { + // Copy config into string for simplicity + *config_json = std::string(buffer, byte_size); + } + if (message) { + TRITONSERVER_MessageDelete(message); + } + + return err; +} + void HTTPAPIServer::HandleModelConfig( evhtp_request_t* req, const std::string& model_name, @@ -1558,33 +1622,18 @@ HTTPAPIServer::HandleModelConfig( req, EVHTP_RES_METHNALLOWED, "Method Not Allowed"); } - if (model_name.empty()) { - RETURN_AND_RESPOND_WITH_ERR( - req, EVHTP_RES_BADREQ, "Missing model name in ModelConfig request"); - } - - TRITONSERVER_Message* message = nullptr; - int64_t requested_model_version; - auto err = - GetModelVersionFromString(model_version_str, &requested_model_version); - if (err == nullptr) { - err = TRITONSERVER_ServerModelConfig( - server_.get(), model_name.c_str(), requested_model_version, - 1 /* config_version */, &message); - if (err == nullptr) { - const char* buffer; - size_t byte_size; - err = TRITONSERVER_MessageSerializeToJson(message, &buffer, &byte_size); - if (err == nullptr) { - evbuffer_add(req->buffer_out, buffer, byte_size); - evhtp_send_reply(req, EVHTP_RES_OK); - } - TRITONSERVER_MessageDelete(message); - } - } + RETURN_AND_RESPOND_IF_ERR( + req, + GetModelVersionFromString(model_version_str, &requested_model_version)); - RETURN_AND_RESPOND_IF_ERR(req, err); + std::string config_json_str = ""; + RETURN_AND_RESPOND_IF_ERR( + req, + GetModelConfig(model_name, requested_model_version, &config_json_str)); + evbuffer_add( + req->buffer_out, config_json_str.c_str(), config_json_str.size()); + evhtp_send_reply(req, EVHTP_RES_OK); } void @@ -2716,37 +2765,8 @@ HTTPAPIServer::ParseJsonTritonParams( "usage " "and should not be specified.")); } else { - triton::common::TritonJson::Value value; - if (!params_json.Find(parameter.c_str(), &value)) { - return TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INTERNAL, - ("parameter key '" + parameter + "' was not found in the JSON") - .c_str()); - } - - if (value.IsString()) { - std::string string_value; - RETURN_IF_ERR(value.AsString(&string_value)); - RETURN_IF_ERR(TRITONSERVER_InferenceRequestSetStringParameter( - irequest, parameter.c_str(), string_value.c_str())); - } else if (value.IsInt()) { - int64_t int_value; - RETURN_IF_ERR(value.AsInt(&int_value)); - RETURN_IF_ERR(TRITONSERVER_InferenceRequestSetIntParameter( - irequest, parameter.c_str(), int_value)); - } else if (value.IsBool()) { - bool bool_value; - RETURN_IF_ERR(value.AsBool(&bool_value)); - RETURN_IF_ERR(TRITONSERVER_InferenceRequestSetBoolParameter( - irequest, parameter.c_str(), bool_value)); - } else { - return TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INVALID_ARG, - ("parameter '" + parameter + - "' has invalid type. It should be either " - "'int', 'bool', or 'string'.") - .c_str()); - } + RETURN_IF_ERR(SetTritonParameterFromJsonParameter( + parameter, params_json, irequest)); } } @@ -2775,6 +2795,28 @@ HTTPAPIServer::ParseJsonTritonRequestID( return nullptr; // Success } +// TODO: Can refactor other non-inference routes to re-use this helper instead. +TRITONSERVER_Error* +HTTPAPIServer::EVRequestToJson( + evhtp_request_t* req, triton::common::TritonJson::Value* request_json_ptr) +{ + struct evbuffer_iovec* v = nullptr; + int v_idx = 0; + int n = evbuffer_peek(req->buffer_in, -1, NULL, NULL, 0); + if (n > 0) { + v = static_cast( + alloca(sizeof(struct evbuffer_iovec) * n)); + if (evbuffer_peek(req->buffer_in, -1, NULL, v, n) != n) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + "Unexpected error getting request buffers"); + } + } + size_t buffer_len = evbuffer_get_length(req->buffer_in); + RETURN_IF_ERR(EVBufferToJson(request_json_ptr, v, &v_idx, buffer_len, n)); + return nullptr; // success +} + TRITONSERVER_Error* HTTPAPIServer::EVBufferToInput( const std::string& model_name, TRITONSERVER_InferenceRequest* irequest, @@ -2977,7 +3019,7 @@ HTTPAPIServer::DecompressBuffer( } TRITONSERVER_Error* -HTTPAPIServer::EVBufferToTritonRequest( +HTTPAPIServer::EVRequestToTritonRequest( evhtp_request_t* req, const std::string& model_name, TRITONSERVER_InferenceRequest* irequest, evbuffer* decompressed_buffer, InferRequestClass* infer_req, size_t header_length) @@ -3015,10 +3057,11 @@ HTTPAPIServer::ForwardHeaders( } void -HTTPAPIServer::HandleInfer( +HTTPAPIServer::HandleGenerate( evhtp_request_t* req, const std::string& model_name, - const std::string& model_version_str) + const std::string& model_version_str, bool streaming) { + AddContentTypeHeader(req, "application/json"); if (req->method != htp_method_POST) { RETURN_AND_RESPOND_WITH_ERR( req, EVHTP_RES_METHNALLOWED, "Method Not Allowed"); @@ -3026,29 +3069,31 @@ HTTPAPIServer::HandleInfer( int64_t requested_model_version; RETURN_AND_RESPOND_IF_ERR( - req, GetModelVersionFromString( - model_version_str.c_str(), &requested_model_version)); - RETURN_AND_RESPOND_IF_ERR( - req, CheckTransactionPolicy(req, model_name, requested_model_version)); + req, + GetModelVersionFromString(model_version_str, &requested_model_version)); // If tracing is enabled see if this request should be traced. TRITONSERVER_InferenceTrace* triton_trace = nullptr; std::shared_ptr trace = StartTrace(req, model_name, &triton_trace); - // Decompress request body if it is compressed in supported type - evbuffer* decompressed_buffer = nullptr; - RETURN_AND_RESPOND_IF_ERR(req, DecompressBuffer(req, &decompressed_buffer)); - - // Get content length as a default header_length if no header specified - int32_t content_length = 0; + std::map input_metadata; + triton::common::TritonJson::Value meta_data_root; RETURN_AND_RESPOND_IF_ERR( - req, GetContentLength(req, decompressed_buffer, &content_length)); + req, ModelInputMetadata( + model_name, requested_model_version, &input_metadata, + &meta_data_root)); - // Get the header length - size_t header_length = 0; - RETURN_AND_RESPOND_IF_ERR( - req, GetInferenceHeaderLength(req, content_length, &header_length)); + + // [FIXME] decompression should have been done here. before parsing request + // body + if (GetRequestCompressionType(req) != DataCompressor::Type::IDENTITY) { + RETURN_AND_RESPOND_IF_ERR( + req, + TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + "Unsupported content-encoding, only 'identity' is supported.")); + } // Create the inference request object which provides all information needed // for an inference. Make sure it is cleaned up on early error. @@ -3060,70 +3105,96 @@ HTTPAPIServer::HandleInfer( // HTTP request paused when creating inference request. Resume it on exit if // this function returns early due to error. Otherwise resumed in callback. - bool connection_paused = true; - auto infer_request = CreateInferRequest(req); - infer_request->trace_ = trace; + std::unique_ptr generate_request; + if (streaming) { + generate_request.reset(new GenerateRequestClass( + server_.get(), req, GetResponseCompressionType(req), + generate_stream_request_schema_.get(), + generate_stream_response_schema_.get(), streaming, irequest)); + } else { + generate_request.reset(new GenerateRequestClass( + server_.get(), req, GetResponseCompressionType(req), + generate_request_schema_.get(), generate_response_schema_.get(), + streaming, irequest)); + } + generate_request->trace_ = trace; const char* request_id = ""; // Callback to cleanup on any errors encountered below. Capture everything // by reference to capture local updates, except for shared pointers which // should be captured by value in case of ref count issues. + // The callback does not own the error object. auto error_callback = [&, trace](TRITONSERVER_Error* error) { if (error != nullptr) { + // Get request ID for logging in case of error. + if (irequest != nullptr) { + LOG_TRITONSERVER_ERROR( + TRITONSERVER_InferenceRequestId(irequest, &request_id), + "unable to retrieve request ID string"); + } + if (!strncmp(request_id, "", 1)) { + request_id = ""; + } + LOG_VERBOSE(1) << "[request id: " << request_id << "] " << "Infer failed: " << TRITONSERVER_ErrorMessage(error); + LOG_TRITONSERVER_ERROR( + TRITONSERVER_InferenceRequestDelete(irequest), + "deleting HTTP/REST inference request"); AddContentTypeHeader(req, "application/json"); EVBufferAddErrorJson(req->buffer_out, error); evhtp_send_reply(req, EVHTP_RES_BADREQ); - if (connection_paused) { - evhtp_request_resume(req); - } + evhtp_request_resume(req); + #ifdef TRITON_ENABLE_TRACING // If HTTP server still owns Triton trace if ((trace != nullptr) && (trace->trace_ != nullptr)) { TraceManager::TraceRelease(trace->trace_, trace->trace_userp_); } #endif // TRITON_ENABLE_TRACING - - LOG_TRITONSERVER_ERROR( - TRITONSERVER_InferenceRequestDelete(irequest), - "deleting HTTP/REST inference request"); } }; - // Parse EV buffer and fill Triton request fields from it + // Option 1: Form tensor-like JSON request and try to re-use HandleInfer + // as much as possible. Probably need to do something like overwrite + // req->buffer_in or create a new evhtp_request to pass and handle. + // Option 2: Do inference logic directly here after parsing request. + // Note: + // Currently option 2 is selected. It is true that HandleInfer() includes + // handling for features that will be requested for generate endpoints + // (i.e. tracing), however, it is currently tied to infer endpoint logic and + // some decoupling must be done to properly reuse it (for example, response + // callback is tied to infer logic and inflexible for response streaming). + // For the time being, it is less mental burden to support this endpoint + // without early optimization for code reuse. + // Also, there is limitation on Triton JSON library that makes forming + // arbitrary JSON message convoluted (added key is reference to a string and + // thus the string must live as long as the JSON message). + triton::common::TritonJson::Value request; + RETURN_AND_CALLBACK_IF_ERR(EVRequestToJson(req, &request), error_callback); + RETURN_AND_CALLBACK_IF_ERR( - EVBufferToTritonRequest( - req, model_name, irequest, decompressed_buffer, infer_request.get(), - header_length), + generate_request->ConvertGenerateRequest( + input_metadata, generate_request->RequestSchema(), request), error_callback); - // Get request ID for logging in case of error. - LOG_TRITONSERVER_ERROR( - TRITONSERVER_InferenceRequestId(irequest, &request_id), - "unable to retrieve request ID string"); - // Reset id to unknown if empty in core. - if (!strncmp(request_id, "", 1)) { - request_id = ""; - } - - RETURN_AND_CALLBACK_IF_ERR(ForwardHeaders(req, irequest), error_callback); - + // [FIXME] decompression.. RETURN_AND_CALLBACK_IF_ERR( TRITONSERVER_InferenceRequestSetReleaseCallback( - irequest, InferRequestClass::InferRequestComplete, - decompressed_buffer), + irequest, InferRequestClass::InferRequestComplete, nullptr), error_callback); RETURN_AND_CALLBACK_IF_ERR( TRITONSERVER_InferenceRequestSetResponseCallback( irequest, allocator_, - reinterpret_cast(&infer_request->alloc_payload_), - InferRequestClass::InferResponseComplete, - reinterpret_cast(infer_request.get())), + reinterpret_cast(&generate_request->alloc_payload_), + GenerateRequestClass::InferResponseComplete, + reinterpret_cast(generate_request.get())), + error_callback); + + RETURN_AND_CALLBACK_IF_ERR( + TRITONSERVER_ServerInferAsync(server_.get(), irequest, triton_trace), error_callback); - auto err = - TRITONSERVER_ServerInferAsync(server_.get(), irequest, triton_trace); #ifdef TRITON_ENABLE_TRACING // Ownership of trace passed to Triton core, set trace to null to mark it // as no longer owned here. @@ -3131,47 +3202,354 @@ HTTPAPIServer::HandleInfer( trace->trace_ = nullptr; } #endif // TRITON_ENABLE_TRACING - - RETURN_AND_CALLBACK_IF_ERR(err, error_callback); - infer_request.release(); + generate_request.release(); } -void -HTTPAPIServer::OKReplyCallback(evthr_t* thr, void* arg, void* shared) +TRITONSERVER_Error* +HTTPAPIServer::ModelInputMetadata( + const std::string& model_name, const int64_t model_version, + std::map* input_metadata, + triton::common::TritonJson::Value* metadata_root) { - HTTPAPIServer::InferRequestClass* infer_request = - reinterpret_cast(arg); + { + if (model_name.empty()) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + "Missing model name in metadata request"); + } - evhtp_request_t* request = infer_request->EvHtpRequest(); - evhtp_send_reply(request, EVHTP_RES_OK); - evhtp_request_resume(request); + TRITONSERVER_Message* message = nullptr; + RETURN_IF_ERR(TRITONSERVER_ServerModelMetadata( + server_.get(), model_name.c_str(), model_version, &message)); + const char* buffer; + size_t byte_size; + TRITONSERVER_Error* err = nullptr; + err = TRITONSERVER_MessageSerializeToJson(message, &buffer, &byte_size); + if (err == nullptr) { + RETURN_IF_ERR(metadata_root->Parse(buffer, byte_size)); + } + if (message) { + TRITONSERVER_MessageDelete(message); + } + } -#ifdef TRITON_ENABLE_TRACING - if (infer_request->trace_ != nullptr) { - infer_request->trace_->CaptureTimestamp( - "HTTP_SEND_START", request->send_start_ns); - infer_request->trace_->CaptureTimestamp( - "HTTP_SEND_END", request->send_end_ns); + // input + triton::common::TritonJson::Value inputs; + RETURN_IF_ERR(metadata_root->MemberAsArray("inputs", &inputs)); + for (size_t i = 0; i < inputs.ArraySize(); ++i) { + triton::common::TritonJson::Value input; + RETURN_IF_ERR(inputs.At(i, &input)); + std::string name = ""; + RETURN_IF_ERR(input.MemberAsString("name", &name)); + (*input_metadata)[name] = std::move(input); } -#endif // TRITON_ENABLE_TRACING - delete infer_request; + return nullptr; // success } -void -HTTPAPIServer::BADReplyCallback(evthr_t* thr, void* arg, void* shared) +TRITONSERVER_Error* +HTTPAPIServer::GenerateRequestClass::ConvertGenerateRequest( + std::map& input_metadata, + const MappingSchema* schema, + triton::common::TritonJson::Value& generate_request) { - HTTPAPIServer::InferRequestClass* infer_request = - reinterpret_cast(arg); - - evhtp_request_t* request = infer_request->EvHtpRequest(); - evhtp_send_reply(request, EVHTP_RES_BADREQ); - evhtp_request_resume(request); - -#ifdef TRITON_ENABLE_TRACING - if (infer_request->trace_ != nullptr) { - infer_request->trace_->CaptureTimestamp( - "HTTP_SEND_START", request->send_start_ns); + // First find all top-level keys in JSON + std::vector members; + RETURN_IF_ERR(generate_request.Members(&members)); + + for (const auto& m : members) { + auto it = schema->children_.find(m); + if (it != schema->children_.end()) { + switch (it->second->kind_) { + case MappingSchema::Kind::EXACT_MAPPING: { + // Read meta data + RETURN_IF_ERR(ExactMappingInput(m, generate_request, input_metadata)); + break; + } + case MappingSchema::Kind::MAPPING_SCHEMA: { + // The key is nested schema + triton::common::TritonJson::Value nested_generate_request; + RETURN_IF_ERR(generate_request.MemberAsObject( + m.c_str(), &nested_generate_request)); + RETURN_IF_ERR(ConvertGenerateRequest( + input_metadata, it->second.get(), nested_generate_request)); + break; + } + default: + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, "Unsupported schema kind"); + } + } else if (schema->allow_unspecified_) { + // Unspecified key follows EXACT_MAPPING + RETURN_IF_ERR(ExactMappingInput(m, generate_request, input_metadata)); + } else { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, + "The schema disallow unspecified key"); + } + } + return nullptr; // success +} + +TRITONSERVER_Error* +HTTPAPIServer::GenerateRequestClass::ExactMappingInput( + const std::string& name, + triton::common::TritonJson::Value& generate_request, + std::map& input_metadata) +{ + auto it = input_metadata.find(name); + if (it == input_metadata.end()) { + RETURN_IF_ERR(SetTritonParameterFromJsonParameter( + name, generate_request, triton_request_)); + } else { + // Parse data type and shape + std::string value; + it->second.MemberAsString("datatype", &value); + auto dtype = TRITONSERVER_StringToDataType(value.c_str()); + + // Perform shape validation, assume the value must be either + // primitive type or 1-D array. + triton::common::TritonJson::Value tensor_data; + if (!generate_request.Find(name.c_str(), &tensor_data)) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string("unexpected key not found in generate request, " + "expecting key '") + + name + "'") + .c_str()); + } + + size_t element_cnt = tensor_data.IsArray() ? tensor_data.ArraySize() : 1; + + size_t byte_size = 0; + if (dtype == TRITONSERVER_TYPE_BYTES) { + RETURN_IF_ERR(JsonBytesArrayByteSize(tensor_data, &byte_size)); + } else { + byte_size = element_cnt * TRITONSERVER_DataTypeByteSize(dtype); + } + + std::vector shape_vec; + { + triton::common::TritonJson::Value value; + if (!it->second.Find("shape", &value)) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + (std::string( + "Unexpected 'shape' not found in model metadata for input '") + + name) + .c_str()); + } + for (size_t i = 0; i < value.ArraySize(); ++i) { + int64_t d = 0; + RETURN_IF_ERR(value.IndexAsInt(i, &d)); + shape_vec.push_back(d); + } + // Because generate request don't carry too much shape information, using + // a two-pass process to pad the request value to match input shape. + // 1. iterate shape for fixed dimension to distribute 'element_cnt'. + // 2. Set most inner dynamic shape to the remaining element count, + // other dynamic shape to be 1. + for (auto rit = shape_vec.rbegin(); rit != shape_vec.rend(); ++rit) { + if (*rit != -1) { + if (element_cnt % *rit) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string("The schema can not convert input '") + name + + "' to tensor with proper shape") + .c_str()); + } + element_cnt /= *rit; + } + } + for (auto rit = shape_vec.rbegin(); rit != shape_vec.rend(); ++rit) { + if (*rit == -1) { + *rit = element_cnt; + element_cnt = 1; + } + } + if (element_cnt != 1) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + (std::string("The schema can not convert input '") + name + + "' to tensor with proper shape") + .c_str()); + } + } + + serialized_data_.emplace_back(); + std::vector& serialized = serialized_data_.back(); + serialized.resize(byte_size); + RETURN_IF_ERR(ReadDataFromJson( + name.c_str(), tensor_data, &serialized[0], dtype, + dtype == TRITONSERVER_TYPE_BYTES ? byte_size : element_cnt)); + + RETURN_IF_ERR(TRITONSERVER_InferenceRequestAddInput( + triton_request_, name.c_str(), dtype, &shape_vec[0], shape_vec.size())); + RETURN_IF_ERR(TRITONSERVER_InferenceRequestAppendInputData( + triton_request_, name.c_str(), &serialized[0], serialized.size(), + TRITONSERVER_MEMORY_CPU, 0 /* memory_type_id */)); + } + return nullptr; // success +} + +void +HTTPAPIServer::HandleInfer( + evhtp_request_t* req, const std::string& model_name, + const std::string& model_version_str) +{ + if (req->method != htp_method_POST) { + RETURN_AND_RESPOND_WITH_ERR( + req, EVHTP_RES_METHNALLOWED, "Method Not Allowed"); + } + + int64_t requested_model_version; + RETURN_AND_RESPOND_IF_ERR( + req, GetModelVersionFromString( + model_version_str.c_str(), &requested_model_version)); + RETURN_AND_RESPOND_IF_ERR( + req, CheckTransactionPolicy(req, model_name, requested_model_version)); + + // If tracing is enabled see if this request should be traced. + TRITONSERVER_InferenceTrace* triton_trace = nullptr; + std::shared_ptr trace = + StartTrace(req, model_name, &triton_trace); + + // Decompress request body if it is compressed in supported type + evbuffer* decompressed_buffer = nullptr; + RETURN_AND_RESPOND_IF_ERR(req, DecompressBuffer(req, &decompressed_buffer)); + + // Get content length as a default header_length if no header specified + int32_t content_length = 0; + RETURN_AND_RESPOND_IF_ERR( + req, GetContentLength(req, decompressed_buffer, &content_length)); + + // Get the header length + size_t header_length = 0; + RETURN_AND_RESPOND_IF_ERR( + req, GetInferenceHeaderLength(req, content_length, &header_length)); + + // Create the inference request object which provides all information needed + // for an inference. Make sure it is cleaned up on early error. + TRITONSERVER_InferenceRequest* irequest = nullptr; + RETURN_AND_RESPOND_IF_ERR( + req, TRITONSERVER_InferenceRequestNew( + &irequest, server_.get(), model_name.c_str(), + requested_model_version)); + + // HTTP request paused when creating inference request. Resume it on exit if + // this function returns early due to error. Otherwise resumed in callback. + bool connection_paused = true; + auto infer_request = CreateInferRequest(req); + infer_request->trace_ = trace; + + const char* request_id = ""; + // Callback to cleanup on any errors encountered below. Capture everything + // by reference to capture local updates, except for shared pointers which + // should be captured by value in case of ref count issues. + auto error_callback = [&, trace](TRITONSERVER_Error* error) { + if (error != nullptr) { + LOG_VERBOSE(1) << "[request id: " << request_id << "] " + << "Infer failed: " << TRITONSERVER_ErrorMessage(error); + AddContentTypeHeader(req, "application/json"); + EVBufferAddErrorJson(req->buffer_out, error); + evhtp_send_reply(req, EVHTP_RES_BADREQ); + if (connection_paused) { + evhtp_request_resume(req); + } +#ifdef TRITON_ENABLE_TRACING + // If HTTP server still owns Triton trace + if ((trace != nullptr) && (trace->trace_ != nullptr)) { + TraceManager::TraceRelease(trace->trace_, trace->trace_userp_); + } +#endif // TRITON_ENABLE_TRACING + + LOG_TRITONSERVER_ERROR( + TRITONSERVER_InferenceRequestDelete(irequest), + "deleting HTTP/REST inference request"); + } + }; + + // Parse EV request and fill Triton request fields from it + RETURN_AND_CALLBACK_IF_ERR( + EVRequestToTritonRequest( + req, model_name, irequest, decompressed_buffer, infer_request.get(), + header_length), + error_callback); + + // Get request ID for logging in case of error. + LOG_TRITONSERVER_ERROR( + TRITONSERVER_InferenceRequestId(irequest, &request_id), + "unable to retrieve request ID string"); + // Reset id to unknown if empty in core. + if (!strncmp(request_id, "", 1)) { + request_id = ""; + } + + RETURN_AND_CALLBACK_IF_ERR(ForwardHeaders(req, irequest), error_callback); + + RETURN_AND_CALLBACK_IF_ERR( + TRITONSERVER_InferenceRequestSetReleaseCallback( + irequest, InferRequestClass::InferRequestComplete, + decompressed_buffer), + error_callback); + RETURN_AND_CALLBACK_IF_ERR( + TRITONSERVER_InferenceRequestSetResponseCallback( + irequest, allocator_, + reinterpret_cast(&infer_request->alloc_payload_), + InferRequestClass::InferResponseComplete, + reinterpret_cast(infer_request.get())), + error_callback); + + auto err = + TRITONSERVER_ServerInferAsync(server_.get(), irequest, triton_trace); +#ifdef TRITON_ENABLE_TRACING + // Ownership of trace passed to Triton core, set trace to null to mark it + // as no longer owned here. + if (trace != nullptr) { + trace->trace_ = nullptr; + } +#endif // TRITON_ENABLE_TRACING + + RETURN_AND_CALLBACK_IF_ERR(err, error_callback); + infer_request.release(); +} + +void +HTTPAPIServer::OKReplyCallback(evthr_t* thr, void* arg, void* shared) +{ + HTTPAPIServer::InferRequestClass* infer_request = + reinterpret_cast(arg); + + evhtp_request_t* request = infer_request->EvHtpRequest(); + evhtp_send_reply(request, EVHTP_RES_OK); + evhtp_request_resume(request); + +#ifdef TRITON_ENABLE_TRACING + if (infer_request->trace_ != nullptr) { + infer_request->trace_->CaptureTimestamp( + "HTTP_SEND_START", request->send_start_ns); + infer_request->trace_->CaptureTimestamp( + "HTTP_SEND_END", request->send_end_ns); + } +#endif // TRITON_ENABLE_TRACING + + delete infer_request; +} + +void +HTTPAPIServer::BADReplyCallback(evthr_t* thr, void* arg, void* shared) +{ + HTTPAPIServer::InferRequestClass* infer_request = + reinterpret_cast(arg); + + evhtp_request_t* request = infer_request->EvHtpRequest(); + evhtp_send_reply(request, EVHTP_RES_BADREQ); + evhtp_request_resume(request); + +#ifdef TRITON_ENABLE_TRACING + if (infer_request->trace_ != nullptr) { + infer_request->trace_->CaptureTimestamp( + "HTTP_SEND_START", request->send_start_ns); infer_request->trace_->CaptureTimestamp( "HTTP_SEND_END", request->send_end_ns); } @@ -3568,6 +3946,423 @@ HTTPAPIServer::InferRequestClass::IncrementResponseCount() return response_count_++; } +HTTPAPIServer::GenerateRequestClass::~GenerateRequestClass() +{ + while (!pending_http_responses_.empty()) { + evbuffer_free(pending_http_responses_.front()); + pending_http_responses_.pop(); + } +} + +void +HTTPAPIServer::GenerateRequestClass::InferResponseComplete( + TRITONSERVER_InferenceResponse* response, const uint32_t flags, void* userp) +{ + // FIXME can't use InferRequestClass object here since it's lifetime + // is different than response. For response we need to know how to + // send each output (as json, shm, or binary) and that information + // has to be maintained in a way that allows us to clean it up + // appropriately if connection closed or last response sent. + // + // But for now userp is the InferRequestClass object and the end of + // its life is in the OK or BAD ReplyCallback. + + auto infer_request = + reinterpret_cast(userp); + + // Assuming responses of the same request is sent in sequence. + + TRITONSERVER_Error* err = nullptr; + if (response != nullptr) { + err = infer_request->FinalizeResponse(response); + } + if (err != nullptr) { + infer_request->AddErrorJson(err); + } + + + // First response starts the chunked response, the response code is set here + // so user should check response body in case of error at later time. + if (infer_request->IncrementResponseCount() == 0) { + infer_request->StartResponse( + (err == nullptr) ? EVHTP_RES_OK : EVHTP_RES_BADREQ); + } + +#ifdef TRITON_ENABLE_TRACING + if (infer_request->trace_ != nullptr) { + infer_request->trace_->CaptureTimestamp( + "INFER_RESPONSE_COMPLETE", TraceManager::CaptureTimestamp()); + } +#endif // TRITON_ENABLE_TRACING + + // Final flag indicates there is no more responses, ending chunked response. + if ((flags & TRITONSERVER_RESPONSE_COMPLETE_FINAL) != 0) { + evthr_defer(infer_request->thread_, EndResponseCallback, infer_request); + } else { + evthr_defer(infer_request->thread_, ChunkResponseCallback, infer_request); + } + + LOG_TRITONSERVER_ERROR( + TRITONSERVER_InferenceResponseDelete(response), + "deleting inference response"); +} + +void +HTTPAPIServer::GenerateRequestClass::StartResponse(evhtp_res code) +{ + if (streaming_) { + AddContentTypeHeader(req_, "text/event-stream; charset=utf-8"); + } else { + AddContentTypeHeader(req_, "application/json"); + } + evhtp_send_reply_chunk_start(req_, code); + evhtp_request_resume(req_); +} + +void +HTTPAPIServer::GenerateRequestClass::ChunkResponseCallback( + evthr_t* thr, void* arg, void* shared) +{ + auto infer_request = + reinterpret_cast(arg); + infer_request->SendChunkResponse(false /* end */); +} + +void +HTTPAPIServer::GenerateRequestClass::EndResponseCallback( + evthr_t* thr, void* arg, void* shared) +{ + auto infer_request = + reinterpret_cast(arg); + + infer_request->SendChunkResponse(true /* end */); + evhtp_send_reply_chunk_end(infer_request->EvHtpRequest()); + delete infer_request; +} + +void +HTTPAPIServer::GenerateRequestClass::SendChunkResponse(bool end) +{ + // check if response count in the case of non-streaming + if (!streaming_) { + std::lock_guard lk(res_mtx_); + // For non-streaming, wait until end + if (!end) { + return; + } + if (pending_http_responses_.size() != 1) { + EVBufferAddErrorJson( + req_->buffer_out, TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + "generate expects model to produce exactly 1 " + "response, use generate stream for model that " + "generates various number of responses")); + evhtp_send_reply_chunk(req_, req_->buffer_out); + return; + } + } + + evbuffer* buffer = nullptr; + { + std::lock_guard lk(res_mtx_); + // This function may be called with no pending responses when + // response complete callback is invoked with flag-only + if (pending_http_responses_.empty()) { + return; + } + buffer = pending_http_responses_.front(); + pending_http_responses_.pop(); + } + evhtp_send_reply_chunk(req_, buffer); + evbuffer_free(buffer); + +#ifdef TRITON_ENABLE_TRACING + if (trace_ != nullptr) { + // [FIXME] currently send_start_ns / send_end_ns is + // not captured in evhtp when response is sent in chunks + trace_->CaptureTimestamp("HTTP_SEND_START", req_->send_start_ns); + trace_->CaptureTimestamp("HTTP_SEND_END", req_->send_end_ns); + } +#endif // TRITON_ENABLE_TRACING +} + +TRITONSERVER_Error* +HTTPAPIServer::GenerateRequestClass::FinalizeResponse( + TRITONSERVER_InferenceResponse* response) +{ + triton_response_ = response; + RETURN_IF_ERR(TRITONSERVER_InferenceResponseError(response)); + + triton::common::TritonJson::Value response_json( + triton::common::TritonJson::ValueType::OBJECT); + + // Response metadata in addition to output tensor / parameter falls under + // "unspecified field" with predefined name: + // "id", "model_name", "model_version" + std::map triton_outputs; + const char* id = ""; + RETURN_IF_ERR(TRITONSERVER_InferenceResponseId(response, &id)); + if (strncmp(id, "", 1)) { + triton_outputs.emplace( + "id", TritonOutput(TritonOutput::Type::RESERVED, id)); + } + const char* model_name; + int64_t model_version; + RETURN_IF_ERR(TRITONSERVER_InferenceResponseModel( + response, &model_name, &model_version)); + triton_outputs.emplace( + "model_name", TritonOutput(TritonOutput::Type::RESERVED, model_name)); + triton_outputs.emplace( + "model_version", + TritonOutput( + TritonOutput::Type::RESERVED, std::to_string(model_version))); + + // If the response has any parameters, convert them to JSON. + uint32_t parameter_count; + RETURN_IF_ERR( + TRITONSERVER_InferenceResponseParameterCount(response, ¶meter_count)); + if (parameter_count > 0) { + for (uint32_t pidx = 0; pidx < parameter_count; ++pidx) { + const char* name; + TRITONSERVER_ParameterType type; + const void* vvalue; + RETURN_IF_ERR(TRITONSERVER_InferenceResponseParameter( + response, pidx, &name, &type, &vvalue)); + switch (type) { + case TRITONSERVER_PARAMETER_BOOL: + case TRITONSERVER_PARAMETER_INT: + case TRITONSERVER_PARAMETER_STRING: + triton_outputs.emplace( + name, TritonOutput(TritonOutput::Type::PARAMETER, pidx)); + break; + case TRITONSERVER_PARAMETER_BYTES: + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, + (std::string("Response parameter '") + name + + "' has type 'TRITONSERVER_PARAMETER_BYTES' which is " + "not currently supported") + .c_str()); + break; + } + } + } + + // Go through each response output and transfer information to JSON + uint32_t output_count; + RETURN_IF_ERR( + TRITONSERVER_InferenceResponseOutputCount(response, &output_count)); + + for (uint32_t idx = 0; idx < output_count; ++idx) { + const char* cname; + TRITONSERVER_DataType datatype; + const int64_t* shape; + uint64_t dim_count; + const void* base; + size_t byte_size; + TRITONSERVER_MemoryType memory_type; + int64_t memory_type_id; + void* userp; + + RETURN_IF_ERR(TRITONSERVER_InferenceResponseOutput( + response, idx, &cname, &datatype, &shape, &dim_count, &base, &byte_size, + &memory_type, &memory_type_id, &userp)); + triton_outputs.emplace( + cname, TritonOutput(TritonOutput::Type::TENSOR, idx)); + } + + std::set mapped_outputs; + RETURN_IF_ERR(ConvertGenerateResponse( + triton_outputs, response_schema_, &response_json, &mapped_outputs)); + if (response_schema_->allow_unspecified_) { + for (const auto& to : triton_outputs) { + if (mapped_outputs.find(to.first) == mapped_outputs.end()) { + RETURN_IF_ERR(ExactMappingOutput( + to.first, to.second, &response_json, &mapped_outputs)); + } + } + } + + // [FIXME] compression + evbuffer* response_body = evbuffer_new(); + if (streaming_) { + static std::string sse_prefix = "data: "; + evbuffer_add(response_body, sse_prefix.c_str(), sse_prefix.length()); + } + // Write json metadata into response evbuffer + triton::common::TritonJson::WriteBuffer buffer; + RETURN_IF_ERR(response_json.Write(&buffer)); + evbuffer_add(response_body, buffer.Base(), buffer.Size()); + if (streaming_) { + static std::string sse_suffix = "\n\n"; + evbuffer_add(response_body, sse_suffix.c_str(), sse_suffix.length()); + } + + { + std::lock_guard lk(res_mtx_); + pending_http_responses_.emplace(response_body); + } + + return nullptr; // success +} + +void +HTTPAPIServer::GenerateRequestClass::AddErrorJson(TRITONSERVER_Error* error) +{ + evbuffer* buffer = evbuffer_new(); + if (streaming_) { + static std::string sse_prefix = "data: "; + evbuffer_add(buffer, sse_prefix.c_str(), sse_prefix.length()); + } + EVBufferAddErrorJson(buffer, error); + if (streaming_) { + static std::string sse_suffix = "\n\n"; + evbuffer_add(buffer, sse_suffix.c_str(), sse_suffix.length()); + } + TRITONSERVER_ErrorDelete(error); + { + std::lock_guard lk(res_mtx_); + pending_http_responses_.emplace(buffer); + } +} + +TRITONSERVER_Error* +HTTPAPIServer::GenerateRequestClass::ConvertGenerateResponse( + const std::map< + std::string, HTTPAPIServer::GenerateRequestClass::TritonOutput>& + output_metadata, + const MappingSchema* schema, + triton::common::TritonJson::Value* generate_response, + std::set* mapped_outputs) +{ + for (auto& nested : schema->children_) { + switch (nested.second->kind_) { + case MappingSchema::Kind::MAPPING_SCHEMA: { + triton::common::TritonJson::Value nested_response( + *generate_response, triton::common::TritonJson::ValueType::OBJECT); + RETURN_IF_ERR(ConvertGenerateResponse( + output_metadata, nested.second.get(), &nested_response, + mapped_outputs)); + RETURN_IF_ERR(generate_response->Add( + nested.first.c_str(), std::move(nested_response))); + break; + } + case MappingSchema::Kind::EXACT_MAPPING: { + auto it = output_metadata.find(nested.first); + if (it == output_metadata.end()) { + if (!nested.second->allow_unspecified_) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + (std::string("Schema requires output '") + nested.first + + "' to be produced by the model.") + .c_str()); + } + } else { + RETURN_IF_ERR(ExactMappingOutput( + nested.first, it->second, generate_response, mapped_outputs)); + } + break; + } + default: + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, "Unsupported schema kind"); + } + } + return nullptr; // success +} + +TRITONSERVER_Error* +HTTPAPIServer::GenerateRequestClass::ExactMappingOutput( + const std::string& name, + const HTTPAPIServer::GenerateRequestClass::TritonOutput& triton_output, + triton::common::TritonJson::Value* generate_response, + std::set* mapped_outputs) +{ + mapped_outputs->emplace(name); + + switch (triton_output.type) { + case TritonOutput::Type::RESERVED: { + generate_response->AddStringRef( + name.c_str(), triton_output.value.c_str()); + break; + } + case TritonOutput::Type::PARAMETER: { + const char* name; + TRITONSERVER_ParameterType type; + const void* vvalue; + RETURN_IF_ERR(TRITONSERVER_InferenceResponseParameter( + triton_response_, triton_output.index, &name, &type, &vvalue)); + switch (type) { + case TRITONSERVER_PARAMETER_BOOL: + RETURN_IF_ERR(generate_response->AddBool( + name, *(reinterpret_cast(vvalue)))); + break; + case TRITONSERVER_PARAMETER_INT: + RETURN_IF_ERR(generate_response->AddInt( + name, *(reinterpret_cast(vvalue)))); + break; + case TRITONSERVER_PARAMETER_STRING: + RETURN_IF_ERR(generate_response->AddStringRef( + name, reinterpret_cast(vvalue))); + break; + case TRITONSERVER_PARAMETER_BYTES: + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, + (std::string("Response parameter '") + name + + "' has type 'TRITONSERVER_PARAMETER_BYTES' which is " + "not currently supported") + .c_str()); + break; + } + break; + } + case TritonOutput::Type::TENSOR: { + const char* cname; + TRITONSERVER_DataType datatype; + const int64_t* shape; + uint64_t dim_count; + const void* base; + size_t byte_size; + TRITONSERVER_MemoryType memory_type; + int64_t memory_type_id; + void* userp; + + RETURN_IF_ERR(TRITONSERVER_InferenceResponseOutput( + triton_response_, triton_output.index, &cname, &datatype, &shape, + &dim_count, &base, &byte_size, &memory_type, &memory_type_id, + &userp)); + + auto info = reinterpret_cast(userp); + // sanity check + if (info->kind_ != AllocPayload::OutputInfo::JSON) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + (std::string("non-JSON output response type is requested for '") + + cname + "'") + .c_str()); + } + + size_t element_count = 1; + for (size_t j = 0; j < dim_count; j++) { + element_count *= shape[j]; + } + + triton::common::TritonJson::Value data_json( + *generate_response, triton::common::TritonJson::ValueType::ARRAY); + RETURN_IF_ERR(WriteDataToJson( + &data_json, cname, datatype, base, byte_size, element_count)); + if (element_count > 1) { + RETURN_IF_ERR(generate_response->Add(cname, std::move(data_json))); + } else { + // if only 1 element, strip out the array + triton::common::TritonJson::Value el; + RETURN_IF_ERR(data_json.At(0, &el)); + RETURN_IF_ERR(generate_response->Add(cname, std::move(el))); + } + break; + } + } + return nullptr; // success +} void HTTPAPIServer::Handle(evhtp_request_t* req) @@ -3597,6 +4392,14 @@ HTTPAPIServer::Handle(evhtp_request_t* req) // model infer HandleInfer(req, model_name, version); return; + } else if (kind == "generate") { + // text generation + HandleGenerate(req, model_name, version, false /* streaming */); + return; + } else if (kind == "generate_stream") { + // text generation (streaming) + HandleGenerate(req, model_name, version, true /* streaming */); + return; } else if (kind == "config") { // model configuration HandleModelConfig(req, model_name, version); diff --git a/src/http_server.h b/src/http_server.h index 265e012896..0b10adb6cc 100644 --- a/src/http_server.h +++ b/src/http_server.h @@ -1,4 +1,4 @@ -// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -31,6 +31,7 @@ #include #include #include +#include #include #include #include @@ -44,6 +45,26 @@ namespace triton { namespace server { +class MappingSchema { + public: + enum class Kind { + EXACT_MAPPING, + // An object of this kind means it is a nested mapping schema. + MAPPING_SCHEMA + }; + std::map> children_; + // Whether an unspecified key is allowed. If true, + // * for requests, the unspecified key will be converted to Triton input + // following the EXACT_MAPPING rule. + // * for responses, the Triton output will be converted to JSON key-value + // pairs at top level if the name is unspecified in the schema, + // following the EXACT_MAPPING rule. + const bool allow_unspecified_{true}; + const Kind kind_{Kind::EXACT_MAPPING}; + + private: +}; + // Generic HTTP server using evhtp class HTTPServer { public: @@ -186,6 +207,9 @@ class HTTPAPIServer : public HTTPServer { // send the response. class InferRequestClass { public: + // [FIXME] decompression / compression should be handled implicitly + // within InferRequestClass. This alleviate the check for decompressed + // buffer in HTTPServer code. explicit InferRequestClass( TRITONSERVER_Server* server, evhtp_request_t* req, DataCompressor::Type response_compression_type); @@ -199,7 +223,7 @@ class HTTPAPIServer : public HTTPServer { static void InferResponseComplete( TRITONSERVER_InferenceResponse* response, const uint32_t flags, void* userp); - TRITONSERVER_Error* FinalizeResponse( + virtual TRITONSERVER_Error* FinalizeResponse( TRITONSERVER_InferenceResponse* response); // Helper function to set infer response header in the form specified by @@ -230,6 +254,96 @@ class HTTPAPIServer : public HTTPServer { std::atomic response_count_; }; + class GenerateRequestClass : public InferRequestClass { + public: + explicit GenerateRequestClass( + TRITONSERVER_Server* server, evhtp_request_t* req, + DataCompressor::Type response_compression_type, + const MappingSchema* request_schema, + const MappingSchema* response_schema, bool streaming, + TRITONSERVER_InferenceRequest* triton_request) + : InferRequestClass(server, req, response_compression_type), + request_schema_(request_schema), response_schema_(response_schema), + streaming_(streaming), triton_request_(triton_request) + { + } + virtual ~GenerateRequestClass(); + + // [FIXME] Specialize response complete function for now, should have + // been a dispatcher and call into object specific response function. + static void InferResponseComplete( + TRITONSERVER_InferenceResponse* response, const uint32_t flags, + void* userp); + static void ChunkResponseCallback(evthr_t* thr, void* arg, void* shared); + static void EndResponseCallback(evthr_t* thr, void* arg, void* shared); + // Return whether the response is ending + void SendChunkResponse(bool end); + + // Response preparation + TRITONSERVER_Error* FinalizeResponse( + TRITONSERVER_InferenceResponse* response) override; + void AddErrorJson(TRITONSERVER_Error* error); + void StartResponse(evhtp_res code); + + // [DLIS-5551] currently always performs basic conversion, only maps schema + // of EXACT_MAPPING kind. MAPPING_SCHEMA and upcoming kinds are for + // customized conversion where a detailed schema will be provided. + TRITONSERVER_Error* ConvertGenerateRequest( + std::map& + input_metadata, + const MappingSchema* schema, + triton::common::TritonJson::Value& generate_request); + + const MappingSchema* RequestSchema() { return request_schema_; } + const MappingSchema* ResponseSchema() { return response_schema_; } + + private: + struct TritonOutput { + enum class Type { RESERVED, TENSOR, PARAMETER }; + TritonOutput(Type t, const std::string& val) : type(t), value(val) {} + explicit TritonOutput(Type t, uint32_t i) : type(t), index(i) {} + Type type; + // RESERVED type + std::string value; + // TENSOR, PARAMETER type + uint32_t index; + }; + TRITONSERVER_Error* ExactMappingInput( + const std::string& name, triton::common::TritonJson::Value& value, + std::map& + input_metadata); + + // [DLIS-5551] currently always performs basic conversion, only maps schema + // of EXACT_MAPPING kind. MAPPING_SCHEMA and upcoming kinds are for + // customized conversion where a detailed schema will be provided. + TRITONSERVER_Error* ConvertGenerateResponse( + const std::map& output_metadata, + const MappingSchema* schema, + triton::common::TritonJson::Value* generate_response, + std::set* mapped_outputs); + TRITONSERVER_Error* ExactMappingOutput( + const std::string& name, const TritonOutput& triton_output, + triton::common::TritonJson::Value* generate_response, + std::set* mapped_outputs); + + const MappingSchema* request_schema_{nullptr}; + const MappingSchema* response_schema_{nullptr}; + const bool streaming_{false}; + // Pointer to associated Triton request, this class does not own the + // request and must not reference it after a successful + // TRITONSERVER_ServerInferAsync. + TRITONSERVER_InferenceRequest* triton_request_{nullptr}; + // Placeholder to completing response, this class does not own + // the response. + TRITONSERVER_InferenceResponse* triton_response_{nullptr}; + // As InferResponseComplete and ChunkResponseCallback are called in + // different threads, need to have dedicated buffers for each response and + // ensure mutual exclusive access. + std::mutex res_mtx_; + std::queue pending_http_responses_; + bool end_{false}; + }; + protected: explicit HTTPAPIServer( const std::shared_ptr& server, @@ -238,6 +352,7 @@ class HTTPAPIServer : public HTTPServer { const int32_t port, const bool reuse_port, const std::string& address, const std::string& header_forward_pattern, const int thread_cnt); virtual void Handle(evhtp_request_t* req) override; + // [FIXME] extract to "infer" class virtual std::unique_ptr CreateInferRequest( evhtp_request_t* req) { @@ -255,6 +370,10 @@ class HTTPAPIServer : public HTTPServer { virtual DataCompressor::Type GetRequestCompressionType(evhtp_request_t* req); virtual DataCompressor::Type GetResponseCompressionType(evhtp_request_t* req); + + TRITONSERVER_Error* GetModelConfig( + const std::string& model_name, int64_t requested_model_version, + std::string* config_json); TRITONSERVER_Error* GetContentLength( evhtp_request_t* req, evbuffer* decompressed_buffer, int32_t* content_length); @@ -318,7 +437,27 @@ class HTTPAPIServer : public HTTPServer { void HandleTrace(evhtp_request_t* req, const std::string& model_name = ""); void HandleLogging(evhtp_request_t* req); - TRITONSERVER_Error* EVBufferToTritonRequest( + // Text Generation / LLM format + //'streaming' selects the schema pair to convert request / response. + // 'streaming' also controls the response convention, if true, + // Server-Sent Events format will be used to send responses. + void HandleGenerate( + evhtp_request_t* req, const std::string& model_name, + const std::string& model_version_str, bool streaming); + + // 'meta_data_root' is the root JSON document for 'input_metadata'. + // In TritonJson, the Value objects are references to the root document. + // Therefore the document must stay valid. + TRITONSERVER_Error* ModelInputMetadata( + const std::string& model_name, const int64_t model_version, + std::map* input_metadata, + triton::common::TritonJson::Value* meta_data_root); + + // Parses full evhtp request and its evbuffers into JSON. + TRITONSERVER_Error* EVRequestToJson( + evhtp_request_t* req, triton::common::TritonJson::Value* request_json); + // Parses evhtp request buffers into Triton Inference Request. + TRITONSERVER_Error* EVRequestToTritonRequest( evhtp_request_t* req, const std::string& model_name, TRITONSERVER_InferenceRequest* irequest, evbuffer* decompressed_buffer, InferRequestClass* infer_req, size_t header_length); @@ -367,6 +506,16 @@ class HTTPAPIServer : public HTTPServer { re2::RE2 systemsharedmemory_regex_; re2::RE2 cudasharedmemory_regex_; re2::RE2 trace_regex_; + + // [DLIS-5551] currently always performs basic conversion, only maps schema + // of EXACT_MAPPING kind. MAPPING_SCHEMA and upcoming kinds are for + // customized conversion where a detailed schema will be provided. + std::unique_ptr generate_request_schema_{new MappingSchema()}; + std::unique_ptr generate_response_schema_{new MappingSchema()}; + std::unique_ptr generate_stream_response_schema_{ + new MappingSchema()}; + std::unique_ptr generate_stream_request_schema_{ + new MappingSchema()}; }; }} // namespace triton::server diff --git a/src/test/CMakeLists.txt b/src/test/CMakeLists.txt index d021a51a15..25049624f8 100644 --- a/src/test/CMakeLists.txt +++ b/src/test/CMakeLists.txt @@ -84,6 +84,7 @@ if(${TRITON_ENABLE_HTTP} OR ${TRITON_ENABLE_METRICS} OR data_compressor_test PRIVATE triton-core-serverapi # from repo-core + triton-core-serverstub # from repo-core GTest::gtest GTest::gtest_main ${LIBEVENT_LIBRARIES} From 2bf543b0bbe6108d8aa03060eab222183cc101a5 Mon Sep 17 00:00:00 2001 From: Jacky <18255193+kthui@users.noreply.github.com> Date: Fri, 6 Oct 2023 09:51:07 -0700 Subject: [PATCH 2/2] Add Python backend request cancellation test (#6364) * Add cancelled response status test * Add Python backend request cancellation test * Add Python backend decoupled request cancellation test * Simplified response if cancelled * Test response_sender.send() after closed * Rollback test response_sender.send() after closed * Rollback non-decoupled any response on cancel --- .../decoupled/decoupled_test.py | 33 ++++++ qa/L0_backend_python/decoupled/test.sh | 7 +- .../lifecycle/lifecycle_test.py | 38 ++++++ qa/L0_backend_python/lifecycle/test.sh | 6 +- qa/python_models/error_code/model.py | 1 + qa/python_models/execute_cancel/config.pbtxt | 47 ++++++++ qa/python_models/execute_cancel/model.py | 108 ++++++++++++++++++ 7 files changed, 238 insertions(+), 2 deletions(-) create mode 100644 qa/python_models/execute_cancel/config.pbtxt create mode 100644 qa/python_models/execute_cancel/model.py diff --git a/qa/L0_backend_python/decoupled/decoupled_test.py b/qa/L0_backend_python/decoupled/decoupled_test.py index 21fa8b757e..f0ca870664 100755 --- a/qa/L0_backend_python/decoupled/decoupled_test.py +++ b/qa/L0_backend_python/decoupled/decoupled_test.py @@ -256,6 +256,39 @@ def test_decoupled_send_after_close_error(self): "The completed request size must be zero.", ) + def test_decoupled_execute_cancel(self): + model_name = "execute_cancel" + log_path = "decoupled_server.log" + execute_delay = 4.0 # seconds + shape = [1, 1] + + user_data = UserData() + with grpcclient.InferenceServerClient("localhost:8001") as client: + client.start_stream(callback=partial(callback, user_data)) + input_data = np.array([[execute_delay]], dtype=np.float32) + inputs = [ + grpcclient.InferInput( + "EXECUTE_DELAY", shape, np_to_triton_dtype(input_data.dtype) + ) + ] + inputs[0].set_data_from_numpy(input_data) + client.async_stream_infer(model_name, inputs) + time.sleep(2) # model delay for decoupling request and response sender + time.sleep(2) # ensure the request is executing + client.stop_stream(cancel_requests=True) + time.sleep(2) # ensure the cancellation is delivered + + self.assertFalse(user_data._completed_requests.empty()) + while not user_data._completed_requests.empty(): + data_item = user_data._completed_requests.get() + self.assertIsInstance(data_item, InferenceServerException) + self.assertEqual(data_item.status(), "StatusCode.CANCELLED") + + with open(log_path, mode="r", encoding="utf-8", errors="strict") as f: + log_text = f.read() + self.assertIn("[execute_cancel] Request not cancelled at 1.0 s", log_text) + self.assertIn("[execute_cancel] Request cancelled at ", log_text) + if __name__ == "__main__": unittest.main() diff --git a/qa/L0_backend_python/decoupled/test.sh b/qa/L0_backend_python/decoupled/test.sh index b4fa4ffe75..07c8f5b4ee 100755 --- a/qa/L0_backend_python/decoupled/test.sh +++ b/qa/L0_backend_python/decoupled/test.sh @@ -27,7 +27,7 @@ CLIENT_PY=./decoupled_test.py CLIENT_LOG="./decoupled_client.log" -EXPECTED_NUM_TESTS="5" +EXPECTED_NUM_TESTS="6" TEST_RESULT_FILE='test_results.txt' TRITON_DIR=${TRITON_DIR:="/opt/tritonserver"} SERVER=${TRITON_DIR}/bin/tritonserver @@ -50,6 +50,11 @@ mkdir -p models/dlpack_add_sub/1/ cp ../../python_models/dlpack_add_sub/model.py models/dlpack_add_sub/1/ cp ../../python_models/dlpack_add_sub/config.pbtxt models/dlpack_add_sub/ +mkdir -p models/execute_cancel/1/ +cp ../../python_models/execute_cancel/model.py ./models/execute_cancel/1/ +cp ../../python_models/execute_cancel/config.pbtxt ./models/execute_cancel/ +echo "model_transaction_policy { decoupled: True }" >> ./models/execute_cancel/config.pbtxt + git clone https://github.com/triton-inference-server/python_backend -b $PYTHON_BACKEND_REPO_TAG mkdir -p models/square_int32/1/ cp python_backend/examples/decoupled/square_model.py models/square_int32/1/model.py diff --git a/qa/L0_backend_python/lifecycle/lifecycle_test.py b/qa/L0_backend_python/lifecycle/lifecycle_test.py index 9c3bf7efa9..82856bbd32 100755 --- a/qa/L0_backend_python/lifecycle/lifecycle_test.py +++ b/qa/L0_backend_python/lifecycle/lifecycle_test.py @@ -31,6 +31,7 @@ sys.path.append("../../common") import queue +import time import unittest from functools import partial @@ -70,6 +71,7 @@ def test_error_code(self): ("UNAVAILABLE", "[StatusCode.UNAVAILABLE]"), ("UNSUPPORTED", "[StatusCode.UNIMPLEMENTED]"), ("ALREADY_EXISTS", "[StatusCode.ALREADY_EXISTS]"), + ("CANCELLED", "[StatusCode.CANCELLED]"), ("(default)", "[StatusCode.INTERNAL] unrecognized"), ] with self._shm_leak_detector.Probe() as shm_probe: @@ -91,6 +93,42 @@ def test_error_code(self): expected_grpc_error_start + " error code: " + error, ) + def test_execute_cancel(self): + model_name = "execute_cancel" + log_path = "lifecycle_server.log" + execute_delay = 4.0 # seconds + shape = [1, 1] + response = {"responded": False, "result": None, "error": None} + + def callback(result, error): + response["responded"] = True + response["result"] = result + response["error"] = error + + with self._shm_leak_detector.Probe() as shm_probe: + with grpcclient.InferenceServerClient("localhost:8001") as client: + input_data = np.array([[execute_delay]], dtype=np.float32) + inputs = [ + grpcclient.InferInput( + "EXECUTE_DELAY", shape, np_to_triton_dtype(input_data.dtype) + ) + ] + inputs[0].set_data_from_numpy(input_data) + exec_future = client.async_infer(model_name, inputs, callback) + time.sleep(2) # ensure the request is executing + self.assertFalse(response["responded"]) + exec_future.cancel() + time.sleep(2) # ensure the cancellation is delivered + self.assertTrue(response["responded"]) + + self.assertEqual(response["result"], None) + self.assertIsInstance(response["error"], InferenceServerException) + self.assertEqual(response["error"].status(), "StatusCode.CANCELLED") + with open(log_path, mode="r", encoding="utf-8", errors="strict") as f: + log_text = f.read() + self.assertIn("[execute_cancel] Request not cancelled at 1.0 s", log_text) + self.assertIn("[execute_cancel] Request cancelled at ", log_text) + def test_batch_error(self): # The execute_error model returns an error for the first and third # request and successfully processes the second request. This is making diff --git a/qa/L0_backend_python/lifecycle/test.sh b/qa/L0_backend_python/lifecycle/test.sh index 2abf107813..eb7f868940 100755 --- a/qa/L0_backend_python/lifecycle/test.sh +++ b/qa/L0_backend_python/lifecycle/test.sh @@ -26,7 +26,7 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. CLIENT_LOG="./lifecycle_client.log" -EXPECTED_NUM_TESTS="4" +EXPECTED_NUM_TESTS="5" TEST_RESULT_FILE='test_results.txt' source ../common.sh source ../../common/util.sh @@ -44,6 +44,10 @@ mkdir -p models/error_code/1/ cp ../../python_models/error_code/model.py ./models/error_code/1/ cp ../../python_models/error_code/config.pbtxt ./models/error_code/ +mkdir -p models/execute_cancel/1/ +cp ../../python_models/execute_cancel/model.py ./models/execute_cancel/1/ +cp ../../python_models/execute_cancel/config.pbtxt ./models/execute_cancel/ + mkdir -p models/execute_error/1/ cp ../../python_models/execute_error/model.py ./models/execute_error/1/ cp ../../python_models/execute_error/config.pbtxt ./models/execute_error/ diff --git a/qa/python_models/error_code/model.py b/qa/python_models/error_code/model.py index 350457ca79..078a4afb73 100644 --- a/qa/python_models/error_code/model.py +++ b/qa/python_models/error_code/model.py @@ -37,6 +37,7 @@ def execute(self, requests): "UNAVAILABLE": pb_utils.TritonError.UNAVAILABLE, "UNSUPPORTED": pb_utils.TritonError.UNSUPPORTED, "ALREADY_EXISTS": pb_utils.TritonError.ALREADY_EXISTS, + "CANCELLED": pb_utils.TritonError.CANCELLED, } responses = [] diff --git a/qa/python_models/execute_cancel/config.pbtxt b/qa/python_models/execute_cancel/config.pbtxt new file mode 100644 index 0000000000..df509863ad --- /dev/null +++ b/qa/python_models/execute_cancel/config.pbtxt @@ -0,0 +1,47 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +name: "execute_cancel" +backend: "python" +max_batch_size: 1 + +input [ + { + name: "EXECUTE_DELAY" + data_type: TYPE_FP32 + dims: [ 1 ] + } +] + +output [ + { + name: "DUMMY_OUT" + data_type: TYPE_FP32 + dims: [ 1 ] + } +] + +instance_group [{ kind: KIND_CPU }] diff --git a/qa/python_models/execute_cancel/model.py b/qa/python_models/execute_cancel/model.py new file mode 100644 index 0000000000..ec7b96ec1a --- /dev/null +++ b/qa/python_models/execute_cancel/model.py @@ -0,0 +1,108 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import json +import threading +import time + +import triton_python_backend_utils as pb_utils + + +class TritonPythonModel: + def initialize(self, args): + self._logger = pb_utils.Logger + self._model_config = json.loads(args["model_config"]) + self._using_decoupled = pb_utils.using_decoupled_model_transaction_policy( + self._model_config + ) + + def execute(self, requests): + processed_requests = [] + for request in requests: + delay_tensor = pb_utils.get_input_tensor_by_name( + request, "EXECUTE_DELAY" + ).as_numpy() + delay = delay_tensor[0][0] # seconds + if self._using_decoupled: + processed_requests.append( + {"response_sender": request.get_response_sender(), "delay": delay} + ) + else: + processed_requests.append({"request": request, "delay": delay}) + if self._using_decoupled: + return self._execute_decoupled(processed_requests) + return self._execute_processed_requests(processed_requests) + + def _execute_processed_requests(self, processed_requests): + responses = [] + for processed_request in processed_requests: + error = pb_utils.TritonError(message="not cancelled") + object_to_check_cancelled = None + if "response_sender" in processed_request: + object_to_check_cancelled = processed_request["response_sender"] + elif "request" in processed_request: + object_to_check_cancelled = processed_request["request"] + delay = processed_request["delay"] # seconds + time_elapsed = 0.0 # seconds + while time_elapsed < delay: + time.sleep(1) + time_elapsed += 1.0 + if object_to_check_cancelled.is_cancelled(): + self._logger.log_info( + "[execute_cancel] Request cancelled at " + + str(time_elapsed) + + " s" + ) + error = pb_utils.TritonError( + message="cancelled", code=pb_utils.TritonError.CANCELLED + ) + break + self._logger.log_info( + "[execute_cancel] Request not cancelled at " + + str(time_elapsed) + + " s" + ) + responses.append(pb_utils.InferenceResponse(error=error)) + return responses + + def _execute_decoupled(self, processed_requests): + def response_thread(execute_processed_requests, processed_requests): + time.sleep(2) # execute after requests are released + responses = execute_processed_requests(processed_requests) + for i in range(len(responses)): # len(responses) == len(processed_requests) + response_sender = processed_requests[i]["response_sender"] + response_sender.send(responses[i]) + response_sender.send( + flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL + ) + + thread = threading.Thread( + target=response_thread, + args=(self._execute_processed_requests, processed_requests), + ) + thread.daemon = True + thread.start() + return None