Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate system shared memory region size when registering a region #7093

Merged
merged 6 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 142 additions & 88 deletions qa/L0_shared_memory/shared_memory_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,22 @@


class SharedMemoryTest(tu.TestResultCollector):
def setUp(self):
self._setup_client()

def _setup_client(self):
self.protocol = os.environ.get("CLIENT_TYPE", "http")
if self.protocol == "http":
self.url = "localhost:8000"
self.triton_client = httpclient.InferenceServerClient(
self.url, verbose=True
)
else:
self.url = "localhost:8001"
self.triton_client = grpcclient.InferenceServerClient(
self.url, verbose=True
)

def test_invalid_create_shm(self):
# Raises error since tried to create invalid system shared memory region
try:
Expand All @@ -54,99 +70,111 @@ def test_invalid_create_shm(self):

def test_valid_create_set_register(self):
# Create a valid system shared memory region, fill data in it and register
if _protocol == "http":
triton_client = httpclient.InferenceServerClient(_url, verbose=True)
else:
triton_client = grpcclient.InferenceServerClient(_url, verbose=True)
shm_op0_handle = shm.create_shared_memory_region("dummy_data", "/dummy_data", 8)
shm.set_shared_memory_region(
shm_op0_handle, [np.array([1, 2], dtype=np.float32)]
)
triton_client.register_system_shared_memory("dummy_data", "/dummy_data", 8)
shm_status = triton_client.get_system_shared_memory_status()
if _protocol == "http":
self.triton_client.register_system_shared_memory("dummy_data", "/dummy_data", 8)
shm_status = self.triton_client.get_system_shared_memory_status()
if self.protocol == "http":
self.assertTrue(len(shm_status) == 1)
else:
self.assertTrue(len(shm_status.regions) == 1)
shm.destroy_shared_memory_region(shm_op0_handle)

def test_unregister_before_register(self):
# Create a valid system shared memory region and unregister before register
if _protocol == "http":
triton_client = httpclient.InferenceServerClient(_url, verbose=True)
else:
triton_client = grpcclient.InferenceServerClient(_url, verbose=True)
shm_op0_handle = shm.create_shared_memory_region("dummy_data", "/dummy_data", 8)
triton_client.unregister_system_shared_memory("dummy_data")
shm_status = triton_client.get_system_shared_memory_status()
if _protocol == "http":
self.triton_client.unregister_system_shared_memory("dummy_data")
shm_status = self.triton_client.get_system_shared_memory_status()
if self.protocol == "http":
self.assertTrue(len(shm_status) == 0)
else:
self.assertTrue(len(shm_status.regions) == 0)
shm.destroy_shared_memory_region(shm_op0_handle)

def test_unregister_after_register(self):
# Create a valid system shared memory region and unregister after register
if _protocol == "http":
triton_client = httpclient.InferenceServerClient(_url, verbose=True)
else:
triton_client = grpcclient.InferenceServerClient(_url, verbose=True)
shm_op0_handle = shm.create_shared_memory_region("dummy_data", "/dummy_data", 8)
triton_client.register_system_shared_memory("dummy_data", "/dummy_data", 8)
triton_client.unregister_system_shared_memory("dummy_data")
shm_status = triton_client.get_system_shared_memory_status()
if _protocol == "http":
self.triton_client.register_system_shared_memory("dummy_data", "/dummy_data", 8)
self.triton_client.unregister_system_shared_memory("dummy_data")
shm_status = self.triton_client.get_system_shared_memory_status()
if self.protocol == "http":
self.assertTrue(len(shm_status) == 0)
else:
self.assertTrue(len(shm_status.regions) == 0)
shm.destroy_shared_memory_region(shm_op0_handle)

