Skip to content

Commit

Permalink
Validate the memory requested for the infer request is not out of bou…
Browse files Browse the repository at this point in the history
…nds (#7083) (#7111)

Validate that memory offset and byte size requested is not out of bounds 
of registered memory. 
Previously in #6914 we checked out of bounds offset for shared memory
requests. This PR also adds more testing to verify the block of memory 
is in fact in bounds.
Client change: triton-inference-server/client#565
  • Loading branch information
jbkyang-nvi authored Apr 12, 2024
1 parent 0d61b40 commit 5ff6935
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 29 deletions.
57 changes: 47 additions & 10 deletions qa/L0_shared_memory/shared_memory_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@


class SharedMemoryTest(tu.TestResultCollector):
DEFAULT_SHM_BYTE_SIZE = 64

def setUp(self):
self._setup_client()

Expand Down Expand Up @@ -125,7 +127,10 @@ def test_reregister_after_register(self):
shm.destroy_shared_memory_region(shm_op0_handle)

def _configure_server(
self, create_byte_size=64, register_byte_size=64, register_offset=0
self,
create_byte_size=DEFAULT_SHM_BYTE_SIZE,
register_byte_size=DEFAULT_SHM_BYTE_SIZE,
register_offset=0,
):
"""Creates and registers shared memory regions for testing.
Expand Down Expand Up @@ -189,8 +194,10 @@ def _basic_inference(
shm_op1_handle,
error_msg,
big_shm_name="",
big_shm_size=64,
big_shm_size=DEFAULT_SHM_BYTE_SIZE,
shm_output_offset=0,
shm_output_byte_size=DEFAULT_SHM_BYTE_SIZE,
default_shm_byte_size=DEFAULT_SHM_BYTE_SIZE,
):
input0_data = np.arange(start=0, stop=16, dtype=np.int32)
input1_data = np.ones(shape=16, dtype=np.int32)
Expand All @@ -209,17 +216,21 @@ def _basic_inference(
outputs.append(grpcclient.InferRequestedOutput("OUTPUT0"))
outputs.append(grpcclient.InferRequestedOutput("OUTPUT1"))

inputs[0].set_shared_memory("input0_data", 64)
inputs[0].set_shared_memory("input0_data", default_shm_byte_size)

if type(shm_ip1_handle) == np.array:
inputs[1].set_data_from_numpy(input0_data, binary_data=True)
elif big_shm_name != "":
inputs[1].set_shared_memory(big_shm_name, big_shm_size)
else:
inputs[1].set_shared_memory("input1_data", 64)
inputs[1].set_shared_memory("input1_data", default_shm_byte_size)

outputs[0].set_shared_memory("output0_data", 64, offset=shm_output_offset)
outputs[1].set_shared_memory("output1_data", 64, offset=shm_output_offset)
outputs[0].set_shared_memory(
"output0_data", shm_output_byte_size, offset=shm_output_offset
)
outputs[1].set_shared_memory(
"output1_data", shm_output_byte_size, offset=shm_output_offset
)

try:
results = self.triton_client.infer(
Expand Down Expand Up @@ -248,7 +259,11 @@ def test_unregister_after_inference(self):
error_msg = []
shm_handles = self._configure_server()
self._basic_inference(
shm_handles[0], shm_handles[1], shm_handles[2], shm_handles[3], error_msg
shm_handles[0],
shm_handles[1],
shm_handles[2],
shm_handles[3],
error_msg,
)
if len(error_msg) > 0:
raise Exception(str(error_msg))
Expand All @@ -270,10 +285,10 @@ def test_register_after_inference(self):
if len(error_msg) > 0:
raise Exception(str(error_msg))
shm_ip2_handle = shm.create_shared_memory_region(
"input2_data", "/input2_data", 64
"input2_data", "/input2_data", self.DEFAULT_SHM_BYTE_SIZE
)
self.triton_client.register_system_shared_memory(
"input2_data", "/input2_data", 64
"input2_data", "/input2_data", self.DEFAULT_SHM_BYTE_SIZE
)
shm_status = self.triton_client.get_system_shared_memory_status()
if self.protocol == "http":
Expand Down Expand Up @@ -362,8 +377,30 @@ def test_infer_offset_out_of_bound(self):
self.assertIn("Invalid offset for shared memory region", error_msg[0])
self._cleanup_server(shm_handles)

def test_infer_byte_size_out_of_bound(self):
# Shared memory byte_size outside output region - Throws error
error_msg = []
shm_handles = self._configure_server()
offset = 60
byte_size = self.DEFAULT_SHM_BYTE_SIZE

self._basic_inference(
shm_handles[0],
shm_handles[1],
shm_handles[2],
shm_handles[3],
error_msg,
shm_output_offset=offset,
shm_output_byte_size=byte_size,
)
self.assertEqual(len(error_msg), 1)
self.assertIn(
"Invalid offset + byte size for shared memory region", error_msg[0]
)
self._cleanup_server(shm_handles)

def test_register_out_of_bound(self):
create_byte_size = 64
create_byte_size = self.DEFAULT_SHM_BYTE_SIZE

# Verify various edge cases of registered region size (offset+byte_size)
# don't go out of bounds of the actual created shm file object's size.
Expand Down
19 changes: 13 additions & 6 deletions qa/L0_shared_memory/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ for i in \
test_mixed_raw_shm \
test_unregisterall \
test_infer_offset_out_of_bound \
test_infer_byte_size_out_of_bound \
test_register_out_of_bound; do
for client_type in http grpc; do
SERVER_ARGS="--model-repository=`pwd`/models --log-verbose=1 ${SERVER_ARGS_EXTRA}"
Expand All @@ -63,32 +64,38 @@ for i in \
fi

export CLIENT_TYPE=$client_type
echo "Test: $i, client type: $client_type" >>$CLIENT_LOG
TMP_CLIENT_LOG="./tmp_client.log"
echo "Test: $i, client type: $client_type" >>$TMP_CLIENT_LOG

set +e
python3 $SHM_TEST SharedMemoryTest.$i >>$CLIENT_LOG 2>&1
python3 $SHM_TEST SharedMemoryTest.$i >>$TMP_CLIENT_LOG 2>&1
if [ $? -ne 0 ]; then
cat $TMP_CLIENT_LOG
echo -e "\n***\n*** Test Failed\n***"
RET=1
else
check_test_results $TEST_RESULT_FILE 1
if [ $? -ne 0 ]; then
cat $CLIENT_LOG
cat $TEST_RESULT_FILE
echo -e "\n***\n*** Test Result Verification Failed\n***"
RET=1
fi
fi
set -e

cat $TMP_CLIENT_LOG >>$CLIENT_LOG
rm $TMP_CLIENT_LOG
kill $SERVER_PID
wait $SERVER_PID
if [ $? -ne 0 ]; then
echo -e "\n***\n*** Test Server shut down non-gracefully\n***"
RET=1
fi
set -e
done
done

if [ $RET -eq 0 ]; then
echo -e "\n***\n*** Test Passed\n***"
else
cat $CLIENT_LOG
echo -e "\n***\n*** Test Failed\n***"
fi

Expand Down
2 changes: 1 addition & 1 deletion src/grpc/infer_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ InferGRPCToInput(
}
void* tmp;
RETURN_IF_ERR(shm_manager->GetMemoryInfo(
region_name, offset, &tmp, &memory_type, &memory_type_id));
region_name, offset, byte_size, &tmp, &memory_type, &memory_type_id));
base = tmp;
if (memory_type == TRITONSERVER_MEMORY_GPU) {
#ifdef TRITON_ENABLE_GPU
Expand Down
3 changes: 2 additions & 1 deletion src/grpc/infer_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,8 @@ InferAllocatorPayload(
TRITONSERVER_MemoryType memory_type;
int64_t memory_type_id;
RETURN_IF_ERR(shm_manager->GetMemoryInfo(
region_name, offset, &base, &memory_type, &memory_type_id));
region_name, offset, byte_size, &base, &memory_type,
&memory_type_id));

if (memory_type == TRITONSERVER_MEMORY_GPU) {
#ifdef TRITON_ENABLE_GPU
Expand Down
6 changes: 4 additions & 2 deletions src/http_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2682,7 +2682,8 @@ HTTPAPIServer::ParseJsonTritonIO(
TRITONSERVER_MemoryType memory_type;
int64_t memory_type_id;
RETURN_IF_ERR(shm_manager_->GetMemoryInfo(
shm_region, shm_offset, &base, &memory_type, &memory_type_id));
shm_region, shm_offset, byte_size, &base, &memory_type,
&memory_type_id));
if (memory_type == TRITONSERVER_MEMORY_GPU) {
#ifdef TRITON_ENABLE_GPU
cudaIpcMemHandle_t* cuda_handle;
Expand Down Expand Up @@ -2796,7 +2797,8 @@ HTTPAPIServer::ParseJsonTritonIO(
TRITONSERVER_MemoryType memory_type;
int64_t memory_type_id;
RETURN_IF_ERR(shm_manager_->GetMemoryInfo(
shm_region, offset, &base, &memory_type, &memory_type_id));
shm_region, offset, byte_size, &base, &memory_type,
&memory_type_id));

if (memory_type == TRITONSERVER_MEMORY_GPU) {
#ifdef TRITON_ENABLE_GPU
Expand Down
24 changes: 17 additions & 7 deletions src/shared_memory_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -384,8 +384,9 @@ SharedMemoryManager::RegisterCUDASharedMemory(

TRITONSERVER_Error*
SharedMemoryManager::GetMemoryInfo(
const std::string& name, size_t offset, void** shm_mapped_addr,
TRITONSERVER_MemoryType* memory_type, int64_t* device_id)
const std::string& name, size_t offset, size_t byte_size,
void** shm_mapped_addr, TRITONSERVER_MemoryType* memory_type,
int64_t* device_id)
{
// protect shared_memory_map_ from concurrent access
std::lock_guard<std::mutex> lock(mu_);
Expand All @@ -399,20 +400,29 @@ SharedMemoryManager::GetMemoryInfo(
}

// validate offset
size_t max_offset = 0;
size_t shm_region_end = 0;
if (it->second->kind_ == TRITONSERVER_MEMORY_CPU) {
max_offset = it->second->offset_;
shm_region_end = it->second->offset_;
}
if (it->second->byte_size_ > 0) {
max_offset += it->second->byte_size_ - 1;
shm_region_end += it->second->byte_size_ - 1;
}
if (offset > max_offset) {
if (offset > shm_region_end) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
std::string("Invalid offset for shared memory region: '" + name + "'")
.c_str());
}
// TODO: should also validate byte_size from caller
// validate byte_size + offset is within memory bounds
size_t total_req_shm = offset + byte_size - 1;
if (total_req_shm > shm_region_end) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
std::string(
"Invalid offset + byte size for shared memory region: '" + name +
"'")
.c_str());
}

if (it->second->kind_ == TRITONSERVER_MEMORY_CPU) {
*shm_mapped_addr = (void*)((uint8_t*)it->second->mapped_addr_ +
Expand Down
6 changes: 4 additions & 2 deletions src/shared_memory_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,17 @@ class SharedMemoryManager {
/// if named block doesn't exist.
/// \param name The name of the shared memory block to get.
/// \param offset The offset in the block
/// \param byte_size The byte size to request for the shm region
/// \param shm_mapped_addr Returns the pointer to the shared
/// memory block with the specified name and offset
/// \param memory_type Returns the type of the memory
/// \param device_id Returns the device id associated with the
/// memory block
/// \return a TRITONSERVER_Error indicating success or failure.
TRITONSERVER_Error* GetMemoryInfo(
const std::string& name, size_t offset, void** shm_mapped_addr,
TRITONSERVER_MemoryType* memory_type, int64_t* device_id);
const std::string& name, size_t offset, size_t byte_size,
void** shm_mapped_addr, TRITONSERVER_MemoryType* memory_type,
int64_t* device_id);

#ifdef TRITON_ENABLE_GPU
/// Get the CUDA memory handle associated with the block name.
Expand Down

0 comments on commit 5ff6935

Please sign in to comment.