Skip to content

Commit

Permalink
fix: Add reference count tracking for shared memory regions (#7567)
Browse files Browse the repository at this point in the history
Co-authored-by: GuanLuo <[email protected]>
  • Loading branch information
pskiran1 and GuanLuo committed Sep 11, 2024
1 parent e39048d commit da52f1c
Show file tree
Hide file tree
Showing 14 changed files with 886 additions and 222 deletions.
312 changes: 240 additions & 72 deletions qa/L0_cuda_shared_memory/cuda_shared_memory_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,20 @@
sys.path.append("../common")

import os
import time
import unittest
from functools import partial

import infer_util as iu
import numpy as np
import test_util as tu
import tritonclient.grpc as grpcclient
import tritonclient.http as httpclient
import tritonshmutils.cuda_shared_memory as cshm
import tritonclient.utils.cuda_shared_memory as cshm
from tritonclient.utils import *


class CudaSharedMemoryTest(tu.TestResultCollector):
class CudaSharedMemoryTestBase(tu.TestResultCollector):
DEFAULT_SHM_BYTE_SIZE = 64

def setUp(self):
Expand All @@ -61,76 +63,6 @@ def _setup_client(self):
self.url, verbose=True
)

def test_invalid_create_shm(self):
# Raises error since tried to create invalid cuda shared memory region
try:
shm_op0_handle = cshm.create_shared_memory_region("dummy_data", -1, 0)
cshm.destroy_shared_memory_region(shm_op0_handle)
except Exception as ex:
self.assertEqual(str(ex), "unable to create cuda shared memory handle")

def test_valid_create_set_register(self):
# Create a valid cuda shared memory region, fill data in it and register
shm_op0_handle = cshm.create_shared_memory_region("dummy_data", 8, 0)
cshm.set_shared_memory_region(
shm_op0_handle, [np.array([1, 2], dtype=np.float32)]
)
self.triton_client.register_cuda_shared_memory(
"dummy_data", cshm.get_raw_handle(shm_op0_handle), 0, 8
)
shm_status = self.triton_client.get_cuda_shared_memory_status()
if self.protocol == "http":
self.assertEqual(len(shm_status), 1)
else:
self.assertEqual(len(shm_status.regions), 1)
cshm.destroy_shared_memory_region(shm_op0_handle)

def test_unregister_before_register(self):
# Create a valid cuda shared memory region and unregister before register
shm_op0_handle = cshm.create_shared_memory_region("dummy_data", 8, 0)
self.triton_client.unregister_cuda_shared_memory("dummy_data")
shm_status = self.triton_client.get_cuda_shared_memory_status()
if self.protocol == "http":
self.assertEqual(len(shm_status), 0)
else:
self.assertEqual(len(shm_status.regions), 0)
cshm.destroy_shared_memory_region(shm_op0_handle)

def test_unregister_after_register(self):
# Create a valid cuda shared memory region and unregister after register
shm_op0_handle = cshm.create_shared_memory_region("dummy_data", 8, 0)
self.triton_client.register_cuda_shared_memory(
"dummy_data", cshm.get_raw_handle(shm_op0_handle), 0, 8
)
self.triton_client.unregister_cuda_shared_memory("dummy_data")
shm_status = self.triton_client.get_cuda_shared_memory_status()
if self.protocol == "http":
self.assertEqual(len(shm_status), 0)
else:
self.assertEqual(len(shm_status.regions), 0)
cshm.destroy_shared_memory_region(shm_op0_handle)

def test_reregister_after_register(self):
# Create a valid cuda shared memory region and unregister after register
shm_op0_handle = cshm.create_shared_memory_region("dummy_data", 8, 0)
self.triton_client.register_cuda_shared_memory(
"dummy_data", cshm.get_raw_handle(shm_op0_handle), 0, 8
)
try:
self.triton_client.register_cuda_shared_memory(
"dummy_data", cshm.get_raw_handle(shm_op0_handle), 0, 8
)
except Exception as ex:
self.assertIn(
"shared memory region 'dummy_data' already in manager", str(ex)
)
shm_status = self.triton_client.get_cuda_shared_memory_status()
if self.protocol == "http":
self.assertEqual(len(shm_status), 1)
else:
self.assertEqual(len(shm_status.regions), 1)
cshm.destroy_shared_memory_region(shm_op0_handle)

