Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate the memory requested for the infer request is not out of bounds #7083

Merged
merged 11 commits into from
Apr 12, 2024
57 changes: 47 additions & 10 deletions qa/L0_shared_memory/shared_memory_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@


class SharedMemoryTest(tu.TestResultCollector):
DEFAULT_SHM_BYTE_SIZE = 64

def setUp(self):
self._setup_client()

Expand Down Expand Up @@ -125,7 +127,10 @@ def test_reregister_after_register(self):
shm.destroy_shared_memory_region(shm_op0_handle)

def _configure_server(
self, create_byte_size=64, register_byte_size=64, register_offset=0
self,
create_byte_size=DEFAULT_SHM_BYTE_SIZE,
register_byte_size=DEFAULT_SHM_BYTE_SIZE,
register_offset=0,
):
"""Creates and registers shared memory regions for testing.

Expand Down Expand Up @@ -189,8 +194,10 @@ def _basic_inference(
shm_op1_handle,
error_msg,
big_shm_name="",
big_shm_size=64,
big_shm_size=DEFAULT_SHM_BYTE_SIZE,
shm_output_offset=0,
shm_output_byte_size=DEFAULT_SHM_BYTE_SIZE,
default_shm_byte_size=DEFAULT_SHM_BYTE_SIZE,
):
input0_data = np.arange(start=0, stop=16, dtype=np.int32)
input1_data = np.ones(shape=16, dtype=np.int32)
Expand All @@ -209,17 +216,21 @@ def _basic_inference(
outputs.append(grpcclient.InferRequestedOutput("OUTPUT0"))
outputs.append(grpcclient.InferRequestedOutput("OUTPUT1"))

inputs[0].set_shared_memory("input0_data", 64)
inputs[0].set_shared_memory("input0_data", default_shm_byte_size)

if type(shm_ip1_handle) == np.array:
inputs[1].set_data_from_numpy(input0_data, binary_data=True)
elif big_shm_name != "":
inputs[1].set_shared_memory(big_shm_name, big_shm_size)
else:
inputs[1].set_shared_memory("input1_data", 64)
inputs[1].set_shared_memory("input1_data", default_shm_byte_size)

outputs[0].set_shared_memory("output0_data", 64, offset=shm_output_offset)
outputs[1].set_shared_memory("output1_data", 64, offset=shm_output_offset)
outputs[0].set_shared_memory(
"output0_data", shm_output_byte_size, offset=shm_output_offset
)
outputs[1].set_shared_memory(
"output1_data", shm_output_byte_size, offset=shm_output_offset
)

try:
results = self.triton_client.infer(
Expand Down Expand Up @@ -248,7 +259,11 @@ def test_unregister_after_inference(self):
error_msg = []
shm_handles = self._configure_server()
self._basic_inference(
shm_handles[0], shm_handles[1], shm_handles[2], shm_handles[3], error_msg
shm_handles[0],
shm_handles[1],
shm_handles[2],
shm_handles[3],
error_msg,
)
if len(error_msg) > 0:
raise Exception(str(error_msg))
Expand All @@ -270,10 +285,10 @@ def test_register_after_inference(self):
if len(error_msg) > 0:
raise Exception(str(error_msg))
shm_ip2_handle = shm.create_shared_memory_region(
"input2_data", "/input2_data", 64
"input2_data", "/input2_data", self.DEFAULT_SHM_BYTE_SIZE
)
self.triton_client.register_system_shared_memory(
"input2_data", "/input2_data", 64
"input2_data", "/input2_data", self.DEFAULT_SHM_BYTE_SIZE
)
shm_status = self.triton_client.get_system_shared_memory_status()
if self.protocol == "http":
Expand Down Expand Up @@ -362,8 +377,30 @@ def test_infer_offset_out_of_bound(self):
self.assertIn("Invalid offset for shared memory region", error_msg[0])
self._cleanup_server(shm_handles)

def test_infer_byte_size_out_of_bound(self):
Tabrizian marked this conversation as resolved.
Show resolved Hide resolved
# Shared memory byte_size outside output region - Throws error
error_msg = []
shm_handles = self._configure_server()
offset = 60
byte_size = self.DEFAULT_SHM_BYTE_SIZE

self._basic_inference(
shm_handles[0],
shm_handles[1],
shm_handles[2],
shm_handles[3],
error_msg,
shm_output_offset=offset,
shm_output_byte_size=byte_size,
)
self.assertEqual(len(error_msg), 1)
self.assertIn(
"Invalid offset + byte size for shared memory region", error_msg[0]
)
self._cleanup_server(shm_handles)

def test_register_out_of_bound(self):
create_byte_size = 64
create_byte_size = self.DEFAULT_SHM_BYTE_SIZE

# Verify various edge cases of registered region size (offset+byte_size)
# don't go out of bounds of the actual created shm file object's size.
Expand Down
19 changes: 13 additions & 6 deletions qa/L0_shared_memory/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ for i in \
test_mixed_raw_shm \
test_unregisterall \
test_infer_offset_out_of_bound \
test_infer_byte_size_out_of_bound \
test_register_out_of_bound; do
for client_type in http grpc; do
SERVER_ARGS="--model-repository=`pwd`/models --log-verbose=1 ${SERVER_ARGS_EXTRA}"
Expand All @@ -63,32 +64,38 @@ for i in \
fi

export CLIENT_TYPE=$client_type
echo "Test: $i, client type: $client_type" >>$CLIENT_LOG
TMP_CLIENT_LOG="./tmp_client.log"
echo "Test: $i, client type: $client_type" >>$TMP_CLIENT_LOG

set +e
python3 $SHM_TEST SharedMemoryTest.$i >>$CLIENT_LOG 2>&1
python3 $SHM_TEST SharedMemoryTest.$i >>$TMP_CLIENT_LOG 2>&1
if [ $? -ne 0 ]; then
cat $TMP_CLIENT_LOG
echo -e "\n***\n*** Test Failed\n***"
RET=1
else
check_test_results $TEST_RESULT_FILE 1
if [ $? -ne 0 ]; then
cat $CLIENT_LOG
cat $TEST_RESULT_FILE
echo -e "\n***\n*** Test Result Verification Failed\n***"
RET=1
fi
fi
set -e

cat $TMP_CLIENT_LOG >>$CLIENT_LOG
rm $TMP_CLIENT_LOG
kill $SERVER_PID
wait $SERVER_PID
if [ $? -ne 0 ]; then
echo -e "\n***\n*** Test Server shut down non-gracefully\n***"
RET=1
fi
set -e
Tabrizian marked this conversation as resolved.
Show resolved Hide resolved
done
done

if [ $RET -eq 0 ]; then
echo -e "\n***\n*** Test Passed\n***"
else
cat $CLIENT_LOG
echo -e "\n***\n*** Test Failed\n***"
fi

Expand Down
2 changes: 1 addition & 1 deletion src/grpc/infer_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ InferGRPCToInput(
}
void* tmp;
RETURN_IF_ERR(shm_manager->GetMemoryInfo(
region_name, offset, &tmp, &memory_type, &memory_type_id));
region_name, offset, byte_size, &tmp, &memory_type, &memory_type_id));
base = tmp;
if (memory_type == TRITONSERVER_MEMORY_GPU) {
#ifdef TRITON_ENABLE_GPU
Expand Down
3 changes: 2 additions & 1 deletion src/grpc/infer_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,8 @@ InferAllocatorPayload(
TRITONSERVER_MemoryType memory_type;
int64_t memory_type_id;
RETURN_IF_ERR(shm_manager->GetMemoryInfo(
region_name, offset, &base, &memory_type, &memory_type_id));
region_name, offset, byte_size, &base, &memory_type,
&memory_type_id));

if (memory_type == TRITONSERVER_MEMORY_GPU) {
#ifdef TRITON_ENABLE_GPU
Expand Down
6 changes: 4 additions & 2 deletions src/http_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2682,7 +2682,8 @@ HTTPAPIServer::ParseJsonTritonIO(
TRITONSERVER_MemoryType memory_type;
int64_t memory_type_id;
RETURN_IF_ERR(shm_manager_->GetMemoryInfo(
shm_region, shm_offset, &base, &memory_type, &memory_type_id));
shm_region, shm_offset, byte_size, &base, &memory_type,
&memory_type_id));
if (memory_type == TRITONSERVER_MEMORY_GPU) {
#ifdef TRITON_ENABLE_GPU
cudaIpcMemHandle_t* cuda_handle;
Expand Down Expand Up @@ -2796,7 +2797,8 @@ HTTPAPIServer::ParseJsonTritonIO(
TRITONSERVER_MemoryType memory_type;
int64_t memory_type_id;
RETURN_IF_ERR(shm_manager_->GetMemoryInfo(
shm_region, offset, &base, &memory_type, &memory_type_id));
shm_region, offset, byte_size, &base, &memory_type,
&memory_type_id));

if (memory_type == TRITONSERVER_MEMORY_GPU) {
#ifdef TRITON_ENABLE_GPU
Expand Down
24 changes: 17 additions & 7 deletions src/shared_memory_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -384,8 +384,9 @@ SharedMemoryManager::RegisterCUDASharedMemory(

TRITONSERVER_Error*
SharedMemoryManager::GetMemoryInfo(
const std::string& name, size_t offset, void** shm_mapped_addr,
TRITONSERVER_MemoryType* memory_type, int64_t* device_id)
const std::string& name, size_t offset, size_t byte_size,
void** shm_mapped_addr, TRITONSERVER_MemoryType* memory_type,
int64_t* device_id)
{
// protect shared_memory_map_ from concurrent access
std::lock_guard<std::mutex> lock(mu_);
Expand All @@ -399,20 +400,29 @@ SharedMemoryManager::GetMemoryInfo(
}

// validate offset
size_t max_offset = 0;
size_t shm_region_end = 0;
if (it->second->kind_ == TRITONSERVER_MEMORY_CPU) {
max_offset = it->second->offset_;
shm_region_end = it->second->offset_;
}
if (it->second->byte_size_ > 0) {
max_offset += it->second->byte_size_ - 1;
shm_region_end += it->second->byte_size_ - 1;
}
if (offset > max_offset) {
if (offset > shm_region_end) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
std::string("Invalid offset for shared memory region: '" + name + "'")
.c_str());
}
// TODO: should also validate byte_size from caller
// validate byte_size + offset is within memory bounds
size_t total_req_shm = offset + byte_size - 1;
if (total_req_shm > shm_region_end) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
std::string(
"Invalid offset + byte size for shared memory region: '" + name +
"'")
.c_str());
}

if (it->second->kind_ == TRITONSERVER_MEMORY_CPU) {
*shm_mapped_addr = (void*)((uint8_t*)it->second->mapped_addr_ +
Expand Down
6 changes: 4 additions & 2 deletions src/shared_memory_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,17 @@ class SharedMemoryManager {
/// if named block doesn't exist.
/// \param name The name of the shared memory block to get.
/// \param offset The offset in the block
/// \param byte_size The byte size to request for the shm region
/// \param shm_mapped_addr Returns the pointer to the shared
/// memory block with the specified name and offset
/// \param memory_type Returns the type of the memory
/// \param device_id Returns the device id associated with the
/// memory block
/// \return a TRITONSERVER_Error indicating success or failure.
TRITONSERVER_Error* GetMemoryInfo(
const std::string& name, size_t offset, void** shm_mapped_addr,
TRITONSERVER_MemoryType* memory_type, int64_t* device_id);
const std::string& name, size_t offset, size_t byte_size,
void** shm_mapped_addr, TRITONSERVER_MemoryType* memory_type,
int64_t* device_id);

#ifdef TRITON_ENABLE_GPU
/// Get the CUDA memory handle associated with the block name.
Expand Down
Loading