Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate the memory requested for the infer request is not out of bounds #7083

Merged
merged 11 commits into from
Apr 12, 2024
80 changes: 62 additions & 18 deletions qa/L0_shared_memory/shared_memory_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@


class SharedMemoryTest(tu.TestResultCollector):
DEFAULT_SHM_BYTE_SIZE = 64

def test_invalid_create_shm(self):
# Raises error since tried to create invalid system shared memory region
try:
Expand Down Expand Up @@ -122,18 +124,18 @@ def test_reregister_after_register(self):
self.assertTrue(len(shm_status.regions) == 1)
shm.destroy_shared_memory_region(shm_op0_handle)

def _configure_sever(self):
def _configure_sever(self, shm_byte_size=DEFAULT_SHM_BYTE_SIZE):
shm_ip0_handle = shm.create_shared_memory_region(
"input0_data", "/input0_data", 64
"input0_data", "/input0_data", shm_byte_size
)
shm_ip1_handle = shm.create_shared_memory_region(
"input1_data", "/input1_data", 64
"input1_data", "/input1_data", shm_byte_size
)
shm_op0_handle = shm.create_shared_memory_region(
"output0_data", "/output0_data", 64
"output0_data", "/output0_data", shm_byte_size
)
shm_op1_handle = shm.create_shared_memory_region(
"output1_data", "/output1_data", 64
"output1_data", "/output1_data", shm_byte_size
)
input0_data = np.arange(start=0, stop=16, dtype=np.int32)
input1_data = np.ones(shape=16, dtype=np.int32)
Expand All @@ -143,10 +145,18 @@ def _configure_sever(self):
triton_client = httpclient.InferenceServerClient(_url, verbose=True)
else:
triton_client = grpcclient.InferenceServerClient(_url, verbose=True)
triton_client.register_system_shared_memory("input0_data", "/input0_data", 64)
triton_client.register_system_shared_memory("input1_data", "/input1_data", 64)
triton_client.register_system_shared_memory("output0_data", "/output0_data", 64)
triton_client.register_system_shared_memory("output1_data", "/output1_data", 64)
triton_client.register_system_shared_memory(
"input0_data", "/input0_data", shm_byte_size
)
triton_client.register_system_shared_memory(
"input1_data", "/input1_data", shm_byte_size
)
triton_client.register_system_shared_memory(
"output0_data", "/output0_data", shm_byte_size
)
triton_client.register_system_shared_memory(
"output1_data", "/output1_data", shm_byte_size
)
return [shm_ip0_handle, shm_ip1_handle, shm_op0_handle, shm_op1_handle]