def _configure_server(
self,
create_byte_size=DEFAULT_SHM_BYTE_SIZE,
Expand Down Expand Up @@ -205,6 +137,78 @@ def _cleanup_server(self, shm_handles):
for shm_handle in shm_handles:
cshm.destroy_shared_memory_region(shm_handle)


class CudaSharedMemoryTest(CudaSharedMemoryTestBase):
def test_invalid_create_shm(self):
# Raises error since tried to create invalid cuda shared memory region
try:
shm_op0_handle = cshm.create_shared_memory_region("dummy_data", -1, 0)
cshm.destroy_shared_memory_region(shm_op0_handle)
except Exception as ex:
self.assertEqual(str(ex), "unable to create cuda shared memory handle")

def test_valid_create_set_register(self):
# Create a valid cuda shared memory region, fill data in it and register
shm_op0_handle = cshm.create_shared_memory_region("dummy_data", 8, 0)
cshm.set_shared_memory_region(
shm_op0_handle, [np.array([1, 2], dtype=np.float32)]
)
self.triton_client.register_cuda_shared_memory(
"dummy_data", cshm.get_raw_handle(shm_op0_handle), 0, 8
)
shm_status = self.triton_client.get_cuda_shared_memory_status()
if self.protocol == "http":
self.assertEqual(len(shm_status), 1)
else:
self.assertEqual(len(shm_status.regions), 1)
cshm.destroy_shared_memory_region(shm_op0_handle)

def test_unregister_before_register(self):
# Create a valid cuda shared memory region and unregister before register
shm_op0_handle = cshm.create_shared_memory_region("dummy_data", 8, 0)
self.triton_client.unregister_cuda_shared_memory("dummy_data")
shm_status = self.triton_client.get_cuda_shared_memory_status()
if self.protocol == "http":
self.assertEqual(len(shm_status), 0)
else:
self.assertEqual(len(shm_status.regions), 0)
cshm.destroy_shared_memory_region(shm_op0_handle)

def test_unregister_after_register(self):
# Create a valid cuda shared memory region and unregister after register
shm_op0_handle = cshm.create_shared_memory_region("dummy_data", 8, 0)
self.triton_client.register_cuda_shared_memory(
"dummy_data", cshm.get_raw_handle(shm_op0_handle), 0, 8
)
self.triton_client.unregister_cuda_shared_memory("dummy_data")
shm_status = self.triton_client.get_cuda_shared_memory_status()
if self.protocol == "http":
self.assertEqual(len(shm_status), 0)
else:
self.assertEqual(len(shm_status.regions), 0)
cshm.destroy_shared_memory_region(shm_op0_handle)

def test_reregister_after_register(self):
# Create a valid cuda shared memory region and unregister after register
shm_op0_handle = cshm.create_shared_memory_region("dummy_data", 8, 0)
self.triton_client.register_cuda_shared_memory(
"dummy_data", cshm.get_raw_handle(shm_op0_handle), 0, 8
)
try:
self.triton_client.register_cuda_shared_memory(
"dummy_data", cshm.get_raw_handle(shm_op0_handle), 0, 8
)
except Exception as ex:
self.assertIn(
"shared memory region 'dummy_data' already in manager", str(ex)
)
shm_status = self.triton_client.get_cuda_shared_memory_status()
if self.protocol == "http":
self.assertEqual(len(shm_status), 1)
else:
self.assertEqual(len(shm_status.regions), 1)
cshm.destroy_shared_memory_region(shm_op0_handle)

def test_unregister_after_inference(self):
# Unregister after inference
error_msg = []
Expand Down Expand Up @@ -396,5 +400,169 @@ def test_infer_byte_size_out_of_bound(self):
self._cleanup_server(shm_handles)


class TestCudaSharedMemoryUnregister(CudaSharedMemoryTestBase):
def _test_unregister_shm_fail(self):
second_client = httpclient.InferenceServerClient("localhost:8000", verbose=True)

with self.assertRaises(InferenceServerException) as ex:
second_client.unregister_cuda_shared_memory()
self.assertIn(
"Failed to unregister the following cuda shared memory regions: input0_data ,input1_data ,output0_data ,output1_data",
str(ex.exception),
)

with self.assertRaises(InferenceServerException) as ex:
second_client.unregister_cuda_shared_memory("input0_data")
self.assertIn(
"Cannot unregister shared memory region 'input0_data', it is currently in use.",
str(ex.exception),
)

with self.assertRaises(InferenceServerException) as ex:
second_client.unregister_cuda_shared_memory("input1_data")
self.assertIn(
"Cannot unregister shared memory region 'input1_data', it is currently in use.",
str(ex.exception),
)

with self.assertRaises(InferenceServerException) as ex:
second_client.unregister_cuda_shared_memory("output0_data")
self.assertIn(
"Cannot unregister shared memory region 'output0_data', it is currently in use.",
str(ex.exception),
)

with self.assertRaises(InferenceServerException) as ex:
second_client.unregister_cuda_shared_memory("output1_data")
self.assertIn(
"Cannot unregister shared memory region 'output1_data', it is currently in use.",
str(ex.exception),
)

def _test_shm_not_found(self):
second_client = httpclient.InferenceServerClient("localhost:8000", verbose=True)

with self.assertRaises(InferenceServerException) as ex:
second_client.get_cuda_shared_memory_status("input0_data")
self.assertIn(
"Unable to find cuda shared memory region: 'input0_data'",
str(ex.exception),
)

with self.assertRaises(InferenceServerException) as ex:
second_client.get_cuda_shared_memory_status("input1_data")
self.assertIn(
"Unable to find cuda shared memory region: 'input1_data'",
str(ex.exception),
)

with self.assertRaises(InferenceServerException) as ex:
second_client.get_cuda_shared_memory_status("output0_data")
self.assertIn(
"Unable to find cuda shared memory region: 'output0_data'",
str(ex.exception),
)

with self.assertRaises(InferenceServerException) as ex:
second_client.get_cuda_shared_memory_status("output1_data")
self.assertIn(
"Unable to find cuda shared memory region: 'output1_data'",
str(ex.exception),
)

def test_unregister_shm_during_inference_http(self):
try:
self.triton_client.unregister_cuda_shared_memory()
shm_handles = self._configure_server()

inputs = [
httpclient.InferInput("INPUT0", [1, 16], "INT32"),
httpclient.InferInput("INPUT1", [1, 16], "INT32"),
]
outputs = [
httpclient.InferRequestedOutput("OUTPUT0", binary_data=True),
httpclient.InferRequestedOutput("OUTPUT1", binary_data=False),
]

inputs[0].set_shared_memory("input0_data", self.DEFAULT_SHM_BYTE_SIZE)
inputs[1].set_shared_memory("input1_data", self.DEFAULT_SHM_BYTE_SIZE)
outputs[0].set_shared_memory("output0_data", self.DEFAULT_SHM_BYTE_SIZE)
outputs[1].set_shared_memory("output1_data", self.DEFAULT_SHM_BYTE_SIZE)

async_request = self.triton_client.async_infer(
model_name="simple", inputs=inputs, outputs=outputs
)

# Ensure inference started
time.sleep(2)

# Try unregister shm regions during inference
self._test_unregister_shm_fail()

# Blocking call
async_request.get_result()

# Try unregister shm regions after inference
self.triton_client.unregister_cuda_shared_memory()
self._test_shm_not_found()

finally:
self._cleanup_server(shm_handles)

def test_unregister_shm_during_inference_grpc(self):
try:
self.triton_client.unregister_cuda_shared_memory()
shm_handles = self._configure_server()

inputs = [
grpcclient.InferInput("INPUT0", [1, 16], "INT32"),
grpcclient.InferInput("INPUT1", [1, 16], "INT32"),
]
outputs = [
grpcclient.InferRequestedOutput("OUTPUT0"),
grpcclient.InferRequestedOutput("OUTPUT1"),
]

inputs[0].set_shared_memory("input0_data", self.DEFAULT_SHM_BYTE_SIZE)
inputs[1].set_shared_memory("input1_data", self.DEFAULT_SHM_BYTE_SIZE)
outputs[0].set_shared_memory("output0_data", self.DEFAULT_SHM_BYTE_SIZE)
outputs[1].set_shared_memory("output1_data", self.DEFAULT_SHM_BYTE_SIZE)

def callback(user_data, result, error):
if error:
user_data.append(error)
else:
user_data.append(result)

user_data = []

self.triton_client.async_infer(
model_name="simple",
inputs=inputs,
outputs=outputs,
callback=partial(callback, user_data),
)

# Ensure inference started
time.sleep(2)

# Try unregister shm regions during inference
self._test_unregister_shm_fail()

# Wait until the results are available in user_data
time_out = 20
while (len(user_data) == 0) and time_out > 0:
time_out = time_out - 1
time.sleep(1)
time.sleep(2)

# Try unregister shm regions after inference
self.triton_client.unregister_cuda_shared_memory()
self._test_shm_not_found()

finally:
self._cleanup_server(shm_handles)


if __name__ == "__main__":
unittest.main()
41 changes: 41 additions & 0 deletions qa/L0_cuda_shared_memory/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,47 @@ for i in \
done
done

mkdir -p python_models/simple/1/
cp ../python_models/execute_delayed_model/model.py ./python_models/simple/1/
cp ../python_models/execute_delayed_model/config.pbtxt ./python_models/simple/
sed -i 's/KIND_CPU/KIND_GPU/g' ./python_models/simple/config.pbtxt

for client_type in http grpc; do
SERVER_ARGS="--model-repository=`pwd`/python_models --log-verbose=1 ${SERVER_ARGS_EXTRA}"
SERVER_LOG="./unregister_shm.$client_type.server.log"
run_server
if [ "$SERVER_PID" == "0" ]; then
echo -e "\n***\n*** Failed to start $SERVER\n***"
cat $SERVER_LOG
exit 1
fi

export CLIENT_TYPE=$client_type
CLIENT_LOG="./unregister_shm.$client_type.client.log"
set +e
python3 $SHM_TEST TestCudaSharedMemoryUnregister.test_unregister_shm_during_inference_$client_type >>$CLIENT_LOG 2>&1
if [ $? -ne 0 ]; then
cat $CLIENT_LOG
echo -e "\n***\n*** Test Failed\n***"
RET=1
else
check_test_results $TEST_RESULT_FILE 1
if [ $? -ne 0 ]; then
cat $TEST_RESULT_FILE
echo -e "\n***\n*** Test Result Verification Failed\n***"
RET=1
fi
fi

kill $SERVER_PID
wait $SERVER_PID
if [ $? -ne 0 ]; then
echo -e "\n***\n*** Test Server shut down non-gracefully\n***"
RET=1
fi
set -e
done

if [ $RET -eq 0 ]; then
echo -e "\n***\n*** Test Passed\n***"
else
Expand Down
Loading

0 comments on commit da52f1c

Please sign in to comment.