Skip to content

Commit

Permalink
Enhance bound check for shm offset (#6914)
Browse files Browse the repository at this point in the history
* Enhance bound check for shm offset

* Add test for enhance bound check for shm offset

* Fix off by 1 on max offset

* Improve comments

* Improve comment and offset

* Separate logic between computation and validation
  • Loading branch information
kthui authored Mar 8, 2024
1 parent 2255663 commit 60071e1
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 6 deletions.
31 changes: 28 additions & 3 deletions qa/L0_shared_memory/shared_memory_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python3

# Copyright 2019-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -162,6 +162,7 @@ def _basic_inference(
error_msg,
big_shm_name="",
big_shm_size=64,
shm_output_offset=0,
):
input0_data = np.arange(start=0, stop=16, dtype=np.int32)
input1_data = np.ones(shape=16, dtype=np.int32)
Expand Down Expand Up @@ -191,8 +192,8 @@ def _basic_inference(
else:
inputs[1].set_shared_memory("input1_data", 64)

outputs[0].set_shared_memory("output0_data", 64)
outputs[1].set_shared_memory("output1_data", 64)
outputs[0].set_shared_memory("output0_data", 64, offset=shm_output_offset)
outputs[1].set_shared_memory("output1_data", 64, offset=shm_output_offset)

try:
results = triton_client.infer(
Expand Down Expand Up @@ -323,6 +324,30 @@ def test_unregisterall(self):
self.assertTrue(len(status_after.regions) == 0)
self._cleanup_server(shm_handles)

def test_infer_offset_out_of_bound(self):
# Shared memory offset outside output region - Throws error
error_msg = []
shm_handles = self._configure_sever()
if _protocol == "http":
# -32 when placed in an int64 signed type, to get a negative offset
# by overflowing
offset = 2**64 - 32
else:
# gRPC will throw an error if > 2**63 - 1, so instead test for
# exceeding shm region size by 1 byte, given its size is 64 bytes
offset = 64
self._basic_inference(
shm_handles[0],
shm_handles[1],
shm_handles[2],
shm_handles[3],
error_msg,
shm_output_offset=offset,
)
self.assertEqual(len(error_msg), 1)
self.assertIn("Invalid offset for shared memory region", error_msg[0])
self._cleanup_server(shm_handles)


if __name__ == "__main__":
_protocol = os.environ.get("CLIENT_TYPE", "http")
Expand Down
5 changes: 3 additions & 2 deletions qa/L0_shared_memory/test.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/bin/bash
# Copyright 2019-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -49,7 +49,8 @@ for i in \
test_register_after_inference \
test_too_big_shm \
test_mixed_raw_shm \
test_unregisterall; do
test_unregisterall \
test_infer_offset_out_of_bound; do
for client_type in http grpc; do
SERVER_ARGS="--model-repository=`pwd`/models --log-verbose=1 ${SERVER_ARGS_EXTRA}"
SERVER_LOG="./$i.$client_type.server.log"
Expand Down
19 changes: 18 additions & 1 deletion src/shared_memory_manager.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2019-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -346,6 +346,23 @@ SharedMemoryManager::GetMemoryInfo(
std::string("Unable to find shared memory region: '" + name + "'")
.c_str());
}

// validate offset
size_t max_offset = 0;
if (it->second->kind_ == TRITONSERVER_MEMORY_CPU) {
max_offset = it->second->offset_;
}
if (it->second->byte_size_ > 0) {
max_offset += it->second->byte_size_ - 1;
}
if (offset > max_offset) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
std::string("Invalid offset for shared memory region: '" + name + "'")
.c_str());
}
// TODO: should also validate byte_size from caller

if (it->second->kind_ == TRITONSERVER_MEMORY_CPU) {
*shm_mapped_addr = (void*)((uint8_t*)it->second->mapped_addr_ +
it->second->offset_ + offset);
Expand Down

0 comments on commit 60071e1

Please sign in to comment.