def _cleanup_server(self, shm_handles):
Expand All @@ -161,8 +171,10 @@ def _basic_inference(
shm_op1_handle,
error_msg,
big_shm_name="",
big_shm_size=64,
big_shm_size=DEFAULT_SHM_BYTE_SIZE,
shm_output_offset=0,
shm_output_byte_size=DEFAULT_SHM_BYTE_SIZE,
default_shm_byte_size=DEFAULT_SHM_BYTE_SIZE,
):
input0_data = np.arange(start=0, stop=16, dtype=np.int32)
input1_data = np.ones(shape=16, dtype=np.int32)
Expand All @@ -183,17 +195,21 @@ def _basic_inference(
outputs.append(grpcclient.InferRequestedOutput("OUTPUT0"))
outputs.append(grpcclient.InferRequestedOutput("OUTPUT1"))

inputs[0].set_shared_memory("input0_data", 64)
inputs[0].set_shared_memory("input0_data", default_shm_byte_size)

if type(shm_ip1_handle) == np.array:
inputs[1].set_data_from_numpy(input0_data, binary_data=True)
elif big_shm_name != "":
inputs[1].set_shared_memory(big_shm_name, big_shm_size)
else:
inputs[1].set_shared_memory("input1_data", 64)
inputs[1].set_shared_memory("input1_data", default_shm_byte_size)

outputs[0].set_shared_memory("output0_data", 64, offset=shm_output_offset)
outputs[1].set_shared_memory("output1_data", 64, offset=shm_output_offset)
outputs[0].set_shared_memory(
"output0_data", shm_output_byte_size, offset=shm_output_offset
)
outputs[1].set_shared_memory(
"output1_data", shm_output_byte_size, offset=shm_output_offset
)

try:
results = triton_client.infer(
Expand Down Expand Up @@ -222,7 +238,11 @@ def test_unregister_after_inference(self):
error_msg = []
shm_handles = self._configure_sever()
self._basic_inference(
shm_handles[0], shm_handles[1], shm_handles[2], shm_handles[3], error_msg
shm_handles[0],
shm_handles[1],
shm_handles[2],
shm_handles[3],
error_msg,
)
if len(error_msg) > 0:
raise Exception(str(error_msg))
Expand Down Expand Up @@ -252,9 +272,11 @@ def test_register_after_inference(self):
if len(error_msg) > 0:
raise Exception(str(error_msg))
shm_ip2_handle = shm.create_shared_memory_region(
"input2_data", "/input2_data", 64
"input2_data", "/input2_data", self.DEFAULT_SHM_BYTE_SIZE
)
triton_client.register_system_shared_memory(
"input2_data", "/input2_data", self.DEFAULT_SHM_BYTE_SIZE
)
triton_client.register_system_shared_memory("input2_data", "/input2_data", 64)
shm_status = triton_client.get_system_shared_memory_status()
if _protocol == "http":
self.assertTrue(len(shm_status) == 5)
Expand Down Expand Up @@ -295,7 +317,7 @@ def test_too_big_shm(self):
def test_mixed_raw_shm(self):
# Mix of shared memory and RAW inputs
error_msg = []
shm_handles = self._configure_sever()
shm_handles = self._configure_sever(shm_byte_size=256)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My PR will have some conflicts with this PR, but you should be able to use the create_byte_size and register_byte_size args in the same way. Let me know if you're good with it: #7093

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Francesco's PR will also likely have some merge conflicts: #7048

input1_data = np.ones(shape=16, dtype=np.int32)
self._basic_inference(
shm_handles[0], [input1_data], shm_handles[2], shm_handles[3], error_msg
Expand Down Expand Up @@ -348,6 +370,28 @@ def test_infer_offset_out_of_bound(self):
self.assertIn("Invalid offset for shared memory region", error_msg[0])
self._cleanup_server(shm_handles)

def test_infer_byte_size_out_of_bound(self):
Tabrizian marked this conversation as resolved.
Show resolved Hide resolved
# Shared memory byte_size outside output region - Throws error
error_msg = []
shm_handles = self._configure_sever()
offset = 60
byte_size = 64

self._basic_inference(
shm_handles[0],
shm_handles[1],
shm_handles[2],
shm_handles[3],
error_msg,
shm_output_offset=offset,
shm_output_byte_size=byte_size,
)
self.assertEqual(len(error_msg), 1)
self.assertIn(
"Invalid offset + byte size for shared memory region", error_msg[0]
)
self._cleanup_server(shm_handles)


if __name__ == "__main__":
_protocol = os.environ.get("CLIENT_TYPE", "http")
Expand Down
17 changes: 10 additions & 7 deletions qa/L0_shared_memory/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ for i in \
test_too_big_shm \
test_mixed_raw_shm \
test_unregisterall \
test_infer_offset_out_of_bound; do
test_infer_offset_out_of_bound \
test_infer_byte_size_out_of_bound; do
for client_type in http grpc; do
SERVER_ARGS="--model-repository=`pwd`/models --log-verbose=1 ${SERVER_ARGS_EXTRA}"
SERVER_LOG="./$i.$client_type.server.log"
Expand All @@ -62,32 +63,34 @@ for i in \
fi

export CLIENT_TYPE=$client_type
echo "Test: $i, client type: $client_type" >>$CLIENT_LOG
TMP_CLIENT_LOG="./tmp_client.log"
echo "Test: $i, client type: $client_type" >>$TMP_CLIENT_LOG

set +e
python $SHM_TEST SharedMemoryTest.$i >>$CLIENT_LOG 2>&1
python $SHM_TEST SharedMemoryTest.$i >>$TMP_CLIENT_LOG 2>&1
Tabrizian marked this conversation as resolved.
Show resolved Hide resolved
if [ $? -ne 0 ]; then
cat $TMP_CLIENT_LOG
echo -e "\n***\n*** Test Failed\n***"
RET=1
else
check_test_results $TEST_RESULT_FILE 1
if [ $? -ne 0 ]; then
cat $CLIENT_LOG
cat $TEST_RESULT_FILE
echo -e "\n***\n*** Test Result Verification Failed\n***"
RET=1
fi
fi
set -e

cat $TMP_CLIENT_LOG >>$CLIENT_LOG
rm $TMP_CLIENT_LOG
kill $SERVER_PID
wait $SERVER_PID
set -e
Tabrizian marked this conversation as resolved.
Show resolved Hide resolved
done
done

if [ $RET -eq 0 ]; then
echo -e "\n***\n*** Test Passed\n***"
else
cat $CLIENT_LOG
echo -e "\n***\n*** Test Failed\n***"
fi

Expand Down
2 changes: 1 addition & 1 deletion src/grpc/infer_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ InferGRPCToInput(
}
void* tmp;
RETURN_IF_ERR(shm_manager->GetMemoryInfo(
region_name, offset, &tmp, &memory_type, &memory_type_id));
region_name, offset, byte_size, &tmp, &memory_type, &memory_type_id));
base = tmp;
if (memory_type == TRITONSERVER_MEMORY_GPU) {
#ifdef TRITON_ENABLE_GPU
Expand Down
3 changes: 2 additions & 1 deletion src/grpc/infer_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,8 @@ InferAllocatorPayload(
TRITONSERVER_MemoryType memory_type;
int64_t memory_type_id;
RETURN_IF_ERR(shm_manager->GetMemoryInfo(
region_name, offset, &base, &memory_type, &memory_type_id));
region_name, offset, byte_size, &base, &memory_type,
&memory_type_id));

if (memory_type == TRITONSERVER_MEMORY_GPU) {
#ifdef TRITON_ENABLE_GPU
Expand Down
6 changes: 4 additions & 2 deletions src/http_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2669,7 +2669,8 @@ HTTPAPIServer::ParseJsonTritonIO(
TRITONSERVER_MemoryType memory_type;
int64_t memory_type_id;
RETURN_IF_ERR(shm_manager_->GetMemoryInfo(
shm_region, shm_offset, &base, &memory_type, &memory_type_id));
shm_region, shm_offset, byte_size, &base, &memory_type,
&memory_type_id));
if (memory_type == TRITONSERVER_MEMORY_GPU) {
#ifdef TRITON_ENABLE_GPU
cudaIpcMemHandle_t* cuda_handle;
Expand Down Expand Up @@ -2783,7 +2784,8 @@ HTTPAPIServer::ParseJsonTritonIO(
TRITONSERVER_MemoryType memory_type;
int64_t memory_type_id;
RETURN_IF_ERR(shm_manager_->GetMemoryInfo(
shm_region, offset, &base, &memory_type, &memory_type_id));
shm_region, offset, byte_size, &base, &memory_type,
&memory_type_id));

if (memory_type == TRITONSERVER_MEMORY_GPU) {
#ifdef TRITON_ENABLE_GPU
Expand Down
24 changes: 17 additions & 7 deletions src/shared_memory_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,9 @@ SharedMemoryManager::RegisterCUDASharedMemory(

TRITONSERVER_Error*
SharedMemoryManager::GetMemoryInfo(
const std::string& name, size_t offset, void** shm_mapped_addr,
TRITONSERVER_MemoryType* memory_type, int64_t* device_id)
const std::string& name, size_t offset, size_t byte_size,
void** shm_mapped_addr, TRITONSERVER_MemoryType* memory_type,
int64_t* device_id)
{
// protect shared_memory_map_ from concurrent access
std::lock_guard<std::mutex> lock(mu_);
Expand All @@ -348,20 +349,29 @@ SharedMemoryManager::GetMemoryInfo(
}

// validate offset
size_t max_offset = 0;
size_t shm_region_end = 0;
if (it->second->kind_ == TRITONSERVER_MEMORY_CPU) {
max_offset = it->second->offset_;
shm_region_end = it->second->offset_;
}
if (it->second->byte_size_ > 0) {
max_offset += it->second->byte_size_ - 1;
shm_region_end += it->second->byte_size_ - 1;
}
if (offset > max_offset) {
if (offset > shm_region_end) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
std::string("Invalid offset for shared memory region: '" + name + "'")
.c_str());
}
// TODO: should also validate byte_size from caller
// validate byte_size + offset is within memory bounds
size_t total_req_shm = offset + byte_size - 1;
if (total_req_shm > shm_region_end) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
std::string(
"Invalid offset + byte size for shared memory region: '" + name +
"'")
.c_str());
}

if (it->second->kind_ == TRITONSERVER_MEMORY_CPU) {
*shm_mapped_addr = (void*)((uint8_t*)it->second->mapped_addr_ +
Expand Down
6 changes: 4 additions & 2 deletions src/shared_memory_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,17 @@ class SharedMemoryManager {
/// if named block doesn't exist.
/// \param name The name of the shared memory block to get.
/// \param offset The offset in the block
/// \param byte_size The byte size to request for the shm region
/// \param shm_mapped_addr Returns the pointer to the shared
/// memory block with the specified name and offset
/// \param memory_type Returns the type of the memory
/// \param device_id Returns the device id associated with the
/// memory block
/// \return a TRITONSERVER_Error indicating success or failure.
TRITONSERVER_Error* GetMemoryInfo(
const std::string& name, size_t offset, void** shm_mapped_addr,
TRITONSERVER_MemoryType* memory_type, int64_t* device_id);
const std::string& name, size_t offset, size_t byte_size,
void** shm_mapped_addr, TRITONSERVER_MemoryType* memory_type,
int64_t* device_id);

#ifdef TRITON_ENABLE_GPU
/// Get the CUDA memory handle associated with the block name.
Expand Down
Loading