def test_reregister_after_register(self):
# Create a valid system shared memory region and unregister after register
if _protocol == "http":
triton_client = httpclient.InferenceServerClient(_url, verbose=True)
else:
triton_client = grpcclient.InferenceServerClient(_url, verbose=True)
shm_op0_handle = shm.create_shared_memory_region("dummy_data", "/dummy_data", 8)
triton_client.register_system_shared_memory("dummy_data", "/dummy_data", 8)
self.triton_client.register_system_shared_memory("dummy_data", "/dummy_data", 8)
try:
triton_client.register_system_shared_memory("dummy_data", "/dummy_data", 8)
self.triton_client.register_system_shared_memory(
"dummy_data", "/dummy_data", 8
)
except Exception as ex:
self.assertTrue(
"shared memory region 'dummy_data' already in manager" in str(ex)
)
shm_status = triton_client.get_system_shared_memory_status()
if _protocol == "http":
shm_status = self.triton_client.get_system_shared_memory_status()
if self.protocol == "http":
self.assertTrue(len(shm_status) == 1)
else:
self.assertTrue(len(shm_status.regions) == 1)
shm.destroy_shared_memory_region(shm_op0_handle)

def _configure_sever(self):
def _configure_server(
self, create_byte_size=64, register_byte_size=64, register_offset=0
):
"""Creates and registers shared memory regions for testing.

Parameters
----------
create_byte_size: int
Size of each system shared memory region to create.
NOTE: This should be sufficiently large to hold the inputs/outputs
stored in shared memory.

register_byte_size: int
Size of each system shared memory region to register with server.
NOTE: The (offset + register_byte_size) should be less than or equal
to the create_byte_size. Otherwise an exception will be raised for
an invalid set of registration args.

register_offset: int
Offset into the shared memory object to start the registered region.

"""
shm_ip0_handle = shm.create_shared_memory_region(
"input0_data", "/input0_data", 64
"input0_data", "/input0_data", create_byte_size
)
shm_ip1_handle = shm.create_shared_memory_region(
"input1_data", "/input1_data", 64
"input1_data", "/input1_data", create_byte_size
)
shm_op0_handle = shm.create_shared_memory_region(
"output0_data", "/output0_data", 64
"output0_data", "/output0_data", create_byte_size
)
shm_op1_handle = shm.create_shared_memory_region(
"output1_data", "/output1_data", 64
"output1_data", "/output1_data", create_byte_size
)
# Implicit assumption that input and output byte_sizes are 64 bytes for now
input0_data = np.arange(start=0, stop=16, dtype=np.int32)
input1_data = np.ones(shape=16, dtype=np.int32)
shm.set_shared_memory_region(shm_ip0_handle, [input0_data])
shm.set_shared_memory_region(shm_ip1_handle, [input1_data])
if _protocol == "http":
triton_client = httpclient.InferenceServerClient(_url, verbose=True)
else:
triton_client = grpcclient.InferenceServerClient(_url, verbose=True)
triton_client.register_system_shared_memory("input0_data", "/input0_data", 64)
triton_client.register_system_shared_memory("input1_data", "/input1_data", 64)
triton_client.register_system_shared_memory("output0_data", "/output0_data", 64)
triton_client.register_system_shared_memory("output1_data", "/output1_data", 64)
self.triton_client.register_system_shared_memory(
"input0_data", "/input0_data", register_byte_size, offset=register_offset
)
self.triton_client.register_system_shared_memory(
"input1_data", "/input1_data", register_byte_size, offset=register_offset
)
self.triton_client.register_system_shared_memory(
"output0_data", "/output0_data", register_byte_size, offset=register_offset
)
self.triton_client.register_system_shared_memory(
"output1_data", "/output1_data", register_byte_size, offset=register_offset
)
return [shm_ip0_handle, shm_ip1_handle, shm_op0_handle, shm_op1_handle]

