Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Resolve integer overflow in Load API file decoding #7787

Merged
merged 7 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions qa/L0_cuda_shared_memory/cuda_shared_memory_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,15 @@

sys.path.append("../common")

import base64
import os
import time
import unittest
from functools import partial

import infer_util as iu
import numpy as np
import requests
import test_util as tu
import tritonclient.grpc as grpcclient
import tritonclient.http as httpclient
Expand Down Expand Up @@ -564,5 +566,108 @@ def callback(user_data, result, error):
self._cleanup_server(shm_handles)


class CudaSharedMemoryTestRawHttpRequest(unittest.TestCase):
def setUp(self):
self.url = "localhost:8000"
self.client = httpclient.InferenceServerClient(url=self.url, verbose=True)
self.valid_shm_handle = None

def tearDown(self):
self.client.unregister_cuda_shared_memory()
if self.valid_shm_handle:
cshm.destroy_shared_memory_region(self.valid_shm_handle)
self.client.close()

def _generate_mock_base64_raw_handle(self, data_length):
original_data_length = data_length * 3 // 4
random_data = b"A" * original_data_length
encoded_data = base64.b64encode(random_data)

assert (
len(encoded_data) == data_length
), "Encoded data length does not match the required length."
return encoded_data

def _send_register_cshm_request(self, raw_handle, device_id, byte_size, shm_name):
cuda_shared_memory_register_request = {
"raw_handle": {"b64": raw_handle.decode("utf-8")},
"device_id": device_id,
"byte_size": byte_size,
}

url = "http://{}/v2/cudasharedmemory/region/{}/register".format(
self.url, shm_name
)
headers = {"Content-Type": "application/json"}

# Send POST request
response = requests.post(
url, headers=headers, json=cuda_shared_memory_register_request
)
return response

def test_exceeds_cshm_handle_size_limit(self):
# byte_size greater than INT_MAX
byte_size = 1 << 31
device_id = 0
shm_name = "invalid_shm"

raw_handle = self._generate_mock_base64_raw_handle(byte_size)
response = self._send_register_cshm_request(
raw_handle, device_id, byte_size, shm_name
)
self.assertNotEqual(response.status_code, 200)

try:
error_message = response.json().get("error", "")
self.assertIn(
"The length of 'raw_handle' exceeds the maximum allowed limit INT_MAX",
error_message,
)
except ValueError:
self.fail("Response is not valid JSON")

def test_invalid_small_cshm_handle(self):
byte_size = 64
device_id = 0
shm_name = "invalid_shm"

raw_handle = self._generate_mock_base64_raw_handle(byte_size)
response = self._send_register_cshm_request(
raw_handle, device_id, byte_size, shm_name
)
self.assertNotEqual(response.status_code, 200)

try:
error_message = response.json().get("error", "")
self.assertIn(
"'raw_handle' must be a valid base64 encoded cudaIpcMemHandle_t",
error_message,
)
except ValueError:
self.fail("Response is not valid JSON")

def test_valid_cshm_handle(self):
byte_size = 64
device_id = 0
shm_name = "test_shm"

# Create valid shared memory
self.valid_shm_handle = cshm.create_shared_memory_region(
shm_name, byte_size, device_id
)
raw_handle = cshm.get_raw_handle(self.valid_shm_handle)

response = self._send_register_cshm_request(
raw_handle, device_id, byte_size, shm_name
)
self.assertEqual(response.status_code, 200)

# Verify shared memory status
status = self.client.get_cuda_shared_memory_status()
self.assertEqual(len(status), 1)
self.assertEqual(status[0]["name"], shm_name)


pskiran1 marked this conversation as resolved.
Show resolved Hide resolved
if __name__ == "__main__":
unittest.main()
32 changes: 32 additions & 0 deletions qa/L0_cuda_shared_memory/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,38 @@ for i in \
done
done

