diff --git a/qa/L0_shared_memory/shared_memory_test.py b/qa/L0_shared_memory/shared_memory_test.py index bfc48d6261..3c084bace4 100755 --- a/qa/L0_shared_memory/shared_memory_test.py +++ b/qa/L0_shared_memory/shared_memory_test.py @@ -42,6 +42,22 @@ class SharedMemoryTest(tu.TestResultCollector): + def setUp(self): + self._setup_client() + + def _setup_client(self): + self.protocol = os.environ.get("CLIENT_TYPE", "http") + if self.protocol == "http": + self.url = "localhost:8000" + self.triton_client = httpclient.InferenceServerClient( + self.url, verbose=True + ) + else: + self.url = "localhost:8001" + self.triton_client = grpcclient.InferenceServerClient( + self.url, verbose=True + ) + def test_invalid_create_shm(self): # Raises error since tried to create invalid system shared memory region try: @@ -54,17 +70,13 @@ def test_invalid_create_shm(self): def test_valid_create_set_register(self): # Create a valid system shared memory region, fill data in it and register - if _protocol == "http": - triton_client = httpclient.InferenceServerClient(_url, verbose=True) - else: - triton_client = grpcclient.InferenceServerClient(_url, verbose=True) shm_op0_handle = shm.create_shared_memory_region("dummy_data", "/dummy_data", 8) shm.set_shared_memory_region( shm_op0_handle, [np.array([1, 2], dtype=np.float32)] ) - triton_client.register_system_shared_memory("dummy_data", "/dummy_data", 8) - shm_status = triton_client.get_system_shared_memory_status() - if _protocol == "http": + self.triton_client.register_system_shared_memory("dummy_data", "/dummy_data", 8) + shm_status = self.triton_client.get_system_shared_memory_status() + if self.protocol == "http": self.assertTrue(len(shm_status) == 1) else: self.assertTrue(len(shm_status.regions) == 1) @@ -72,14 +84,10 @@ def test_valid_create_set_register(self): def test_unregister_before_register(self): # Create a valid system shared memory region and unregister before register - if _protocol == "http": - triton_client = httpclient.InferenceServerClient(_url, verbose=True) - else: - triton_client = grpcclient.InferenceServerClient(_url, verbose=True) shm_op0_handle = shm.create_shared_memory_region("dummy_data", "/dummy_data", 8) - triton_client.unregister_system_shared_memory("dummy_data") - shm_status = triton_client.get_system_shared_memory_status() - if _protocol == "http": + self.triton_client.unregister_system_shared_memory("dummy_data") + shm_status = self.triton_client.get_system_shared_memory_status() + if self.protocol == "http": self.assertTrue(len(shm_status) == 0) else: self.assertTrue(len(shm_status.regions) == 0) @@ -87,15 +95,11 @@ def test_unregister_before_register(self): def test_unregister_after_register(self): # Create a valid system shared memory region and unregister after register - if _protocol == "http": - triton_client = httpclient.InferenceServerClient(_url, verbose=True) - else: - triton_client = grpcclient.InferenceServerClient(_url, verbose=True) shm_op0_handle = shm.create_shared_memory_region("dummy_data", "/dummy_data", 8) - triton_client.register_system_shared_memory("dummy_data", "/dummy_data", 8) - triton_client.unregister_system_shared_memory("dummy_data") - shm_status = triton_client.get_system_shared_memory_status() - if _protocol == "http": + self.triton_client.register_system_shared_memory("dummy_data", "/dummy_data", 8) + self.triton_client.unregister_system_shared_memory("dummy_data") + shm_status = self.triton_client.get_system_shared_memory_status() + if self.protocol == "http": self.assertTrue(len(shm_status) == 0) else: self.assertTrue(len(shm_status.regions) == 0) @@ -103,50 +107,74 @@ def test_unregister_after_register(self): def test_reregister_after_register(self): # Create a valid system shared memory region and unregister after register - if _protocol == "http": - triton_client = httpclient.InferenceServerClient(_url, verbose=True) - else: - triton_client = grpcclient.InferenceServerClient(_url, verbose=True) shm_op0_handle = shm.create_shared_memory_region("dummy_data", "/dummy_data", 8) - triton_client.register_system_shared_memory("dummy_data", "/dummy_data", 8) + self.triton_client.register_system_shared_memory("dummy_data", "/dummy_data", 8) try: - triton_client.register_system_shared_memory("dummy_data", "/dummy_data", 8) + self.triton_client.register_system_shared_memory( + "dummy_data", "/dummy_data", 8 + ) except Exception as ex: self.assertTrue( "shared memory region 'dummy_data' already in manager" in str(ex) ) - shm_status = triton_client.get_system_shared_memory_status() - if _protocol == "http": + shm_status = self.triton_client.get_system_shared_memory_status() + if self.protocol == "http": self.assertTrue(len(shm_status) == 1) else: self.assertTrue(len(shm_status.regions) == 1) shm.destroy_shared_memory_region(shm_op0_handle) - def _configure_sever(self): + def _configure_server( + self, create_byte_size=64, register_byte_size=64, register_offset=0 + ): + """Creates and registers shared memory regions for testing. + + Parameters + ---------- + create_byte_size: int + Size of each system shared memory region to create. + NOTE: This should be sufficiently large to hold the inputs/outputs + stored in shared memory. + + register_byte_size: int + Size of each system shared memory region to register with server. + NOTE: The (offset + register_byte_size) should be less than or equal + to the create_byte_size. Otherwise an exception will be raised for + an invalid set of registration args. + + register_offset: int + Offset into the shared memory object to start the registered region. + + """ shm_ip0_handle = shm.create_shared_memory_region( - "input0_data", "/input0_data", 64 + "input0_data", "/input0_data", create_byte_size ) shm_ip1_handle = shm.create_shared_memory_region( - "input1_data", "/input1_data", 64 + "input1_data", "/input1_data", create_byte_size ) shm_op0_handle = shm.create_shared_memory_region( - "output0_data", "/output0_data", 64 + "output0_data", "/output0_data", create_byte_size ) shm_op1_handle = shm.create_shared_memory_region( - "output1_data", "/output1_data", 64 + "output1_data", "/output1_data", create_byte_size ) + # Implicit assumption that input and output byte_sizes are 64 bytes for now input0_data = np.arange(start=0, stop=16, dtype=np.int32) input1_data = np.ones(shape=16, dtype=np.int32) shm.set_shared_memory_region(shm_ip0_handle, [input0_data]) shm.set_shared_memory_region(shm_ip1_handle, [input1_data]) - if _protocol == "http": - triton_client = httpclient.InferenceServerClient(_url, verbose=True) - else: - triton_client = grpcclient.InferenceServerClient(_url, verbose=True) - triton_client.register_system_shared_memory("input0_data", "/input0_data", 64) - triton_client.register_system_shared_memory("input1_data", "/input1_data", 64) - triton_client.register_system_shared_memory("output0_data", "/output0_data", 64) - triton_client.register_system_shared_memory("output1_data", "/output1_data", 64) + self.triton_client.register_system_shared_memory( + "input0_data", "/input0_data", register_byte_size, offset=register_offset + ) + self.triton_client.register_system_shared_memory( + "input1_data", "/input1_data", register_byte_size, offset=register_offset + ) + self.triton_client.register_system_shared_memory( + "output0_data", "/output0_data", register_byte_size, offset=register_offset + ) + self.triton_client.register_system_shared_memory( + "output1_data", "/output1_data", register_byte_size, offset=register_offset + ) return [shm_ip0_handle, shm_ip1_handle, shm_op0_handle, shm_op1_handle] def _cleanup_server(self, shm_handles): @@ -168,8 +196,7 @@ def _basic_inference( input1_data = np.ones(shape=16, dtype=np.int32) inputs = [] outputs = [] - if _protocol == "http": - triton_client = httpclient.InferenceServerClient(_url, verbose=True) + if self.protocol == "http": inputs.append(httpclient.InferInput("INPUT0", [1, 16], "INT32")) inputs.append(httpclient.InferInput("INPUT1", [1, 16], "INT32")) outputs.append(httpclient.InferRequestedOutput("OUTPUT0", binary_data=True)) @@ -177,7 +204,6 @@ def _basic_inference( httpclient.InferRequestedOutput("OUTPUT1", binary_data=False) ) else: - triton_client = grpcclient.InferenceServerClient(_url, verbose=True) inputs.append(grpcclient.InferInput("INPUT0", [1, 16], "INT32")) inputs.append(grpcclient.InferInput("INPUT1", [1, 16], "INT32")) outputs.append(grpcclient.InferRequestedOutput("OUTPUT0")) @@ -196,11 +222,11 @@ def _basic_inference( outputs[1].set_shared_memory("output1_data", 64, offset=shm_output_offset) try: - results = triton_client.infer( + results = self.triton_client.infer( "simple", inputs, model_version="", outputs=outputs ) output = results.get_output("OUTPUT0") - if _protocol == "http": + if self.protocol == "http": output_datatype = output["datatype"] output_shape = output["shape"] else: @@ -220,19 +246,15 @@ def _basic_inference( def test_unregister_after_inference(self): # Unregister after inference error_msg = [] - shm_handles = self._configure_sever() + shm_handles = self._configure_server() self._basic_inference( shm_handles[0], shm_handles[1], shm_handles[2], shm_handles[3], error_msg ) if len(error_msg) > 0: raise Exception(str(error_msg)) - if _protocol == "http": - triton_client = httpclient.InferenceServerClient(_url, verbose=True) - else: - triton_client = grpcclient.InferenceServerClient(_url, verbose=True) - triton_client.unregister_system_shared_memory("output0_data") - shm_status = triton_client.get_system_shared_memory_status() - if _protocol == "http": + self.triton_client.unregister_system_shared_memory("output0_data") + shm_status = self.triton_client.get_system_shared_memory_status() + if self.protocol == "http": self.assertTrue(len(shm_status) == 3) else: self.assertTrue(len(shm_status.regions) == 3) @@ -241,11 +263,7 @@ def test_unregister_after_inference(self): def test_register_after_inference(self): # Register after inference error_msg = [] - shm_handles = self._configure_sever() - if _protocol == "http": - triton_client = httpclient.InferenceServerClient(_url, verbose=True) - else: - triton_client = grpcclient.InferenceServerClient(_url, verbose=True) + shm_handles = self._configure_server() self._basic_inference( shm_handles[0], shm_handles[1], shm_handles[2], shm_handles[3], error_msg ) @@ -254,9 +272,11 @@ def test_register_after_inference(self): shm_ip2_handle = shm.create_shared_memory_region( "input2_data", "/input2_data", 64 ) - triton_client.register_system_shared_memory("input2_data", "/input2_data", 64) - shm_status = triton_client.get_system_shared_memory_status() - if _protocol == "http": + self.triton_client.register_system_shared_memory( + "input2_data", "/input2_data", 64 + ) + shm_status = self.triton_client.get_system_shared_memory_status() + if self.protocol == "http": self.assertTrue(len(shm_status) == 5) else: self.assertTrue(len(shm_status.regions) == 5) @@ -266,15 +286,13 @@ def test_register_after_inference(self): def test_too_big_shm(self): # Shared memory input region larger than needed - Throws error error_msg = [] - shm_handles = self._configure_sever() + shm_handles = self._configure_server() shm_ip2_handle = shm.create_shared_memory_region( "input2_data", "/input2_data", 128 ) - if _protocol == "http": - triton_client = httpclient.InferenceServerClient(_url, verbose=True) - else: - triton_client = grpcclient.InferenceServerClient(_url, verbose=True) - triton_client.register_system_shared_memory("input2_data", "/input2_data", 128) + self.triton_client.register_system_shared_memory( + "input2_data", "/input2_data", 128 + ) self._basic_inference( shm_handles[0], shm_ip2_handle, @@ -295,7 +313,7 @@ def test_too_big_shm(self): def test_mixed_raw_shm(self): # Mix of shared memory and RAW inputs error_msg = [] - shm_handles = self._configure_sever() + shm_handles = self._configure_server() input1_data = np.ones(shape=16, dtype=np.int32) self._basic_inference( shm_handles[0], [input1_data], shm_handles[2], shm_handles[3], error_msg @@ -306,19 +324,15 @@ def test_mixed_raw_shm(self): def test_unregisterall(self): # Unregister all shared memory blocks - shm_handles = self._configure_sever() - if _protocol == "http": - triton_client = httpclient.InferenceServerClient(_url, verbose=True) - else: - triton_client = grpcclient.InferenceServerClient(_url, verbose=True) - status_before = triton_client.get_system_shared_memory_status() - if _protocol == "http": + shm_handles = self._configure_server() + status_before = self.triton_client.get_system_shared_memory_status() + if self.protocol == "http": self.assertTrue(len(status_before) == 4) else: self.assertTrue(len(status_before.regions) == 4) - triton_client.unregister_system_shared_memory() - status_after = triton_client.get_system_shared_memory_status() - if _protocol == "http": + self.triton_client.unregister_system_shared_memory() + status_after = self.triton_client.get_system_shared_memory_status() + if self.protocol == "http": self.assertTrue(len(status_after) == 0) else: self.assertTrue(len(status_after.regions) == 0) @@ -327,8 +341,8 @@ def test_unregisterall(self): def test_infer_offset_out_of_bound(self): # Shared memory offset outside output region - Throws error error_msg = [] - shm_handles = self._configure_sever() - if _protocol == "http": + shm_handles = self._configure_server() + if self.protocol == "http": # -32 when placed in an int64 signed type, to get a negative offset # by overflowing offset = 2**64 - 32 @@ -348,11 +362,51 @@ def test_infer_offset_out_of_bound(self): self.assertIn("Invalid offset for shared memory region", error_msg[0]) self._cleanup_server(shm_handles) + def test_register_out_of_bound(self): + create_byte_size = 64 + + # Verify various edge cases of registered region size (offset+byte_size) + # don't go out of bounds of the actual created shm file object's size. + with self.assertRaisesRegex( + utils.InferenceServerException, + "failed to register shared memory region.*invalid args", + ): + self._configure_server( + create_byte_size=create_byte_size, + register_byte_size=create_byte_size + 1, + register_offset=0, + ) + + with self.assertRaisesRegex( + utils.InferenceServerException, + "failed to register shared memory region.*invalid args", + ): + self._configure_server( + create_byte_size=create_byte_size, + register_byte_size=create_byte_size, + register_offset=1, + ) + + with self.assertRaisesRegex( + utils.InferenceServerException, + "failed to register shared memory region.*invalid args", + ): + self._configure_server( + create_byte_size=create_byte_size, + register_byte_size=1, + register_offset=create_byte_size, + ) + + with self.assertRaisesRegex( + utils.InferenceServerException, + "failed to register shared memory region.*invalid args", + ): + self._configure_server( + create_byte_size=create_byte_size, + register_byte_size=0, + register_offset=create_byte_size + 1, + ) + if __name__ == "__main__": - _protocol = os.environ.get("CLIENT_TYPE", "http") - if _protocol == "http": - _url = "localhost:8000" - else: - _url = "localhost:8001" unittest.main() diff --git a/qa/L0_shared_memory/test.sh b/qa/L0_shared_memory/test.sh index 30abfca545..84b5a8a857 100755 --- a/qa/L0_shared_memory/test.sh +++ b/qa/L0_shared_memory/test.sh @@ -50,7 +50,8 @@ for i in \ test_too_big_shm \ test_mixed_raw_shm \ test_unregisterall \ - test_infer_offset_out_of_bound; do + test_infer_offset_out_of_bound \ + test_register_out_of_bound; do for client_type in http grpc; do SERVER_ARGS="--model-repository=`pwd`/models --log-verbose=1 ${SERVER_ARGS_EXTRA}" SERVER_LOG="./$i.$client_type.server.log" @@ -65,7 +66,7 @@ for i in \ echo "Test: $i, client type: $client_type" >>$CLIENT_LOG set +e - python $SHM_TEST SharedMemoryTest.$i >>$CLIENT_LOG 2>&1 + python3 $SHM_TEST SharedMemoryTest.$i >>$CLIENT_LOG 2>&1 if [ $? -ne 0 ]; then echo -e "\n***\n*** Test Failed\n***" RET=1 diff --git a/src/shared_memory_manager.cc b/src/shared_memory_manager.cc index 268c6451d1..724c00a73d 100644 --- a/src/shared_memory_manager.cc +++ b/src/shared_memory_manager.cc @@ -120,6 +120,7 @@ SharedMemoryManager::UnregisterHelper( #include #include #include +#include #include #include "common.h" @@ -145,6 +146,51 @@ OpenSharedMemoryRegion(const std::string& shm_key, int* shm_fd) return nullptr; } +TRITONSERVER_Error* +GetSharedMemoryRegionSize( + const std::string& shm_key, int shm_fd, size_t* shm_region_size) +{ + struct stat file_status; + if (fstat(shm_fd, &file_status) == -1) { + LOG_VERBOSE(1) << "fstat on shm_fd failed, errno: " << errno; + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + std::string("Invalid shared memory region: '" + shm_key + "'").c_str()); + } + + // According to POSIX standard, type off_t can be negative, so for sake of + // catching possible under/overflows, assert that the size is non-negative. + if (file_status.st_size < 0) { + LOG_VERBOSE(1) << "File size of shared memory region must be non-negative"; + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + std::string("Invalid shared memory region: '" + shm_key + "'").c_str()); + } + + *shm_region_size = static_cast(file_status.st_size); + return nullptr; // success +} + +TRITONSERVER_Error* +CheckSharedMemoryRegionSize( + const std::string& name, const std::string& shm_key, int shm_fd, + size_t offset, size_t byte_size) +{ + size_t shm_region_size = 0; + RETURN_IF_ERR(GetSharedMemoryRegionSize(shm_key, shm_fd, &shm_region_size)); + // User-provided offset and byte_size should not go out-of-bounds. + if ((offset + byte_size) > shm_region_size) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + std::string( + "failed to register shared memory region '" + name + + "': invalid args") + .c_str()); + } + + return nullptr; // success +} + TRITONSERVER_Error* MapSharedMemory( const int shm_fd, const size_t offset, const size_t byte_size, @@ -249,6 +295,7 @@ SharedMemoryManager::RegisterSystemSharedMemory( for (auto itr = shared_memory_map_.begin(); itr != shared_memory_map_.end(); ++itr) { if (itr->second->shm_key_ == shm_key) { + // FIXME: Consider invalid file descriptors after close shm_fd = itr->second->shm_fd_; break; } @@ -259,6 +306,10 @@ SharedMemoryManager::RegisterSystemSharedMemory( RETURN_IF_ERR(OpenSharedMemoryRegion(shm_key, &shm_fd)); } + // Enforce that registered region is in-bounds of shm file object. + RETURN_IF_ERR( + CheckSharedMemoryRegionSize(name, shm_key, shm_fd, offset, byte_size)); + // Mmap and then close the shared memory descriptor TRITONSERVER_Error* err_mmap = MapSharedMemory(shm_fd, offset, byte_size, &mapped_addr);