diff --git a/qa/L0_cuda_shared_memory/cuda_shared_memory_test.py b/qa/L0_cuda_shared_memory/cuda_shared_memory_test.py index 51137e8934..01d50d67cc 100755 --- a/qa/L0_cuda_shared_memory/cuda_shared_memory_test.py +++ b/qa/L0_cuda_shared_memory/cuda_shared_memory_test.py @@ -30,6 +30,7 @@ sys.path.append("../common") +import base64 import os import time import unittest @@ -37,6 +38,7 @@ import infer_util as iu import numpy as np +import requests import test_util as tu import tritonclient.grpc as grpcclient import tritonclient.http as httpclient @@ -564,5 +566,108 @@ def callback(user_data, result, error): self._cleanup_server(shm_handles) +class CudaSharedMemoryTestRawHttpRequest(unittest.TestCase): + def setUp(self): + self.url = "localhost:8000" + self.client = httpclient.InferenceServerClient(url=self.url, verbose=True) + self.valid_shm_handle = None + + def tearDown(self): + self.client.unregister_cuda_shared_memory() + if self.valid_shm_handle: + cshm.destroy_shared_memory_region(self.valid_shm_handle) + self.client.close() + + def _generate_mock_base64_raw_handle(self, data_length): + original_data_length = data_length * 3 // 4 + random_data = b"A" * original_data_length + encoded_data = base64.b64encode(random_data) + + assert ( + len(encoded_data) == data_length + ), "Encoded data length does not match the required length." + return encoded_data + + def _send_register_cshm_request(self, raw_handle, device_id, byte_size, shm_name): + cuda_shared_memory_register_request = { + "raw_handle": {"b64": raw_handle.decode("utf-8")}, + "device_id": device_id, + "byte_size": byte_size, + } + + url = "http://{}/v2/cudasharedmemory/region/{}/register".format( + self.url, shm_name + ) + headers = {"Content-Type": "application/json"} + + # Send POST request + response = requests.post( + url, headers=headers, json=cuda_shared_memory_register_request + ) + return response + + def test_exceeds_cshm_handle_size_limit(self): + # byte_size greater than INT_MAX + byte_size = 1 << 31 + device_id = 0 + shm_name = "invalid_shm" + + raw_handle = self._generate_mock_base64_raw_handle(byte_size) + response = self._send_register_cshm_request( + raw_handle, device_id, byte_size, shm_name + ) + self.assertNotEqual(response.status_code, 200) + + try: + error_message = response.json().get("error", "") + self.assertIn( + "'raw_handle' exceeds the maximum allowed data size limit INT_MAX", + error_message, + ) + except ValueError: + self.fail("Response is not valid JSON") + + def test_invalid_small_cshm_handle(self): + byte_size = 64 + device_id = 0 + shm_name = "invalid_shm" + + raw_handle = self._generate_mock_base64_raw_handle(byte_size) + response = self._send_register_cshm_request( + raw_handle, device_id, byte_size, shm_name + ) + self.assertNotEqual(response.status_code, 200) + + try: + error_message = response.json().get("error", "") + self.assertIn( + "'raw_handle' must be a valid base64 encoded cudaIpcMemHandle_t", + error_message, + ) + except ValueError: + self.fail("Response is not valid JSON") + + def test_valid_cshm_handle(self): + byte_size = 64 + device_id = 0 + shm_name = "test_shm" + + # Create valid shared memory + self.valid_shm_handle = cshm.create_shared_memory_region( + shm_name, byte_size, device_id + ) + raw_handle = cshm.get_raw_handle(self.valid_shm_handle) + + response = self._send_register_cshm_request( + raw_handle, device_id, byte_size, shm_name + ) + self.assertEqual(response.status_code, 200) + + # Verify shared memory status + status = self.client.get_cuda_shared_memory_status() + self.assertEqual(len(status), 1) + self.assertEqual(status[0]["name"], shm_name) + + if __name__ == "__main__": unittest.main() diff --git a/qa/L0_cuda_shared_memory/test.sh b/qa/L0_cuda_shared_memory/test.sh index b7126a9295..2ab80c913f 100755 --- a/qa/L0_cuda_shared_memory/test.sh +++ b/qa/L0_cuda_shared_memory/test.sh @@ -84,6 +84,38 @@ for i in \ done done +for i in \ + test_exceeds_cshm_handle_size_limit \ + test_invalid_small_cshm_handle \ + test_valid_cshm_handle; do + SERVER_ARGS="--model-repository=`pwd`/models --log-verbose=1" + SERVER_LOG="./$i.server.log" + CLIENT_LOG="./$i.client.log" + run_server + if [ "$SERVER_PID" == "0" ]; then + echo -e "\n***\n*** Failed to start $SERVER\n***" + cat $SERVER_LOG + exit 1 + fi + echo "Test: $i, client type: HTTP" >>$CLIENT_LOG + set +e + python $SHM_TEST CudaSharedMemoryTestRawHttpRequest.$i >>$CLIENT_LOG 2>&1 + if [ $? -ne 0 ]; then + echo -e "\n***\n*** Test Failed\n***" + RET=1 + else + check_test_results $TEST_RESULT_FILE 1 + if [ $? -ne 0 ]; then + cat $CLIENT_LOG + echo -e "\n***\n*** Test Result Verification Failed\n***" + RET=1 + fi + fi + set -e + kill $SERVER_PID + wait $SERVER_PID +done + mkdir -p python_models/simple/1/ cp ../python_models/execute_delayed_model/model.py ./python_models/simple/1/ cp ../python_models/execute_delayed_model/config.pbtxt ./python_models/simple/ diff --git a/qa/L0_http/http_test.py b/qa/L0_http/http_test.py index 4432fe9186..638ccbbbf8 100755 --- a/qa/L0_http/http_test.py +++ b/qa/L0_http/http_test.py @@ -1,5 +1,5 @@ #!/usr/bin/python -# Copyright 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -29,6 +29,8 @@ sys.path.append("../common") +import base64 +import json import threading import time import unittest @@ -44,6 +46,9 @@ class HttpTest(tu.TestResultCollector): def _get_infer_url(self, model_name): return "http://localhost:8000/v2/models/{}/infer".format(model_name) + def _get_load_model_url(self, model_name): + return "http://localhost:8000/v2/repository/models/{}/load".format(model_name) + def _raw_binary_helper( self, model, input_bytes, expected_output_bytes, extra_headers={} ): @@ -231,6 +236,43 @@ def test_descriptive_status_code(self): ) t.join() + def test_loading_large_invalid_model(self): + # Generate large base64 encoded data + data_length = 1 << 31 + int_max = (1 << 31) - 1 + random_data = b"A" * data_length + encoded_data = base64.b64encode(random_data) + + assert ( + len(encoded_data) > int_max + ), "Encoded data length does not match the required length." + + # Prepare payload with large base64 encoded data + payload = { + "parameters": { + "config": json.dumps({"backend": "onnxruntime"}), + "file:1/model.onnx": encoded_data.decode("utf-8"), + } + } + headers = {"Content-Type": "application/json"} + + # Send POST request + response = requests.post( + self._get_load_model_url("invalid_onnx"), headers=headers, json=payload + ) + + # Assert the response is not successful + self.assertNotEqual(response.status_code, 200) + try: + error_message = response.json().get("error", "") + self.assertIn( + "'file:1/model.onnx' exceeds the maximum allowed data size limit " + "INT_MAX", + error_message, + ) + except ValueError: + self.fail("Response is not valid JSON") + if __name__ == "__main__": unittest.main() diff --git a/qa/L0_http/test.sh b/qa/L0_http/test.sh index 572c527ba4..c8edf32a41 100755 --- a/qa/L0_http/test.sh +++ b/qa/L0_http/test.sh @@ -624,7 +624,7 @@ fi TEST_RESULT_FILE='test_results.txt' PYTHON_TEST=http_test.py -EXPECTED_NUM_TESTS=9 +EXPECTED_NUM_TESTS=10 set +e python $PYTHON_TEST >$CLIENT_LOG 2>&1 if [ $? -ne 0 ]; then diff --git a/src/common.cc b/src/common.cc index 289d868866..bf697e6752 100644 --- a/src/common.cc +++ b/src/common.cc @@ -1,4 +1,4 @@ -// Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -27,11 +27,16 @@ #include "common.h" #include +#include #include #include "restricted_features.h" #include "triton/core/tritonserver.h" +extern "C" { +#include +} + namespace triton { namespace server { TRITONSERVER_Error* @@ -102,4 +107,27 @@ Contains(const std::vector& vec, const std::string& str) return std::find(vec.begin(), vec.end(), str) != vec.end(); } +TRITONSERVER_Error* +DecodeBase64( + const char* input, size_t input_len, std::vector& decoded_data, + size_t& decoded_size, const std::string& name) +{ + if (input_len > static_cast(INT_MAX)) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + ("'" + name + "' exceeds the maximum allowed data size limit INT_MAX") + .c_str()); + } + + // The decoded size cannot be larger than the input + decoded_data.resize(input_len + 1); + base64_decodestate state; + base64_init_decodestate(&state); + + decoded_size = + base64_decode_block(input, input_len, decoded_data.data(), &state); + + return nullptr; +} + }} // namespace triton::server diff --git a/src/common.h b/src/common.h index 011546d637..fb8f9fdee6 100644 --- a/src/common.h +++ b/src/common.h @@ -1,4 +1,4 @@ -// Copyright 2019-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -167,6 +167,18 @@ int64_t GetElementCount(const std::vector& dims); /// \return True if the str is found, false otherwise. bool Contains(const std::vector& vec, const std::string& str); +/// Decodes a Base64 encoded string and stores the result in a vector. +/// +/// \param input The Base64 encoded input string to decode. +/// \param input_len The length of the input string. +/// \param decoded_data A vector to store the decoded data. +/// \param decoded_size The size of the decoded data. +/// \param name The name associated with the decoding process. +/// \return The error status. +TRITONSERVER_Error* DecodeBase64( + const char* input, size_t input_len, std::vector& decoded_data, + size_t& decoded_size, const std::string& name); + /// Joins container of strings into a single string delimited by /// 'delim'. /// diff --git a/src/http_server.cc b/src/http_server.cc index 958417bc24..baa13be6e8 100644 --- a/src/http_server.cc +++ b/src/http_server.cc @@ -46,10 +46,6 @@ #define TRITONJSON_STATUSSUCCESS nullptr #include "triton/common/triton_json.h" -extern "C" { -#include -} - namespace triton { namespace server { #define RETURN_AND_CALLBACK_IF_ERR(X, CALLBACK) \ @@ -1546,14 +1542,12 @@ HTTPAPIServer::HandleRepositoryControl( param = TRITONSERVER_ParameterNew( m.c_str(), TRITONSERVER_PARAMETER_STRING, param_str); } else if (m.rfind("file:", 0) == 0) { - // Decode base64 - base64_decodestate s; - base64_init_decodestate(&s); - - // The decoded can not be larger than the input... - binary_files.emplace_back(std::vector(param_len + 1)); - size_t decoded_size = base64_decode_block( - param_str, param_len, binary_files.back().data(), &s); + size_t decoded_size; + binary_files.emplace_back(std::vector()); + RETURN_AND_RESPOND_IF_ERR( + req, DecodeBase64( + param_str, param_len, binary_files.back(), + decoded_size, m)); param = TRITONSERVER_ParameterBytesNew( m.c_str(), binary_files.back().data(), decoded_size); } @@ -2443,13 +2437,13 @@ HTTPAPIServer::HandleCudaSharedMemory( } if (err == nullptr) { - base64_decodestate s; - base64_init_decodestate(&s); + size_t decoded_size; + std::vector raw_handle; + RETURN_AND_RESPOND_IF_ERR( + req, DecodeBase64( + b64_handle, b64_handle_len, raw_handle, decoded_size, + "raw_handle")); - // The decoded can not be larger than the input... - std::vector raw_handle(b64_handle_len + 1); - size_t decoded_size = base64_decode_block( - b64_handle, b64_handle_len, raw_handle.data(), &s); if (decoded_size != sizeof(cudaIpcMemHandle_t)) { err = TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_INVALID_ARG,