Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Add reference count tracking for shared memory regions #7567

Merged
merged 26 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
7dad71b
Add shm reference counter support
pskiran1 Aug 26, 2024
8fd831c
Update
pskiran1 Aug 26, 2024
58aa7fe
Update copyright
pskiran1 Aug 26, 2024
aa0b282
Fix pre-commit errors
pskiran1 Aug 26, 2024
7b057d2
Merge branch 'main' of https://github.com/triton-inference-server/ser…
pskiran1 Aug 27, 2024
20edd69
Merge branch 'spolisetty_oob_dos_issue_fix' of https://github.com/tri…
pskiran1 Aug 27, 2024
6c4424b
Merge branch 'main' of https://github.com/triton-inference-server/ser…
pskiran1 Aug 29, 2024
0d7bf1a
Enhancements
pskiran1 Aug 29, 2024
bf9e87c
Fix pre-commit errors
pskiran1 Aug 29, 2024
862d01d
Fix alert
pskiran1 Aug 29, 2024
8290e87
Fix pre-commit errors
pskiran1 Aug 29, 2024
d849132
Update
pskiran1 Aug 30, 2024
8266aed
Update
pskiran1 Aug 30, 2024
c56725f
Update
pskiran1 Aug 30, 2024
481b2ff
Undo formatting
pskiran1 Aug 30, 2024
d0fb4b9
Fix errors
pskiran1 Aug 30, 2024
ba4a1bc
Add copyright
pskiran1 Aug 31, 2024
cc1ce14
Merge branch 'main' of https://github.com/triton-inference-server/ser…
pskiran1 Aug 31, 2024
094d846
Enhancements
pskiran1 Sep 1, 2024
2215904
Fix pre-commit
pskiran1 Sep 1, 2024
474f344
Update names
pskiran1 Sep 3, 2024
813e3c9
Merge branch 'main' of https://github.com/triton-inference-server/ser…
pskiran1 Sep 4, 2024
8669859
Merge branch 'main' of https://github.com/triton-inference-server/ser…
pskiran1 Sep 6, 2024
cfb50cb
Update src/grpc/infer_handler.h
pskiran1 Sep 6, 2024
f42368f
Merge branch 'spolisetty_oob_dos_issue_fix' of https://github.com/tri…
pskiran1 Sep 6, 2024
0be4137
Update
pskiran1 Sep 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
312 changes: 240 additions & 72 deletions qa/L0_cuda_shared_memory/cuda_shared_memory_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,20 @@
sys.path.append("../common")

import os
import time
import unittest
from functools import partial

import infer_util as iu
import numpy as np
import test_util as tu
import tritonclient.grpc as grpcclient
import tritonclient.http as httpclient
import tritonshmutils.cuda_shared_memory as cshm
import tritonclient.utils.cuda_shared_memory as cshm
from tritonclient.utils import *


class CudaSharedMemoryTest(tu.TestResultCollector):
class CudaSharedMemoryTestBase(tu.TestResultCollector):
DEFAULT_SHM_BYTE_SIZE = 64

def setUp(self):
Expand All @@ -61,76 +63,6 @@ def _setup_client(self):
self.url, verbose=True
)

def test_invalid_create_shm(self):
# Raises error since tried to create invalid cuda shared memory region
try:
shm_op0_handle = cshm.create_shared_memory_region("dummy_data", -1, 0)
cshm.destroy_shared_memory_region(shm_op0_handle)
except Exception as ex:
self.assertEqual(str(ex), "unable to create cuda shared memory handle")

def test_valid_create_set_register(self):
# Create a valid cuda shared memory region, fill data in it and register
shm_op0_handle = cshm.create_shared_memory_region("dummy_data", 8, 0)
cshm.set_shared_memory_region(
shm_op0_handle, [np.array([1, 2], dtype=np.float32)]
)
self.triton_client.register_cuda_shared_memory(
"dummy_data", cshm.get_raw_handle(shm_op0_handle), 0, 8
)
shm_status = self.triton_client.get_cuda_shared_memory_status()
if self.protocol == "http":
self.assertEqual(len(shm_status), 1)
else:
self.assertEqual(len(shm_status.regions), 1)
cshm.destroy_shared_memory_region(shm_op0_handle)

