Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add test for shape validation #7195

Merged
merged 8 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 158 additions & 0 deletions qa/L0_input_validation/input_shape_validation_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
#!/usr/bin/env python
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import asyncio
from pathlib import Path
from subprocess import Popen
from tempfile import TemporaryDirectory
from typing import Optional
github-advanced-security[bot] marked this conversation as resolved.
Fixed
Show resolved Hide resolved

import numpy as np
import pytest
import torch
from tritonclient.grpc.aio import InferenceServerClient, InferInput
from tritonclient.utils import np_to_triton_dtype

GRPC_PORT = 9653
FIXED_LAST_DIM = 8


def repo_dir():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like the idea of moving all logic within Python, can this be common utils?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree but will push back for now. This function needs to be generic then we can refactor other tests.

with TemporaryDirectory() as model_repo:
(Path(model_repo) / "pt_identity" / "1").mkdir(parents=True, exist_ok=True)

torch.jit.save(
torch.jit.script(torch.nn.Identity()),
model_repo + "/pt_identity/1/model.pt",
)

pbtxt = f"""
name: "pt_identity"
backend: "pytorch"
max_batch_size: 8

input [
{{
name: "INPUT0"
data_type: TYPE_FP32
dims: [ {FIXED_LAST_DIM} ]
}}
]
output [
{{
name: "OUTPUT0"
data_type: TYPE_FP32
dims: [ {FIXED_LAST_DIM} ]
}}
]
# ensure we batch requests together
dynamic_batching {{
max_queue_delay_microseconds: {int(5e6)}
}}
"""
with open(model_repo + "/pt_identity/config.pbtxt", "w") as f:
f.write(pbtxt)

yield model_repo


async def poll_readiness(client: InferenceServerClient, server_proc):
while True:
if server_proc is not None and (ret_code := server_proc.poll()) is not None:
_, stderr = server_proc.communicate()
print(stderr)
raise Exception(f"Tritonserver died with return code {ret_code}")
try:
if await client.is_server_ready():
break
except: # noqa: E722
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
pass
await asyncio.sleep(0.5)


async def server_terminated(client: InferenceServerClient, server_proc):
if server_proc is not None and (ret_code := server_proc.poll()) is not None:
_, stderr = server_proc.communicate()
print(stderr)
raise Exception(f"Tritonserver died with return code {ret_code}")


@pytest.mark.asyncio
async def test_shape_overlapped(repo_dir: str):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add summary on how the test operates?

with Popen(
[
"/opt/tritonserver/bin/tritonserver",
"--model-repository",
repo_dir,
"--grpc-port",
str(GRPC_PORT),
]
) as server:
await poll_readiness(
InferenceServerClient("localhost:" + str(GRPC_PORT)), server
)

alice = InferenceServerClient("localhost:" + str(GRPC_PORT))
bob = InferenceServerClient("localhost:" + str(GRPC_PORT))

input_data_1 = np.arange(FIXED_LAST_DIM + 2)[None].astype(np.float32)
print(f"{input_data_1=}")
inputs_1 = [
InferInput(
"INPUT0", input_data_1.shape, np_to_triton_dtype(input_data_1.dtype)
),
]
inputs_1[0].set_data_from_numpy(input_data_1)
# Compromised input shape
inputs_1[0].set_shape((1, FIXED_LAST_DIM))

input_data_2 = 100 + np.arange(FIXED_LAST_DIM)[None].astype(np.float32)
print(f"{input_data_2=}")
inputs_2 = [
InferInput(
"INPUT0",
shape=input_data_2.shape,
datatype=np_to_triton_dtype(input_data_2.dtype),
)
]
inputs_2[0].set_data_from_numpy(input_data_2)
with pytest.raises(Exception) as e_info:
server_terminated(
InferenceServerClient("localhost:" + str(GRPC_PORT)), server
)
t1 = asyncio.create_task(
Fixed Show fixed Hide fixed
alice.infer("pt_identity", inputs_1)
) # should fail here
t2 = asyncio.create_task(bob.infer("pt_identity", inputs_2))
github-advanced-security[bot] marked this conversation as resolved.
Fixed
Show resolved Hide resolved

# alice_result, bob_result = await asyncio.gather(t1, t2)
# print(f"{alice_result.as_numpy('OUTPUT0')=}")
# print(f"{bob_result.as_numpy('OUTPUT0')=}")
# server.terminate()
# assert np.allclose(
# bob_result.as_numpy("OUTPUT0"), input_data_2
# ), "Bob's result should be the same as input"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uncomment?

Copy link
Contributor Author

@jbkyang-nvi jbkyang-nvi May 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove? I think anything past the assert is not needed

11 changes: 10 additions & 1 deletion qa/L0_input_validation/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ RET=0

CLIENT_LOG="./input_validation_client.log"
TEST_PY=./input_validation_test.py
SHAPE_TEST_PY=./input_shape_validation_test.py
TEST_RESULT_FILE='./test_results.txt'

export CUDA_VISIBLE_DEVICES=0
Expand All @@ -64,14 +65,22 @@ set +e
python3 -m pytest --junitxml="input_validation.report.xml" $TEST_PY >> $CLIENT_LOG 2>&1

if [ $? -ne 0 ]; then
echo -e "\n***\n*** python_unittest.py FAILED. \n***"
echo -e "\n***\n*** input_validation_test.py FAILED. \n***"
RET=1
fi
set -e

kill $SERVER_PID
wait $SERVER_PID

pip install torch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why pytorch? Should be more flexible using Python backend

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is taken from this reproducer

python3 -m pytest $SHAPE_TEST_PY >> $CLIENT_LOG 2>&1
if [ $? -ne 0 ]; then
echo -e "\n***\n*** input_shape_validation_test.py FAILED. \n***"
RET=1

fi

if [ $RET -eq 0 ]; then
echo -e "\n***\n*** Input Validation Test Passed\n***"
else
Expand Down
Loading