def _cleanup_server(self, shm_handles):
Expand All @@ -168,16 +196,14 @@ def _basic_inference(
input1_data = np.ones(shape=16, dtype=np.int32)
inputs = []
outputs = []
if _protocol == "http":
triton_client = httpclient.InferenceServerClient(_url, verbose=True)
if self.protocol == "http":
inputs.append(httpclient.InferInput("INPUT0", [1, 16], "INT32"))
inputs.append(httpclient.InferInput("INPUT1", [1, 16], "INT32"))
outputs.append(httpclient.InferRequestedOutput("OUTPUT0", binary_data=True))
outputs.append(
httpclient.InferRequestedOutput("OUTPUT1", binary_data=False)
)
else:
triton_client = grpcclient.InferenceServerClient(_url, verbose=True)
inputs.append(grpcclient.InferInput("INPUT0", [1, 16], "INT32"))
inputs.append(grpcclient.InferInput("INPUT1", [1, 16], "INT32"))
outputs.append(grpcclient.InferRequestedOutput("OUTPUT0"))
Expand All @@ -196,11 +222,11 @@ def _basic_inference(
outputs[1].set_shared_memory("output1_data", 64, offset=shm_output_offset)

try:
results = triton_client.infer(
results = self.triton_client.infer(
"simple", inputs, model_version="", outputs=outputs
)
output = results.get_output("OUTPUT0")
if _protocol == "http":
if self.protocol == "http":
output_datatype = output["datatype"]
output_shape = output["shape"]
else:
Expand All @@ -220,19 +246,15 @@ def _basic_inference(
def test_unregister_after_inference(self):
# Unregister after inference
error_msg = []
shm_handles = self._configure_sever()
shm_handles = self._configure_server()
self._basic_inference(
shm_handles[0], shm_handles[1], shm_handles[2], shm_handles[3], error_msg
)
if len(error_msg) > 0:
raise Exception(str(error_msg))
if _protocol == "http":
triton_client = httpclient.InferenceServerClient(_url, verbose=True)
else:
triton_client = grpcclient.InferenceServerClient(_url, verbose=True)
triton_client.unregister_system_shared_memory("output0_data")
shm_status = triton_client.get_system_shared_memory_status()
if _protocol == "http":
self.triton_client.unregister_system_shared_memory("output0_data")
shm_status = self.triton_client.get_system_shared_memory_status()
if self.protocol == "http":
self.assertTrue(len(shm_status) == 3)
else:
self.assertTrue(len(shm_status.regions) == 3)
Expand All @@ -241,11 +263,7 @@ def test_unregister_after_inference(self):
def test_register_after_inference(self):
# Register after inference
error_msg = []
shm_handles = self._configure_sever()
if _protocol == "http":
triton_client = httpclient.InferenceServerClient(_url, verbose=True)
else:
triton_client = grpcclient.InferenceServerClient(_url, verbose=True)
shm_handles = self._configure_server()
self._basic_inference(
shm_handles[0], shm_handles[1], shm_handles[2], shm_handles[3], error_msg
)
Expand All @@ -254,9 +272,11 @@ def test_register_after_inference(self):
shm_ip2_handle = shm.create_shared_memory_region(
"input2_data", "/input2_data", 64
)
triton_client.register_system_shared_memory("input2_data", "/input2_data", 64)
shm_status = triton_client.get_system_shared_memory_status()
if _protocol == "http":
self.triton_client.register_system_shared_memory(
"input2_data", "/input2_data", 64
)
shm_status = self.triton_client.get_system_shared_memory_status()
if self.protocol == "http":
self.assertTrue(len(shm_status) == 5)
else:
self.assertTrue(len(shm_status.regions) == 5)
Expand All @@ -266,15 +286,13 @@ def test_register_after_inference(self):
def test_too_big_shm(self):
# Shared memory input region larger than needed - Throws error
error_msg = []
shm_handles = self._configure_sever()
shm_handles = self._configure_server()
shm_ip2_handle = shm.create_shared_memory_region(
"input2_data", "/input2_data", 128
)
if _protocol == "http":
triton_client = httpclient.InferenceServerClient(_url, verbose=True)
else:
triton_client = grpcclient.InferenceServerClient(_url, verbose=True)
triton_client.register_system_shared_memory("input2_data", "/input2_data", 128)
self.triton_client.register_system_shared_memory(
"input2_data", "/input2_data", 128
)
self._basic_inference(
shm_handles[0],
shm_ip2_handle,
Expand All @@ -295,7 +313,7 @@ def test_too_big_shm(self):
def test_mixed_raw_shm(self):
# Mix of shared memory and RAW inputs
error_msg = []
shm_handles = self._configure_sever()
shm_handles = self._configure_server()
input1_data = np.ones(shape=16, dtype=np.int32)
self._basic_inference(
shm_handles[0], [input1_data], shm_handles[2], shm_handles[3], error_msg
Expand All @@ -306,19 +324,15 @@ def test_mixed_raw_shm(self):

def test_unregisterall(self):
# Unregister all shared memory blocks
shm_handles = self._configure_sever()
if _protocol == "http":
triton_client = httpclient.InferenceServerClient(_url, verbose=True)
else:
triton_client = grpcclient.InferenceServerClient(_url, verbose=True)
status_before = triton_client.get_system_shared_memory_status()
if _protocol == "http":
shm_handles = self._configure_server()
status_before = self.triton_client.get_system_shared_memory_status()
if self.protocol == "http":
self.assertTrue(len(status_before) == 4)
else:
self.assertTrue(len(status_before.regions) == 4)
triton_client.unregister_system_shared_memory()
status_after = triton_client.get_system_shared_memory_status()
if _protocol == "http":
self.triton_client.unregister_system_shared_memory()
status_after = self.triton_client.get_system_shared_memory_status()
if self.protocol == "http":
self.assertTrue(len(status_after) == 0)
else:
self.assertTrue(len(status_after.regions) == 0)
Expand All @@ -327,8 +341,8 @@ def test_unregisterall(self):
def test_infer_offset_out_of_bound(self):
# Shared memory offset outside output region - Throws error
error_msg = []
shm_handles = self._configure_sever()
if _protocol == "http":
shm_handles = self._configure_server()
if self.protocol == "http":
# -32 when placed in an int64 signed type, to get a negative offset
# by overflowing
offset = 2**64 - 32
Expand All @@ -348,11 +362,51 @@ def test_infer_offset_out_of_bound(self):
self.assertIn("Invalid offset for shared memory region", error_msg[0])
self._cleanup_server(shm_handles)

def test_register_out_of_bound(self):
create_byte_size = 64

# Verify various edge cases of registered region size (offset+byte_size)
# don't go out of bounds of the actual created shm file object's size.
with self.assertRaisesRegex(
utils.InferenceServerException,
"failed to register shared memory region.*invalid args",
):
self._configure_server(
create_byte_size=create_byte_size,
register_byte_size=create_byte_size + 1,
register_offset=0,
)

with self.assertRaisesRegex(
utils.InferenceServerException,
"failed to register shared memory region.*invalid args",
):
self._configure_server(
create_byte_size=create_byte_size,
register_byte_size=create_byte_size,
register_offset=1,
)

with self.assertRaisesRegex(
utils.InferenceServerException,
"failed to register shared memory region.*invalid args",
):
self._configure_server(
create_byte_size=create_byte_size,
register_byte_size=1,
register_offset=create_byte_size,
)

with self.assertRaisesRegex(
utils.InferenceServerException,
"failed to register shared memory region.*invalid args",
):
self._configure_server(
create_byte_size=create_byte_size,
register_byte_size=0,
register_offset=create_byte_size + 1,
)


if __name__ == "__main__":
_protocol = os.environ.get("CLIENT_TYPE", "http")
if _protocol == "http":
_url = "localhost:8000"
else:
_url = "localhost:8001"
unittest.main()
Loading
Loading