def test_unregister_before_register(self):
# Create a valid cuda shared memory region and unregister before register
shm_op0_handle = cshm.create_shared_memory_region("dummy_data", 8, 0)
self.triton_client.unregister_cuda_shared_memory("dummy_data")
shm_status = self.triton_client.get_cuda_shared_memory_status()
if self.protocol == "http":
self.assertEqual(len(shm_status), 0)
else:
self.assertEqual(len(shm_status.regions), 0)
cshm.destroy_shared_memory_region(shm_op0_handle)

def test_unregister_after_register(self):
# Create a valid cuda shared memory region and unregister after register
shm_op0_handle = cshm.create_shared_memory_region("dummy_data", 8, 0)
self.triton_client.register_cuda_shared_memory(
"dummy_data", cshm.get_raw_handle(shm_op0_handle), 0, 8
)
self.triton_client.unregister_cuda_shared_memory("dummy_data")
shm_status = self.triton_client.get_cuda_shared_memory_status()
if self.protocol == "http":
self.assertEqual(len(shm_status), 0)
else:
self.assertEqual(len(shm_status.regions), 0)
cshm.destroy_shared_memory_region(shm_op0_handle)

def test_reregister_after_register(self):
# Create a valid cuda shared memory region and unregister after register
shm_op0_handle = cshm.create_shared_memory_region("dummy_data", 8, 0)
self.triton_client.register_cuda_shared_memory(
"dummy_data", cshm.get_raw_handle(shm_op0_handle), 0, 8
)
try:
self.triton_client.register_cuda_shared_memory(
"dummy_data", cshm.get_raw_handle(shm_op0_handle), 0, 8
)
except Exception as ex:
self.assertIn(
"shared memory region 'dummy_data' already in manager", str(ex)
)
shm_status = self.triton_client.get_cuda_shared_memory_status()
if self.protocol == "http":
self.assertEqual(len(shm_status), 1)
else:
self.assertEqual(len(shm_status.regions), 1)
cshm.destroy_shared_memory_region(shm_op0_handle)

