From d6cb396c73deded4bc316cb08cdb2da763847164 Mon Sep 17 00:00:00 2001 From: krishung5 Date: Wed, 1 May 2024 19:49:10 -0700 Subject: [PATCH 01/12] Validate CUDA SHM region registration size --- src/CMakeLists.txt | 1 + src/shared_memory_manager.cc | 52 ++++++++++++++++++++++++++++++++++++ src/shared_memory_manager.h | 1 + 3 files changed, 54 insertions(+) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 783275d8d7..53c8add989 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -180,6 +180,7 @@ if(${TRITON_ENABLE_GPU}) main PRIVATE CUDA::cudart + -lcuda ) endif() # TRITON_ENABLE_GPU diff --git a/src/shared_memory_manager.cc b/src/shared_memory_manager.cc index 1064982669..f290dded5b 100644 --- a/src/shared_memory_manager.cc +++ b/src/shared_memory_manager.cc @@ -264,6 +264,46 @@ OpenCudaIPCRegion( return nullptr; } + +TRITONSERVER_Error* +GetCudaSharedMemoryRegionSize(CUdeviceptr data_ptr, size_t& shm_region_size) +{ + CUdeviceptr* base = nullptr; + CUresult result = cuMemGetAddressRange(base, &shm_region_size, data_ptr); + if (result != CUDA_SUCCESS) { + const char* errorString; + if (cuGetErrorString(result, &errorString) != CUDA_SUCCESS) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, "Failed to get CUDA error string"); + } + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + std::string( + "Failed to get CUDA address range: " + std::string(errorString)) + .c_str()); + } + return nullptr; +} + +TRITONSERVER_Error* +CheckCudaSharedMemoryRegionSize( + const std::string& name, CUdeviceptr data_ptr, size_t byte_size) +{ + size_t shm_region_size = 0; + RETURN_IF_ERR(GetCudaSharedMemoryRegionSize(data_ptr, shm_region_size)); + + // User-provided offset and byte_size should not go out-of-bounds. + if (byte_size > shm_region_size) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + std::string( + "failed to register CUDA shared memory region '" + name + + "': register size exceeds shared memory region size") + .c_str()); + } + + return nullptr; +} #endif // TRITON_ENABLE_GPU } // namespace @@ -365,6 +405,18 @@ SharedMemoryManager::RegisterCUDASharedMemory( // Get CUDA shared memory base address TRITONSERVER_Error* err = OpenCudaIPCRegion(cuda_shm_handle, &mapped_addr, device_id); + if (err != nullptr) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + std::string( + "failed to open CUDA shared memory region '" + name + + "': " + TRITONSERVER_ErrorMessage(err)) + .c_str()); + } + + // Enforce that registered region is in-bounds of shm file object. + RETURN_IF_ERR(CheckCudaSharedMemoryRegionSize( + name, reinterpret_cast(mapped_addr), byte_size)); if (err != nullptr) { return TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_INVALID_ARG, diff --git a/src/shared_memory_manager.h b/src/shared_memory_manager.h index f079308bd5..51eb0f0786 100644 --- a/src/shared_memory_manager.h +++ b/src/shared_memory_manager.h @@ -39,6 +39,7 @@ #include "triton/common/triton_json.h" #ifdef TRITON_ENABLE_GPU +#include #include #endif // TRITON_ENABLE_GPU From adb986f831acc06e68605144d775287d534e27bb Mon Sep 17 00:00:00 2001 From: krishung5 Date: Wed, 1 May 2024 19:51:32 -0700 Subject: [PATCH 02/12] Add test --- .../cuda_shared_memory_test.py | 313 +++++++++--------- qa/L0_cuda_shared_memory/test.sh | 5 +- qa/common/infer_util.py | 89 ++++- 3 files changed, 252 insertions(+), 155 deletions(-) diff --git a/qa/L0_cuda_shared_memory/cuda_shared_memory_test.py b/qa/L0_cuda_shared_memory/cuda_shared_memory_test.py index e17692ef56..231d9444d3 100755 --- a/qa/L0_cuda_shared_memory/cuda_shared_memory_test.py +++ b/qa/L0_cuda_shared_memory/cuda_shared_memory_test.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -# Copyright 2019-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -33,6 +33,7 @@ import os import unittest +import infer_util as iu import numpy as np import test_util as tu import tritonclient.grpc as grpcclient @@ -42,6 +43,24 @@ class CudaSharedMemoryTest(tu.TestResultCollector): + DEFAULT_SHM_BYTE_SIZE = 64 + + def setUp(self): + self._setup_client() + + def _setup_client(self): + self.protocol = os.environ.get("CLIENT_TYPE", "http") + if self.protocol == "http": + self.url = "localhost:8000" + self.triton_client = httpclient.InferenceServerClient( + self.url, verbose=True + ) + else: + self.url = "localhost:8001" + self.triton_client = grpcclient.InferenceServerClient( + self.url, verbose=True + ) + def test_invalid_create_shm(self): # Raises error since tried to create invalid cuda shared memory region try: @@ -52,19 +71,15 @@ def test_invalid_create_shm(self): def test_valid_create_set_register(self): # Create a valid cuda shared memory region, fill data in it and register - if _protocol == "http": - triton_client = httpclient.InferenceServerClient(_url, verbose=True) - else: - triton_client = grpcclient.InferenceServerClient(_url, verbose=True) shm_op0_handle = cshm.create_shared_memory_region("dummy_data", 8, 0) cshm.set_shared_memory_region( shm_op0_handle, [np.array([1, 2], dtype=np.float32)] ) - triton_client.register_cuda_shared_memory( + self.triton_client.register_cuda_shared_memory( "dummy_data", cshm.get_raw_handle(shm_op0_handle), 0, 8 ) - shm_status = triton_client.get_cuda_shared_memory_status() - if _protocol == "http": + shm_status = self.triton_client.get_cuda_shared_memory_status() + if self.protocol == "http": self.assertEqual(len(shm_status), 1) else: self.assertEqual(len(shm_status.regions), 1) @@ -72,14 +87,10 @@ def test_valid_create_set_register(self): def test_unregister_before_register(self): # Create a valid cuda shared memory region and unregister before register - if _protocol == "http": - triton_client = httpclient.InferenceServerClient(_url, verbose=True) - else: - triton_client = grpcclient.InferenceServerClient(_url, verbose=True) shm_op0_handle = cshm.create_shared_memory_region("dummy_data", 8, 0) - triton_client.unregister_cuda_shared_memory("dummy_data") - shm_status = triton_client.get_cuda_shared_memory_status() - if _protocol == "http": + self.triton_client.unregister_cuda_shared_memory("dummy_data") + shm_status = self.triton_client.get_cuda_shared_memory_status() + if self.protocol == "http": self.assertEqual(len(shm_status), 0) else: self.assertEqual(len(shm_status.regions), 0) @@ -87,17 +98,13 @@ def test_unregister_before_register(self): def test_unregister_after_register(self): # Create a valid cuda shared memory region and unregister after register - if _protocol == "http": - triton_client = httpclient.InferenceServerClient(_url, verbose=True) - else: - triton_client = grpcclient.InferenceServerClient(_url, verbose=True) shm_op0_handle = cshm.create_shared_memory_region("dummy_data", 8, 0) - triton_client.register_cuda_shared_memory( + self.triton_client.register_cuda_shared_memory( "dummy_data", cshm.get_raw_handle(shm_op0_handle), 0, 8 ) - triton_client.unregister_cuda_shared_memory("dummy_data") - shm_status = triton_client.get_cuda_shared_memory_status() - if _protocol == "http": + self.triton_client.unregister_cuda_shared_memory("dummy_data") + shm_status = self.triton_client.get_cuda_shared_memory_status() + if self.protocol == "http": self.assertEqual(len(shm_status), 0) else: self.assertEqual(len(shm_status.regions), 0) @@ -105,54 +112,92 @@ def test_unregister_after_register(self): def test_reregister_after_register(self): # Create a valid cuda shared memory region and unregister after register - if _protocol == "http": - triton_client = httpclient.InferenceServerClient(_url, verbose=True) - else: - triton_client = grpcclient.InferenceServerClient(_url, verbose=True) shm_op0_handle = cshm.create_shared_memory_region("dummy_data", 8, 0) - triton_client.register_cuda_shared_memory( + self.triton_client.register_cuda_shared_memory( "dummy_data", cshm.get_raw_handle(shm_op0_handle), 0, 8 ) try: - triton_client.register_cuda_shared_memory( + self.triton_client.register_cuda_shared_memory( "dummy_data", cshm.get_raw_handle(shm_op0_handle), 0, 8 ) except Exception as ex: self.assertIn( "shared memory region 'dummy_data' already in manager", str(ex) ) - shm_status = triton_client.get_cuda_shared_memory_status() - if _protocol == "http": + shm_status = self.triton_client.get_cuda_shared_memory_status() + if self.protocol == "http": self.assertEqual(len(shm_status), 1) else: self.assertEqual(len(shm_status.regions), 1) cshm.destroy_shared_memory_region(shm_op0_handle) - def _configure_sever(self): - shm_ip0_handle = cshm.create_shared_memory_region("input0_data", 64, 0) - shm_ip1_handle = cshm.create_shared_memory_region("input1_data", 64, 0) - shm_op0_handle = cshm.create_shared_memory_region("output0_data", 64, 0) - shm_op1_handle = cshm.create_shared_memory_region("output1_data", 64, 0) + def _configure_server( + self, + create_byte_size=DEFAULT_SHM_BYTE_SIZE, + register_byte_size=DEFAULT_SHM_BYTE_SIZE, + device_id=0, + ): + """Creates and registers cuda shared memory regions for testing. + + Parameters + ---------- + create_byte_size: int + Size of each cuda shared memory region to create. + NOTE: This should be sufficiently large to hold the inputs/outputs + stored in shared memory. + + register_byte_size: int + Size of each cuda shared memory region to register with server. + NOTE: The register_byte_size should be less than or equal + to the create_byte_size. Otherwise an exception will be raised for + an invalid set of registration args. + + device_id: int + The GPU device ID of the cuda shared memory region to be created. + + """ + + shm_ip0_handle = cshm.create_shared_memory_region( + "input0_data", create_byte_size, device_id + ) + shm_ip1_handle = cshm.create_shared_memory_region( + "input1_data", create_byte_size, device_id + ) + shm_op0_handle = cshm.create_shared_memory_region( + "output0_data", create_byte_size, device_id + ) + shm_op1_handle = cshm.create_shared_memory_region( + "output1_data", create_byte_size, device_id + ) input0_data = np.arange(start=0, stop=16, dtype=np.int32) input1_data = np.ones(shape=16, dtype=np.int32) cshm.set_shared_memory_region(shm_ip0_handle, [input0_data]) cshm.set_shared_memory_region(shm_ip1_handle, [input1_data]) - if _protocol == "http": - triton_client = httpclient.InferenceServerClient(_url, verbose=True) - else: - triton_client = grpcclient.InferenceServerClient(_url, verbose=True) - triton_client.register_cuda_shared_memory( - "input0_data", cshm.get_raw_handle(shm_ip0_handle), 0, 64 + + self.triton_client.register_cuda_shared_memory( + "input0_data", + cshm.get_raw_handle(shm_ip0_handle), + device_id, + register_byte_size, ) - triton_client.register_cuda_shared_memory( - "input1_data", cshm.get_raw_handle(shm_ip1_handle), 0, 64 + self.triton_client.register_cuda_shared_memory( + "input1_data", + cshm.get_raw_handle(shm_ip1_handle), + device_id, + register_byte_size, ) - triton_client.register_cuda_shared_memory( - "output0_data", cshm.get_raw_handle(shm_op0_handle), 0, 64 + self.triton_client.register_cuda_shared_memory( + "output0_data", + cshm.get_raw_handle(shm_op0_handle), + device_id, + register_byte_size, ) - triton_client.register_cuda_shared_memory( - "output1_data", cshm.get_raw_handle(shm_op1_handle), 0, 64 + self.triton_client.register_cuda_shared_memory( + "output1_data", + cshm.get_raw_handle(shm_op1_handle), + device_id, + register_byte_size, ) return [shm_ip0_handle, shm_ip1_handle, shm_op0_handle, shm_op1_handle] @@ -160,79 +205,27 @@ def _cleanup_server(self, shm_handles): for shm_handle in shm_handles: cshm.destroy_shared_memory_region(shm_handle) - def _basic_inference( - self, - shm_ip0_handle, - shm_ip1_handle, - shm_op0_handle, - shm_op1_handle, - error_msg, - big_shm_name="", - big_shm_size=64, - ): - input0_data = np.arange(start=0, stop=16, dtype=np.int32) - input1_data = np.ones(shape=16, dtype=np.int32) - inputs = [] - outputs = [] - if _protocol == "http": - triton_client = httpclient.InferenceServerClient(_url, verbose=True) - inputs.append(httpclient.InferInput("INPUT0", [1, 16], "INT32")) - inputs.append(httpclient.InferInput("INPUT1", [1, 16], "INT32")) - outputs.append(httpclient.InferRequestedOutput("OUTPUT0", binary_data=True)) - outputs.append( - httpclient.InferRequestedOutput("OUTPUT1", binary_data=False) - ) - else: - triton_client = grpcclient.InferenceServerClient(_url, verbose=True) - inputs.append(grpcclient.InferInput("INPUT0", [1, 16], "INT32")) - inputs.append(grpcclient.InferInput("INPUT1", [1, 16], "INT32")) - outputs.append(grpcclient.InferRequestedOutput("OUTPUT0")) - outputs.append(grpcclient.InferRequestedOutput("OUTPUT1")) - inputs[0].set_shared_memory("input0_data", 64) - if type(shm_ip1_handle) == np.array: - inputs[1].set_data_from_numpy(input0_data, binary_data=True) - elif big_shm_name != "": - inputs[1].set_shared_memory(big_shm_name, big_shm_size) - else: - inputs[1].set_shared_memory("input1_data", 64) - outputs[0].set_shared_memory("output0_data", 64) - outputs[1].set_shared_memory("output1_data", 64) - - try: - results = triton_client.infer( - "simple", inputs, model_version="", outputs=outputs - ) - output = results.get_output("OUTPUT0") - if _protocol == "http": - output_datatype = output["datatype"] - output_shape = output["shape"] - else: - output_datatype = output.datatype - output_shape = output.shape - output_dtype = triton_to_np_dtype(output_datatype) - output_data = cshm.get_contents_as_numpy( - shm_op0_handle, output_dtype, output_shape - ) - self.assertTrue((output_data[0] == (input0_data + input1_data)).all()) - except Exception as ex: - error_msg.append(str(ex)) - def test_unregister_after_inference(self): # Unregister after inference error_msg = [] - shm_handles = self._configure_sever() - self._basic_inference( - shm_handles[0], shm_handles[1], shm_handles[2], shm_handles[3], error_msg + shm_handles = self._configure_server() + iu.shm_basic_infer( + self, + self.triton_client, + shm_handles[0], + shm_handles[1], + shm_handles[2], + shm_handles[3], + error_msg, + protocol=self.protocol, + use_cuda_shared_memory=True, ) if len(error_msg) > 0: raise Exception(str(error_msg)) - if _protocol == "http": - triton_client = httpclient.InferenceServerClient(_url, verbose=True) - else: - triton_client = grpcclient.InferenceServerClient(_url, verbose=True) - triton_client.unregister_cuda_shared_memory("output0_data") - shm_status = triton_client.get_cuda_shared_memory_status() - if _protocol == "http": + + self.triton_client.unregister_cuda_shared_memory("output0_data") + shm_status = self.triton_client.get_cuda_shared_memory_status() + if self.protocol == "http": self.assertEqual(len(shm_status), 3) else: self.assertEqual(len(shm_status.regions), 3) @@ -241,22 +234,26 @@ def test_unregister_after_inference(self): def test_register_after_inference(self): # Register after inference error_msg = [] - shm_handles = self._configure_sever() - if _protocol == "http": - triton_client = httpclient.InferenceServerClient(_url, verbose=True) - else: - triton_client = grpcclient.InferenceServerClient(_url, verbose=True) - self._basic_inference( - shm_handles[0], shm_handles[1], shm_handles[2], shm_handles[3], error_msg + shm_handles = self._configure_server() + iu.shm_basic_infer( + self, + self.triton_client, + shm_handles[0], + shm_handles[1], + shm_handles[2], + shm_handles[3], + error_msg, + protocol=self.protocol, + use_cuda_shared_memory=True, ) if len(error_msg) > 0: raise Exception(str(error_msg)) shm_ip2_handle = cshm.create_shared_memory_region("input2_data", 64, 0) - triton_client.register_cuda_shared_memory( + self.triton_client.register_cuda_shared_memory( "input2_data", cshm.get_raw_handle(shm_ip2_handle), 0, 64 ) - shm_status = triton_client.get_cuda_shared_memory_status() - if _protocol == "http": + shm_status = self.triton_client.get_cuda_shared_memory_status() + if self.protocol == "http": self.assertEqual(len(shm_status), 5) else: self.assertEqual(len(shm_status.regions), 5) @@ -266,23 +263,23 @@ def test_register_after_inference(self): def test_too_big_shm(self): # Shared memory input region larger than needed - Throws error error_msg = [] - shm_handles = self._configure_sever() + shm_handles = self._configure_server() shm_ip2_handle = cshm.create_shared_memory_region("input2_data", 128, 0) - if _protocol == "http": - triton_client = httpclient.InferenceServerClient(_url, verbose=True) - else: - triton_client = grpcclient.InferenceServerClient(_url, verbose=True) - triton_client.register_cuda_shared_memory( + self.triton_client.register_cuda_shared_memory( "input2_data", cshm.get_raw_handle(shm_ip2_handle), 0, 128 ) - self._basic_inference( + iu.shm_basic_infer( + self, + self.triton_client, shm_handles[0], shm_ip2_handle, shm_handles[2], shm_handles[3], error_msg, - "input2_data", - 128, + big_shm_name="input2_data", + big_shm_size=128, + protocol=self.protocol, + use_cuda_shared_memory=True, ) if len(error_msg) > 0: self.assertIn( @@ -295,40 +292,52 @@ def test_too_big_shm(self): def test_mixed_raw_shm(self): # Mix of shared memory and RAW inputs error_msg = [] - shm_handles = self._configure_sever() + shm_handles = self._configure_server() input1_data = np.ones(shape=16, dtype=np.int32) - self._basic_inference( - shm_handles[0], [input1_data], shm_handles[2], shm_handles[3], error_msg + iu.shm_basic_infer( + self, + self.triton_client, + shm_handles[0], + [input1_data], + shm_handles[2], + shm_handles[3], + error_msg, + protocol=self.protocol, + use_cuda_shared_memory=True, ) + if len(error_msg) > 0: raise Exception(error_msg[-1]) self._cleanup_server(shm_handles) def test_unregisterall(self): # Unregister all shared memory blocks - shm_handles = self._configure_sever() - if _protocol == "http": - triton_client = httpclient.InferenceServerClient(_url, verbose=True) - else: - triton_client = grpcclient.InferenceServerClient(_url, verbose=True) - status_before = triton_client.get_cuda_shared_memory_status() - if _protocol == "http": + shm_handles = self._configure_server() + status_before = self.triton_client.get_cuda_shared_memory_status() + if self.protocol == "http": self.assertEqual(len(status_before), 4) else: self.assertEqual(len(status_before.regions), 4) - triton_client.unregister_cuda_shared_memory() - status_after = triton_client.get_cuda_shared_memory_status() - if _protocol == "http": + self.triton_client.unregister_cuda_shared_memory() + status_after = self.triton_client.get_cuda_shared_memory_status() + if self.protocol == "http": self.assertEqual(len(status_after), 0) else: self.assertEqual(len(status_after.regions), 0) self._cleanup_server(shm_handles) + def test_register_out_of_bound(self): + create_byte_size = 64 + # Verify various edge cases of registered region size don't go out of bounds of the actual created shm region's size. + with self.assertRaisesRegex( + InferenceServerException, + "failed to register cuda memory region.*register size exceeds cuda shared memory region size", + ): + self._configure_server( + create_byte_size=create_byte_size, + register_byte_size=create_byte_size + 1, + ) + if __name__ == "__main__": - _protocol = os.environ.get("CLIENT_TYPE", "http") - if _protocol == "http": - _url = "localhost:8000" - else: - _url = "localhost:8001" unittest.main() diff --git a/qa/L0_cuda_shared_memory/test.sh b/qa/L0_cuda_shared_memory/test.sh index b011244174..1daa9724d4 100755 --- a/qa/L0_cuda_shared_memory/test.sh +++ b/qa/L0_cuda_shared_memory/test.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved. +# Copyright 2019-2024, NVIDIA CORPORATION. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -47,7 +47,8 @@ for i in \ test_register_after_inference \ test_too_big_shm \ test_mixed_raw_shm \ - test_unregisterall; do + test_unregisterall \ + test_register_out_of_bound; do for client_type in http grpc; do SERVER_ARGS="--model-repository=`pwd`/models --log-verbose=1" SERVER_LOG="./$i.$client_type.server.log" diff --git a/qa/common/infer_util.py b/qa/common/infer_util.py index 9a181c1d29..18512d9927 100755 --- a/qa/common/infer_util.py +++ b/qa/common/infer_util.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -# Copyright 2018-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright 2018-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -1306,3 +1306,90 @@ def infer_zero( shm.destroy_shared_memory_region(shm_op_handles[io_num]) return results + + +# Perform basic inference for shared memory tests +def shm_basic_infer( + tester, + triton_client, + shm_ip0_handle, + shm_ip1_handle, + shm_op0_handle, + shm_op1_handle, + error_msg, + big_shm_name="", + big_shm_size=64, + default_shm_byte_size=64, + shm_output_offset=0, + shm_output_byte_size=64, + protocol="http", + use_system_shared_memory=False, + use_cuda_shared_memory=False, +): + # Lazy shm imports... + if use_system_shared_memory: + import tritonclient.utils.shared_memory as shm + elif use_cuda_shared_memory: + import tritonclient.utils.cuda_shared_memory as cudashm + else: + raise Exception("No shared memory type specified") + + input0_data = np.arange(start=0, stop=16, dtype=np.int32) + input1_data = np.ones(shape=16, dtype=np.int32) + inputs = [] + outputs = [] + if protocol == "http": + inputs.append(httpclient.InferInput("INPUT0", [1, 16], "INT32")) + inputs.append(httpclient.InferInput("INPUT1", [1, 16], "INT32")) + outputs.append(httpclient.InferRequestedOutput("OUTPUT0", binary_data=True)) + outputs.append(httpclient.InferRequestedOutput("OUTPUT1", binary_data=False)) + else: + inputs.append(grpcclient.InferInput("INPUT0", [1, 16], "INT32")) + inputs.append(grpcclient.InferInput("INPUT1", [1, 16], "INT32")) + outputs.append(grpcclient.InferRequestedOutput("OUTPUT0")) + outputs.append(grpcclient.InferRequestedOutput("OUTPUT1")) + + inputs[0].set_shared_memory("input0_data", default_shm_byte_size) + + if type(shm_ip1_handle) == np.array: + inputs[1].set_data_from_numpy(input0_data, binary_data=True) + elif big_shm_name != "": + inputs[1].set_shared_memory(big_shm_name, big_shm_size) + else: + inputs[1].set_shared_memory("input1_data", default_shm_byte_size) + + outputs[0].set_shared_memory( + "output0_data", shm_output_byte_size, offset=shm_output_offset + ) + outputs[1].set_shared_memory( + "output1_data", shm_output_byte_size, offset=shm_output_offset + ) + + try: + results = triton_client.infer( + "simple", inputs, model_version="", outputs=outputs + ) + output = results.get_output("OUTPUT0") + if protocol == "http": + output_datatype = output["datatype"] + output_shape = output["shape"] + else: + output_datatype = output.datatype + output_shape = output.shape + output_dtype = triton_to_np_dtype(output_datatype) + + if use_system_shared_memory: + output_data = shm.get_contents_as_numpy( + shm_op0_handle, output_dtype, output_shape + ) + elif use_cuda_shared_memory: + output_data = cudashm.get_contents_as_numpy( + shm_op0_handle, output_dtype, output_shape + ) + + tester.assertTrue( + (output_data[0] == (input0_data + input1_data)).all(), + "Model output does not match expected output", + ) + except Exception as ex: + error_msg.append(str(ex)) From ec819e04ca641d6d25923da96d88877ccb2991c3 Mon Sep 17 00:00:00 2001 From: krishung5 Date: Wed, 1 May 2024 19:51:57 -0700 Subject: [PATCH 03/12] Refactor --- qa/L0_shared_memory/shared_memory_test.py | 127 +++++++++------------- 1 file changed, 49 insertions(+), 78 deletions(-) diff --git a/qa/L0_shared_memory/shared_memory_test.py b/qa/L0_shared_memory/shared_memory_test.py index 828c714ec6..ca2f2e6abe 100755 --- a/qa/L0_shared_memory/shared_memory_test.py +++ b/qa/L0_shared_memory/shared_memory_test.py @@ -33,6 +33,7 @@ import os import unittest +import infer_util as iu import numpy as np import test_util as tu import tritonclient.grpc as grpcclient @@ -186,84 +187,20 @@ def _cleanup_server(self, shm_handles): for shm_handle in shm_handles: shm.destroy_shared_memory_region(shm_handle) - def _basic_inference( - self, - shm_ip0_handle, - shm_ip1_handle, - shm_op0_handle, - shm_op1_handle, - error_msg, - big_shm_name="", - big_shm_size=DEFAULT_SHM_BYTE_SIZE, - shm_output_offset=0, - shm_output_byte_size=DEFAULT_SHM_BYTE_SIZE, - default_shm_byte_size=DEFAULT_SHM_BYTE_SIZE, - ): - input0_data = np.arange(start=0, stop=16, dtype=np.int32) - input1_data = np.ones(shape=16, dtype=np.int32) - inputs = [] - outputs = [] - if self.protocol == "http": - inputs.append(httpclient.InferInput("INPUT0", [1, 16], "INT32")) - inputs.append(httpclient.InferInput("INPUT1", [1, 16], "INT32")) - outputs.append(httpclient.InferRequestedOutput("OUTPUT0", binary_data=True)) - outputs.append( - httpclient.InferRequestedOutput("OUTPUT1", binary_data=False) - ) - else: - inputs.append(grpcclient.InferInput("INPUT0", [1, 16], "INT32")) - inputs.append(grpcclient.InferInput("INPUT1", [1, 16], "INT32")) - outputs.append(grpcclient.InferRequestedOutput("OUTPUT0")) - outputs.append(grpcclient.InferRequestedOutput("OUTPUT1")) - - inputs[0].set_shared_memory("input0_data", default_shm_byte_size) - - if type(shm_ip1_handle) == np.array: - inputs[1].set_data_from_numpy(input0_data, binary_data=True) - elif big_shm_name != "": - inputs[1].set_shared_memory(big_shm_name, big_shm_size) - else: - inputs[1].set_shared_memory("input1_data", default_shm_byte_size) - - outputs[0].set_shared_memory( - "output0_data", shm_output_byte_size, offset=shm_output_offset - ) - outputs[1].set_shared_memory( - "output1_data", shm_output_byte_size, offset=shm_output_offset - ) - - try: - results = self.triton_client.infer( - "simple", inputs, model_version="", outputs=outputs - ) - output = results.get_output("OUTPUT0") - if self.protocol == "http": - output_datatype = output["datatype"] - output_shape = output["shape"] - else: - output_datatype = output.datatype - output_shape = output.shape - output_dtype = utils.triton_to_np_dtype(output_datatype) - output_data = shm.get_contents_as_numpy( - shm_op0_handle, output_dtype, output_shape - ) - self.assertTrue( - (output_data[0] == (input0_data + input1_data)).all(), - "Model output does not match expected output", - ) - except Exception as ex: - error_msg.append(str(ex)) - def test_unregister_after_inference(self): # Unregister after inference error_msg = [] shm_handles = self._configure_server() - self._basic_inference( + iu.shm_basic_infer( + self, + self.triton_client, shm_handles[0], shm_handles[1], shm_handles[2], shm_handles[3], error_msg, + protocol=self.protocol, + use_system_shared_memory=True, ) if len(error_msg) > 0: raise Exception(str(error_msg)) @@ -279,9 +216,19 @@ def test_register_after_inference(self): # Register after inference error_msg = [] shm_handles = self._configure_server() - self._basic_inference( - shm_handles[0], shm_handles[1], shm_handles[2], shm_handles[3], error_msg + + iu.shm_basic_infer( + self, + self.triton_client, + shm_handles[0], + shm_handles[1], + shm_handles[2], + shm_handles[3], + error_msg, + protocol=self.protocol, + use_system_shared_memory=True, ) + if len(error_msg) > 0: raise Exception(str(error_msg)) shm_ip2_handle = shm.create_shared_memory_region( @@ -308,14 +255,19 @@ def test_too_big_shm(self): self.triton_client.register_system_shared_memory( "input2_data", "/input2_data", 128 ) - self._basic_inference( + + iu.shm_basic_infer( + self, + self.triton_client, shm_handles[0], shm_ip2_handle, shm_handles[2], shm_handles[3], error_msg, - "input2_data", - 128, + big_shm_name="input2_data", + big_shm_size=128, + protocol=self.protocol, + use_system_shared_memory=True, ) if len(error_msg) > 0: self.assertTrue( @@ -330,8 +282,17 @@ def test_mixed_raw_shm(self): error_msg = [] shm_handles = self._configure_server() input1_data = np.ones(shape=16, dtype=np.int32) - self._basic_inference( - shm_handles[0], [input1_data], shm_handles[2], shm_handles[3], error_msg + + iu.shm_basic_infer( + self, + self.triton_client, + shm_handles[0], + [input1_data], + shm_handles[2], + shm_handles[3], + error_msg, + protocol=self.protocol, + use_system_shared_memory=True, ) if len(error_msg) > 0: raise Exception(error_msg[-1]) @@ -365,14 +326,20 @@ def test_infer_offset_out_of_bound(self): # gRPC will throw an error if > 2**63 - 1, so instead test for # exceeding shm region size by 1 byte, given its size is 64 bytes offset = 64 - self._basic_inference( + + iu.shm_basic_infer( + self, + self.triton_client, shm_handles[0], shm_handles[1], shm_handles[2], shm_handles[3], error_msg, shm_output_offset=offset, + protocol=self.protocol, + use_system_shared_memory=True, ) + self.assertEqual(len(error_msg), 1) self.assertIn("Invalid offset for shared memory region", error_msg[0]) self._cleanup_server(shm_handles) @@ -384,7 +351,9 @@ def test_infer_byte_size_out_of_bound(self): offset = 60 byte_size = self.DEFAULT_SHM_BYTE_SIZE - self._basic_inference( + iu.shm_basic_infer( + self, + self.triton_client, shm_handles[0], shm_handles[1], shm_handles[2], @@ -392,6 +361,8 @@ def test_infer_byte_size_out_of_bound(self): error_msg, shm_output_offset=offset, shm_output_byte_size=byte_size, + protocol=self.protocol, + use_system_shared_memory=True, ) self.assertEqual(len(error_msg), 1) self.assertIn( From 96e4c0445d1329ee006fac6c58fdb39cbbf3e8a0 Mon Sep 17 00:00:00 2001 From: krishung5 Date: Wed, 1 May 2024 19:56:55 -0700 Subject: [PATCH 04/12] Fix error message --- qa/L0_cuda_shared_memory/cuda_shared_memory_test.py | 2 +- src/shared_memory_manager.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/qa/L0_cuda_shared_memory/cuda_shared_memory_test.py b/qa/L0_cuda_shared_memory/cuda_shared_memory_test.py index 231d9444d3..ce097cb4c7 100755 --- a/qa/L0_cuda_shared_memory/cuda_shared_memory_test.py +++ b/qa/L0_cuda_shared_memory/cuda_shared_memory_test.py @@ -331,7 +331,7 @@ def test_register_out_of_bound(self): # Verify various edge cases of registered region size don't go out of bounds of the actual created shm region's size. with self.assertRaisesRegex( InferenceServerException, - "failed to register cuda memory region.*register size exceeds cuda shared memory region size", + r"failed to register CUDA shared memory region.*register size exceeds CUDA shared memory region size", ): self._configure_server( create_byte_size=create_byte_size, diff --git a/src/shared_memory_manager.cc b/src/shared_memory_manager.cc index f290dded5b..8b9ba8e4fa 100644 --- a/src/shared_memory_manager.cc +++ b/src/shared_memory_manager.cc @@ -298,7 +298,7 @@ CheckCudaSharedMemoryRegionSize( TRITONSERVER_ERROR_INVALID_ARG, std::string( "failed to register CUDA shared memory region '" + name + - "': register size exceeds shared memory region size") + "': register size exceeds CUDA shared memory region size") .c_str()); } From 81fc5d0cf23630c048783f77167428474c5e560b Mon Sep 17 00:00:00 2001 From: krishung5 Date: Fri, 3 May 2024 11:14:50 -0700 Subject: [PATCH 05/12] Address comment --- .../cuda_shared_memory_test.py | 2 +- src/shared_memory_manager.cc | 16 +++++----------- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/qa/L0_cuda_shared_memory/cuda_shared_memory_test.py b/qa/L0_cuda_shared_memory/cuda_shared_memory_test.py index ce097cb4c7..a19ca411df 100755 --- a/qa/L0_cuda_shared_memory/cuda_shared_memory_test.py +++ b/qa/L0_cuda_shared_memory/cuda_shared_memory_test.py @@ -331,7 +331,7 @@ def test_register_out_of_bound(self): # Verify various edge cases of registered region size don't go out of bounds of the actual created shm region's size. with self.assertRaisesRegex( InferenceServerException, - r"failed to register CUDA shared memory region.*register size exceeds CUDA shared memory region size", + "failed to register shared memory region.*invalid args", ): self._configure_server( create_byte_size=create_byte_size, diff --git a/src/shared_memory_manager.cc b/src/shared_memory_manager.cc index 8b9ba8e4fa..4c07ed9275 100644 --- a/src/shared_memory_manager.cc +++ b/src/shared_memory_manager.cc @@ -297,8 +297,8 @@ CheckCudaSharedMemoryRegionSize( return TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_INVALID_ARG, std::string( - "failed to register CUDA shared memory region '" + name + - "': register size exceeds CUDA shared memory region size") + "failed to register shared memory region '" + name + + "': invalid args") .c_str()); } @@ -406,25 +406,19 @@ SharedMemoryManager::RegisterCUDASharedMemory( TRITONSERVER_Error* err = OpenCudaIPCRegion(cuda_shm_handle, &mapped_addr, device_id); if (err != nullptr) { + std::string err_str = TRITONSERVER_ErrorMessage(err); + TRITONSERVER_ErrorDelete(err); return TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_INVALID_ARG, std::string( "failed to open CUDA shared memory region '" + name + - "': " + TRITONSERVER_ErrorMessage(err)) + "': " + err_str) .c_str()); } // Enforce that registered region is in-bounds of shm file object. RETURN_IF_ERR(CheckCudaSharedMemoryRegionSize( name, reinterpret_cast(mapped_addr), byte_size)); - if (err != nullptr) { - return TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INVALID_ARG, - std::string( - "failed to register CUDA shared memory region '" + name + - "': " + TRITONSERVER_ErrorMessage(err)) - .c_str()); - } shared_memory_map_.insert(std::make_pair( name, std::unique_ptr(new CUDASharedMemoryInfo( From 0d39ce390686a4e88a5720769dc9c86feb78f62a Mon Sep 17 00:00:00 2001 From: krishung5 Date: Fri, 3 May 2024 12:23:36 -0700 Subject: [PATCH 06/12] Address comment --- .../cuda_shared_memory_test.py | 2 +- src/shared_memory_manager.cc | 15 ++------------- 2 files changed, 3 insertions(+), 14 deletions(-) diff --git a/qa/L0_cuda_shared_memory/cuda_shared_memory_test.py b/qa/L0_cuda_shared_memory/cuda_shared_memory_test.py index a19ca411df..83c847ca4a 100755 --- a/qa/L0_cuda_shared_memory/cuda_shared_memory_test.py +++ b/qa/L0_cuda_shared_memory/cuda_shared_memory_test.py @@ -331,7 +331,7 @@ def test_register_out_of_bound(self): # Verify various edge cases of registered region size don't go out of bounds of the actual created shm region's size. with self.assertRaisesRegex( InferenceServerException, - "failed to register shared memory region.*invalid args", + "failed to register CUDA shared memory region.*invalid args", ): self._configure_server( create_byte_size=create_byte_size, diff --git a/src/shared_memory_manager.cc b/src/shared_memory_manager.cc index 4c07ed9275..18200a9e96 100644 --- a/src/shared_memory_manager.cc +++ b/src/shared_memory_manager.cc @@ -297,7 +297,7 @@ CheckCudaSharedMemoryRegionSize( return TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_INVALID_ARG, std::string( - "failed to register shared memory region '" + name + + "failed to register CUDA shared memory region '" + name + "': invalid args") .c_str()); } @@ -403,18 +403,7 @@ SharedMemoryManager::RegisterCUDASharedMemory( void* mapped_addr; // Get CUDA shared memory base address - TRITONSERVER_Error* err = - OpenCudaIPCRegion(cuda_shm_handle, &mapped_addr, device_id); - if (err != nullptr) { - std::string err_str = TRITONSERVER_ErrorMessage(err); - TRITONSERVER_ErrorDelete(err); - return TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INVALID_ARG, - std::string( - "failed to open CUDA shared memory region '" + name + - "': " + err_str) - .c_str()); - } + RETURN_IF_ERR(OpenCudaIPCRegion(cuda_shm_handle, &mapped_addr, device_id)); // Enforce that registered region is in-bounds of shm file object. RETURN_IF_ERR(CheckCudaSharedMemoryRegionSize( From 5079b24f463cbf56f93762a8787a41df435619ec Mon Sep 17 00:00:00 2001 From: krishung5 Date: Fri, 3 May 2024 15:39:58 -0700 Subject: [PATCH 07/12] Address comment --- qa/L0_cuda_shared_memory/cuda_shared_memory_test.py | 2 +- src/shared_memory_manager.cc | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/qa/L0_cuda_shared_memory/cuda_shared_memory_test.py b/qa/L0_cuda_shared_memory/cuda_shared_memory_test.py index 83c847ca4a..a19ca411df 100755 --- a/qa/L0_cuda_shared_memory/cuda_shared_memory_test.py +++ b/qa/L0_cuda_shared_memory/cuda_shared_memory_test.py @@ -331,7 +331,7 @@ def test_register_out_of_bound(self): # Verify various edge cases of registered region size don't go out of bounds of the actual created shm region's size. with self.assertRaisesRegex( InferenceServerException, - "failed to register CUDA shared memory region.*invalid args", + "failed to register shared memory region.*invalid args", ): self._configure_server( create_byte_size=create_byte_size, diff --git a/src/shared_memory_manager.cc b/src/shared_memory_manager.cc index 18200a9e96..b419c7ddda 100644 --- a/src/shared_memory_manager.cc +++ b/src/shared_memory_manager.cc @@ -276,11 +276,10 @@ GetCudaSharedMemoryRegionSize(CUdeviceptr data_ptr, size_t& shm_region_size) return TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_INTERNAL, "Failed to get CUDA error string"); } + // Should not pass the detailed error message back to the client. + LOG_ERROR << "Failed to get CUDA address range: " << errorString; return TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INTERNAL, - std::string( - "Failed to get CUDA address range: " + std::string(errorString)) - .c_str()); + TRITONSERVER_ERROR_INTERNAL, "Failed to get CUDA address range"); } return nullptr; } @@ -297,7 +296,7 @@ CheckCudaSharedMemoryRegionSize( return TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_INVALID_ARG, std::string( - "failed to register CUDA shared memory region '" + name + + "failed to register shared memory region '" + name + "': invalid args") .c_str()); } From f66a7f1eddd0859080585e131d3d9972cc46d56c Mon Sep 17 00:00:00 2001 From: krishung5 Date: Mon, 6 May 2024 12:35:18 -0700 Subject: [PATCH 08/12] Make detailed error internal. Only pass the general error message to the client --- src/shared_memory_manager.cc | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/shared_memory_manager.cc b/src/shared_memory_manager.cc index b419c7ddda..ce0c6d93ef 100644 --- a/src/shared_memory_manager.cc +++ b/src/shared_memory_manager.cc @@ -276,10 +276,11 @@ GetCudaSharedMemoryRegionSize(CUdeviceptr data_ptr, size_t& shm_region_size) return TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_INTERNAL, "Failed to get CUDA error string"); } - // Should not pass the detailed error message back to the client. - LOG_ERROR << "Failed to get CUDA address range: " << errorString; return TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INTERNAL, "Failed to get CUDA address range"); + TRITONSERVER_ERROR_INTERNAL, + std::string( + "Failed to get CUDA address range: " + std::string(errorString)) + .c_str()); } return nullptr; } @@ -289,10 +290,15 @@ CheckCudaSharedMemoryRegionSize( const std::string& name, CUdeviceptr data_ptr, size_t byte_size) { size_t shm_region_size = 0; - RETURN_IF_ERR(GetCudaSharedMemoryRegionSize(data_ptr, shm_region_size)); + auto err = GetCudaSharedMemoryRegionSize(data_ptr, shm_region_size); // User-provided offset and byte_size should not go out-of-bounds. - if (byte_size > shm_region_size) { + if (err != nullptr || byte_size > shm_region_size) { + if (err != nullptr) { + // Should not pass the detailed error message back to the client. + LOG_ERROR << TRITONSERVER_ErrorMessage(err); + TRITONSERVER_ErrorDelete(err); + } return TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_INVALID_ARG, std::string( From 9f12ee85b65a77cea3efeb0a2976df5d01405755 Mon Sep 17 00:00:00 2001 From: krishung5 Date: Mon, 6 May 2024 17:19:15 -0700 Subject: [PATCH 09/12] Replace 64byte with DEFAULT_SHM_BYTE_SIZE --- qa/L0_cuda_shared_memory/cuda_shared_memory_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qa/L0_cuda_shared_memory/cuda_shared_memory_test.py b/qa/L0_cuda_shared_memory/cuda_shared_memory_test.py index a19ca411df..0c877c8749 100755 --- a/qa/L0_cuda_shared_memory/cuda_shared_memory_test.py +++ b/qa/L0_cuda_shared_memory/cuda_shared_memory_test.py @@ -327,7 +327,7 @@ def test_unregisterall(self): self._cleanup_server(shm_handles) def test_register_out_of_bound(self): - create_byte_size = 64 + create_byte_size = self.DEFAULT_SHM_BYTE_SIZE # Verify various edge cases of registered region size don't go out of bounds of the actual created shm region's size. with self.assertRaisesRegex( InferenceServerException, From 04711b23eb9d9926b4ecdc7a7d63f5559e58ec01 Mon Sep 17 00:00:00 2001 From: krishung5 Date: Mon, 6 May 2024 18:33:19 -0700 Subject: [PATCH 10/12] Fix L0_grpc --- qa/L0_grpc/python_grpc_aio_test.py | 4 ++-- src/shared_memory_manager.cc | 9 +++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/qa/L0_grpc/python_grpc_aio_test.py b/qa/L0_grpc/python_grpc_aio_test.py index f342f19ad5..ba43b36abb 100755 --- a/qa/L0_grpc/python_grpc_aio_test.py +++ b/qa/L0_grpc/python_grpc_aio_test.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -# Copyright 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -113,7 +113,7 @@ async def test_get_cuda_shared_memory_status(self): async def test_register_cuda_shared_memory(self): with self.assertRaisesRegex( InferenceServerException, - "\[StatusCode\.INVALID_ARGUMENT\] failed to register CUDA shared memory region '': failed to open CUDA IPC handle: invalid argument", + "failed to register shared memory region.*invalid args", ): await self._triton_client.register_cuda_shared_memory("", b"", 0, 0) diff --git a/src/shared_memory_manager.cc b/src/shared_memory_manager.cc index ce0c6d93ef..52b5fc9d6f 100644 --- a/src/shared_memory_manager.cc +++ b/src/shared_memory_manager.cc @@ -255,11 +255,12 @@ OpenCudaIPCRegion( cudaError_t err = cudaIpcOpenMemHandle( data_ptr, *cuda_shm_handle, cudaIpcMemLazyEnablePeerAccess); if (err != cudaSuccess) { + // Should not pass the detailed error message back to the client. + LOG_ERROR << "failed to open CUDA IPC handle: " << cudaGetErrorString(err); return TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INTERNAL, std::string( - "failed to open CUDA IPC handle: " + - std::string(cudaGetErrorString(err))) - .c_str()); + TRITONSERVER_ERROR_INVALID_ARG, + std::string("failed to register shared memory region: invalid args") + .c_str()); } return nullptr; From b55d343f3089a579bde76d310cb7995f7455acf4 Mon Sep 17 00:00:00 2001 From: krishung5 Date: Tue, 7 May 2024 09:16:11 -0700 Subject: [PATCH 11/12] Update comments --- src/shared_memory_manager.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/shared_memory_manager.cc b/src/shared_memory_manager.cc index 52b5fc9d6f..aafe092e81 100644 --- a/src/shared_memory_manager.cc +++ b/src/shared_memory_manager.cc @@ -255,7 +255,7 @@ OpenCudaIPCRegion( cudaError_t err = cudaIpcOpenMemHandle( data_ptr, *cuda_shm_handle, cudaIpcMemLazyEnablePeerAccess); if (err != cudaSuccess) { - // Should not pass the detailed error message back to the client. + // Log detailed error message LOG_ERROR << "failed to open CUDA IPC handle: " << cudaGetErrorString(err); return TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_INVALID_ARG, @@ -296,7 +296,7 @@ CheckCudaSharedMemoryRegionSize( // User-provided offset and byte_size should not go out-of-bounds. if (err != nullptr || byte_size > shm_region_size) { if (err != nullptr) { - // Should not pass the detailed error message back to the client. + // Log detailed error message LOG_ERROR << TRITONSERVER_ErrorMessage(err); TRITONSERVER_ErrorDelete(err); } From 1277853e050ad336a527723f0ba25d29514e2e79 Mon Sep 17 00:00:00 2001 From: krishung5 Date: Tue, 7 May 2024 09:17:25 -0700 Subject: [PATCH 12/12] Update comments --- src/shared_memory_manager.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/shared_memory_manager.cc b/src/shared_memory_manager.cc index aafe092e81..8101a2e236 100644 --- a/src/shared_memory_manager.cc +++ b/src/shared_memory_manager.cc @@ -255,7 +255,7 @@ OpenCudaIPCRegion( cudaError_t err = cudaIpcOpenMemHandle( data_ptr, *cuda_shm_handle, cudaIpcMemLazyEnablePeerAccess); if (err != cudaSuccess) { - // Log detailed error message + // Log detailed error message and send generic error to client LOG_ERROR << "failed to open CUDA IPC handle: " << cudaGetErrorString(err); return TRITONSERVER_ErrorNew( TRITONSERVER_ERROR_INVALID_ARG, @@ -296,7 +296,7 @@ CheckCudaSharedMemoryRegionSize( // User-provided offset and byte_size should not go out-of-bounds. if (err != nullptr || byte_size > shm_region_size) { if (err != nullptr) { - // Log detailed error message + // Log detailed error message and send generic error to client LOG_ERROR << TRITONSERVER_ErrorMessage(err); TRITONSERVER_ErrorDelete(err); }