for i in \
test_exceeds_cshm_handle_size_limit \
test_invalid_small_cshm_handle \
test_valid_cshm_handle; do
SERVER_ARGS="--model-repository=`pwd`/models --log-verbose=1"
SERVER_LOG="./$i.server.log"
CLIENT_LOG="./$i.client.log"
run_server
if [ "$SERVER_PID" == "0" ]; then
echo -e "\n***\n*** Failed to start $SERVER\n***"
cat $SERVER_LOG
exit 1
fi
echo "Test: $i, client type: HTTP" >>$CLIENT_LOG
set +e
python $SHM_TEST CudaSharedMemoryTestRawHttpRequest.$i >>$CLIENT_LOG 2>&1
if [ $? -ne 0 ]; then
echo -e "\n***\n*** Test Failed\n***"
RET=1
else
check_test_results $TEST_RESULT_FILE 1
if [ $? -ne 0 ]; then
cat $CLIENT_LOG
echo -e "\n***\n*** Test Result Verification Failed\n***"
RET=1
fi
fi
set -e
kill $SERVER_PID
wait $SERVER_PID
done

mkdir -p python_models/simple/1/
cp ../python_models/execute_delayed_model/model.py ./python_models/simple/1/
cp ../python_models/execute_delayed_model/config.pbtxt ./python_models/simple/
Expand Down
44 changes: 43 additions & 1 deletion qa/L0_http/http_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/python
# Copyright 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -29,6 +29,8 @@

sys.path.append("../common")

import base64
import json
import threading
import time
import unittest
Expand All @@ -44,6 +46,9 @@ class HttpTest(tu.TestResultCollector):
def _get_infer_url(self, model_name):
return "http://localhost:8000/v2/models/{}/infer".format(model_name)

def _get_load_model_url(self, model_name):
return "http://localhost:8000/v2/repository/models/{}/load".format(model_name)

def _raw_binary_helper(
self, model, input_bytes, expected_output_bytes, extra_headers={}
):
Expand Down Expand Up @@ -231,6 +236,43 @@ def test_descriptive_status_code(self):
)
t.join()

def test_loading_large_invalid_model(self):
# Generate large base64 encoded data
data_length = 1 << 31
int_max = (1 << 31) - 1
random_data = b"A" * data_length
encoded_data = base64.b64encode(random_data)

assert (
len(encoded_data) > int_max
), "Encoded data length does not match the required length."

# Prepare payload with large base64 encoded data
payload = {
"parameters": {
"config": json.dumps({"backend": "onnxruntime"}),
"file:1/model.onnx": encoded_data.decode("utf-8"),
}
}
headers = {"Content-Type": "application/json"}

# Send POST request
response = requests.post(
self._get_load_model_url("invalid_onnx"), headers=headers, json=payload
)

# Assert the response is not successful
self.assertNotEqual(response.status_code, 200)
try:
error_message = response.json().get("error", "")
self.assertIn(
"'file:1/model.onnx' exceeds the maximum allowed data size limit "
"INT_MAX",
error_message,
)
except ValueError:
self.fail("Response is not valid JSON")


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion qa/L0_http/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,7 @@ fi

TEST_RESULT_FILE='test_results.txt'
PYTHON_TEST=http_test.py
EXPECTED_NUM_TESTS=9
EXPECTED_NUM_TESTS=10
set +e
python $PYTHON_TEST >$CLIENT_LOG 2>&1
if [ $? -ne 0 ]; then
Expand Down
18 changes: 18 additions & 0 deletions src/http_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1546,6 +1546,16 @@ HTTPAPIServer::HandleRepositoryControl(
param = TRITONSERVER_ParameterNew(
m.c_str(), TRITONSERVER_PARAMETER_STRING, param_str);
} else if (m.rfind("file:", 0) == 0) {
if (param_len > INT_MAX) {
RETURN_AND_RESPOND_IF_ERR(
req, TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
("'" + m +
"' exceeds the maximum allowed data size limit "
"INT_MAX")
.c_str()));
}

// Decode base64
base64_decodestate s;
base64_init_decodestate(&s);
Expand Down Expand Up @@ -2443,6 +2453,14 @@ HTTPAPIServer::HandleCudaSharedMemory(
}

if (err == nullptr) {
if (b64_handle_len > INT_MAX) {
pskiran1 marked this conversation as resolved.
Show resolved Hide resolved
RETURN_AND_RESPOND_IF_ERR(
req, TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
"The length of 'raw_handle' exceeds the maximum "
"allowed limit INT_MAX"));
}

base64_decodestate s;
base64_init_decodestate(&s);

Expand Down
Loading