def _configure_server(
self,
create_byte_size=DEFAULT_SHM_BYTE_SIZE,
Expand Down Expand Up @@ -205,6 +137,78 @@ def _cleanup_server(self, shm_handles):
for shm_handle in shm_handles:
cshm.destroy_shared_memory_region(shm_handle)


class CudaSharedMemoryTest(CudaSharedMemoryTestBase):
def test_invalid_create_shm(self):
# Raises error since tried to create invalid cuda shared memory region
try:
shm_op0_handle = cshm.create_shared_memory_region("dummy_data", -1, 0)
cshm.destroy_shared_memory_region(shm_op0_handle)
except Exception as ex:
self.assertEqual(str(ex), "unable to create cuda shared memory handle")

def test_valid_create_set_register(self):
# Create a valid cuda shared memory region, fill data in it and register
shm_op0_handle = cshm.create_shared_memory_region("dummy_data", 8, 0)
cshm.set_shared_memory_region(
shm_op0_handle, [np.array([1, 2], dtype=np.float32)]
)
self.triton_client.register_cuda_shared_memory(
"dummy_data", cshm.get_raw_handle(shm_op0_handle), 0, 8
)
shm_status = self.triton_client.get_cuda_shared_memory_status()
if self.protocol == "http":
self.assertEqual(len(shm_status), 1)
else:
self.assertEqual(len(shm_status.regions), 1)
cshm.destroy_shared_memory_region(shm_op0_handle)

def test_unregister_before_register(self):
# Create a valid cuda shared memory region and unregister before register
shm_op0_handle = cshm.create_shared_memory_region("dummy_data", 8, 0)
self.triton_client.unregister_cuda_shared_memory("dummy_data")
shm_status = self.triton_client.get_cuda_shared_memory_status()
if self.protocol == "http":
self.assertEqual(len(shm_status), 0)
else:
self.assertEqual(len(shm_status.regions), 0)
cshm.destroy_shared_memory_region(shm_op0_handle)

def test_unregister_after_register(self):
# Create a valid cuda shared memory region and unregister after register
shm_op0_handle = cshm.create_shared_memory_region("dummy_data", 8, 0)
self.triton_client.register_cuda_shared_memory(
"dummy_data", cshm.get_raw_handle(shm_op0_handle), 0, 8
)
self.triton_client.unregister_cuda_shared_memory("dummy_data")
shm_status = self.triton_client.get_cuda_shared_memory_status()
if self.protocol == "http":
self.assertEqual(len(shm_status), 0)
else:
self.assertEqual(len(shm_status.regions), 0)
cshm.destroy_shared_memory_region(shm_op0_handle)

def test_reregister_after_register(self):
# Create a valid cuda shared memory region and unregister after register
shm_op0_handle = cshm.create_shared_memory_region("dummy_data", 8, 0)
self.triton_client.register_cuda_shared_memory(
"dummy_data", cshm.get_raw_handle(shm_op0_handle), 0, 8
)
try:
self.triton_client.register_cuda_shared_memory(
"dummy_data", cshm.get_raw_handle(shm_op0_handle), 0, 8
)
except Exception as ex:
self.assertIn(
"shared memory region 'dummy_data' already in manager", str(ex)
)
shm_status = self.triton_client.get_cuda_shared_memory_status()
if self.protocol == "http":
self.assertEqual(len(shm_status), 1)
else:
self.assertEqual(len(shm_status.regions), 1)
cshm.destroy_shared_memory_region(shm_op0_handle)

def test_unregister_after_inference(self):
# Unregister after inference
error_msg = []
Expand Down Expand Up @@ -396,5 +400,169 @@ def test_infer_byte_size_out_of_bound(self):
self._cleanup_server(shm_handles)


class TestCudaSharedMemoryUnregister(CudaSharedMemoryTestBase):
def _test_unregister_shm_fail(self):
second_client = httpclient.InferenceServerClient("localhost:8000", verbose=True)

with self.assertRaises(InferenceServerException) as ex:
second_client.unregister_cuda_shared_memory()
self.assertIn(
"Failed to unregister the following cuda shared memory regions: input0_data ,input1_data ,output0_data ,output1_data",
str(ex.exception),
)

with self.assertRaises(InferenceServerException) as ex:
second_client.unregister_cuda_shared_memory("input0_data")
self.assertIn(
"Cannot unregister shared memory region 'input0_data', it is currently in use.",
str(ex.exception),
)

with self.assertRaises(InferenceServerException) as ex:
second_client.unregister_cuda_shared_memory("input1_data")
self.assertIn(
"Cannot unregister shared memory region 'input1_data', it is currently in use.",
str(ex.exception),
)

with self.assertRaises(InferenceServerException) as ex:
second_client.unregister_cuda_shared_memory("output0_data")
self.assertIn(
"Cannot unregister shared memory region 'output0_data', it is currently in use.",
str(ex.exception),
)

with self.assertRaises(InferenceServerException) as ex:
second_client.unregister_cuda_shared_memory("output1_data")
self.assertIn(
"Cannot unregister shared memory region 'output1_data', it is currently in use.",
str(ex.exception),
)

def _test_shm_not_found(self):
second_client = httpclient.InferenceServerClient("localhost:8000", verbose=True)

with self.assertRaises(InferenceServerException) as ex:
second_client.get_cuda_shared_memory_status("input0_data")
self.assertIn(
"Unable to find cuda shared memory region: 'input0_data'",
str(ex.exception),
)

with self.assertRaises(InferenceServerException) as ex:
second_client.get_cuda_shared_memory_status("input1_data")
self.assertIn(
"Unable to find cuda shared memory region: 'input1_data'",
str(ex.exception),
)

with self.assertRaises(InferenceServerException) as ex:
second_client.get_cuda_shared_memory_status("output0_data")
self.assertIn(
"Unable to find cuda shared memory region: 'output0_data'",
str(ex.exception),
)

with self.assertRaises(InferenceServerException) as ex:
second_client.get_cuda_shared_memory_status("output1_data")
self.assertIn(
"Unable to find cuda shared memory region: 'output1_data'",
str(ex.exception),
)

def test_unregister_shm_during_inference_http(self):
try:
self.triton_client.unregister_cuda_shared_memory()
shm_handles = self._configure_server()

inputs = [
httpclient.InferInput("INPUT0", [1, 16], "INT32"),
httpclient.InferInput("INPUT1", [1, 16], "INT32"),
]
outputs = [
httpclient.InferRequestedOutput("OUTPUT0", binary_data=True),
httpclient.InferRequestedOutput("OUTPUT1", binary_data=False),
]

inputs[0].set_shared_memory("input0_data", self.DEFAULT_SHM_BYTE_SIZE)
inputs[1].set_shared_memory("input1_data", self.DEFAULT_SHM_BYTE_SIZE)
outputs[0].set_shared_memory("output0_data", self.DEFAULT_SHM_BYTE_SIZE)
outputs[1].set_shared_memory("output1_data", self.DEFAULT_SHM_BYTE_SIZE)

async_request = self.triton_client.async_infer(
model_name="simple", inputs=inputs, outputs=outputs
)

# Ensure inference started
time.sleep(2)

# Try unregister shm regions during inference
self._test_unregister_shm_fail()

# Blocking call
async_request.get_result()

# Try unregister shm regions after inference
self.triton_client.unregister_cuda_shared_memory()
self._test_shm_not_found()

finally:
self._cleanup_server(shm_handles)

def test_unregister_shm_during_inference_grpc(self):
try:
self.triton_client.unregister_cuda_shared_memory()
shm_handles = self._configure_server()

inputs = [
grpcclient.InferInput("INPUT0", [1, 16], "INT32"),
grpcclient.InferInput("INPUT1", [1, 16], "INT32"),
]
outputs = [
grpcclient.InferRequestedOutput("OUTPUT0"),
grpcclient.InferRequestedOutput("OUTPUT1"),
]

inputs[0].set_shared_memory("input0_data", self.DEFAULT_SHM_BYTE_SIZE)
inputs[1].set_shared_memory("input1_data", self.DEFAULT_SHM_BYTE_SIZE)
outputs[0].set_shared_memory("output0_data", self.DEFAULT_SHM_BYTE_SIZE)
outputs[1].set_shared_memory("output1_data", self.DEFAULT_SHM_BYTE_SIZE)

def callback(user_data, result, error):
if error:
user_data.append(error)
else:
user_data.append(result)

user_data = []

self.triton_client.async_infer(
model_name="simple",
inputs=inputs,
outputs=outputs,
callback=partial(callback, user_data),
)

# Ensure inference started
time.sleep(2)

# Try unregister shm regions during inference
self._test_unregister_shm_fail()

# Wait until the results are available in user_data
time_out = 20
while (len(user_data) == 0) and time_out > 0:
time_out = time_out - 1
time.sleep(1)
time.sleep(2)

# Try unregister shm regions after inference
self.triton_client.unregister_cuda_shared_memory()
self._test_shm_not_found()

finally:
self._cleanup_server(shm_handles)


if __name__ == "__main__":
unittest.main()
41 changes: 41 additions & 0 deletions qa/L0_cuda_shared_memory/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,47 @@ for i in \
done
done

mkdir -p python_models/simple/1/
cp ../python_models/execute_delayed_model/model.py ./python_models/simple/1/
cp ../python_models/execute_delayed_model/config.pbtxt ./python_models/simple/
sed -i 's/KIND_CPU/KIND_GPU/g' ./python_models/simple/config.pbtxt

for client_type in http grpc; do
SERVER_ARGS="--model-repository=`pwd`/python_models --log-verbose=1 ${SERVER_ARGS_EXTRA}"
SERVER_LOG="./unregister_shm.$client_type.server.log"
run_server
if [ "$SERVER_PID" == "0" ]; then
echo -e "\n***\n*** Failed to start $SERVER\n***"
cat $SERVER_LOG
exit 1
fi

export CLIENT_TYPE=$client_type
CLIENT_LOG="./unregister_shm.$client_type.client.log"
set +e
python3 $SHM_TEST TestCudaSharedMemoryUnregister.test_unregister_shm_during_inference_$client_type >>$CLIENT_LOG 2>&1
if [ $? -ne 0 ]; then
cat $CLIENT_LOG
echo -e "\n***\n*** Test Failed\n***"
RET=1
else
check_test_results $TEST_RESULT_FILE 1
if [ $? -ne 0 ]; then
cat $TEST_RESULT_FILE
echo -e "\n***\n*** Test Result Verification Failed\n***"
RET=1
fi
fi

kill $SERVER_PID
wait $SERVER_PID
if [ $? -ne 0 ]; then
echo -e "\n***\n*** Test Server shut down non-gracefully\n***"
RET=1
fi
set -e
done

if [ $RET -eq 0 ]; then
echo -e "\n***\n*** Test Passed\n***"
else
Expand Down
Loading
Loading