Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance OTEL testing to capture and verify Cancellation Requests and Non-Decoupled model inference. #7132

Merged
merged 18 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 128 additions & 0 deletions qa/L0_trace/opentelemetry_unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@
client.infer("bls_simple", inputs, headers=headers)


class UserData:
def __init__(self):
self._completed_requests = queue.Queue()


class OpenTelemetryTest(tu.TestResultCollector):
def setUp(self):
self.collector_subprocess = subprocess.Popen(
Expand All @@ -106,12 +111,22 @@
self.ensemble_model_name = "ensemble_add_sub_int32_int32_int32"
self.bls_model_name = "bls_simple"
self.trace_context_model = "trace_context"
self.non_decoupled_model_name_ = "repeat_int32"
self.test_models = [
self.simple_model_name,
self.ensemble_model_name,
self.bls_model_name,
self.non_decoupled_model_name_,
]
self.root_span = "InferRequest"
self._user_data = UserData()
self._callback = partial(callback, self._user_data)
self._outputs = []
self.input_data = {
"IN": np.array([1], dtype=np.int32),
"DELAY": np.array([0], dtype=np.uint32),
"WAIT": np.array([0], dtype=np.uint32),
}

def tearDown(self):
self.collector_subprocess.kill()
Expand Down Expand Up @@ -199,6 +214,38 @@
)
self.assertFalse(all(entry in events for entry in compute_events))

def _check_events_cancel(self, events):
indrajit96 marked this conversation as resolved.
Show resolved Hide resolved
"""
Helper function that verifies passed events contain expected entries.

Args:
span_name (str): name of a span.
events (List[str]): list of event names, collected for the span with the name `span_name`.
"""
print("Printing Events2")
print(events)
root_events_grpc = [
"GRPC_WAITREAD_START",
"GRPC_WAITREAD_END",
"INFER_RESPONSE_COMPLETE",
"GRPC_SEND_START",
"GRPC_SEND_END",
]
root_events_http = [
"HTTP_RECV_START",
"HTTP_RECV_END",
"INFER_RESPONSE_COMPLETE",
"HTTP_SEND_START",
"HTTP_SEND_END",
]

if "HTTP" in events:
self.assertTrue(all(entry in events for entry in root_events_http))
self.assertFalse(all(entry in events for entry in root_events_grpc))
elif "GRPC" in events:
self.assertTrue(all(entry in events for entry in root_events_grpc))
self.assertFalse(all(entry in events for entry in root_events_http))

def _test_resource_attributes(self, attributes):
"""
Helper function that verifies passed span attributes.
Expand Down Expand Up @@ -270,6 +317,11 @@
"Unexpeced number of " + name + " spans collected",
)

def _verify_contents_cancel(self, spans):
indrajit96 marked this conversation as resolved.
Show resolved Hide resolved
span_events = spans["events"]
event_names_only = [event["name"] for event in span_events]
self._check_events_cancel(event_names_only)

def _verify_nesting(self, spans, expected_parent_span_dict):
"""
Helper function that checks parent-child relationships between
Expand Down Expand Up @@ -339,6 +391,18 @@
),
)

def _test_trace_cancel(
self,
):
time.sleep(COLLECTOR_TIMEOUT)
traces = self._parse_trace_log(self.filename)

parsed_spans = traces[0]["resourceSpans"][0]["scopeSpans"][0]["spans"]
root_span = [
github-advanced-security[bot] marked this conversation as resolved.
Fixed
Show resolved Hide resolved
entry for entry in parsed_spans if entry["name"] == "InferRequest"
][0]
self._verify_contents_cancel(root_span)

def _test_trace(
self,
headers,
Expand Down Expand Up @@ -420,6 +484,24 @@
expected_parent_span_dict=expected_parent_span_dict,
)

def _test_non_decoupled_trace(self, headers=None):
"""
Helper function, that collects trace for non decoupled model and verifies it.
"""
expected_number_of_spans = 3
expected_counts = dict(
{"compute": 1, self.non_decoupled_model_name_: 1, self.root_span: 1}
)
expected_parent_span_dict = dict(
{"InferRequest": ["repeat_int32"], "repeat_int32": ["compute"]}
)
self._test_trace(
headers=headers,
expected_number_of_spans=expected_number_of_spans,
expected_counts=expected_counts,
expected_parent_span_dict=expected_parent_span_dict,
indrajit96 marked this conversation as resolved.
Show resolved Hide resolved
)

def _test_bls_trace(self, headers=None):
"""
Helper function, that specifies expected parameters to evaluate trace,
Expand Down Expand Up @@ -527,6 +609,52 @@

self._test_simple_trace()

def test_grpc_trace_simple_model_cancel(self):
"""
Tests trace, collected from executing one inference request and cancelling the request
for a `simple` model and GRPC client.
"""
triton_client_grpc = grpcclient.InferenceServerClient(
"localhost:8001", verbose=True
)
inputs = prepare_data(grpcclient)
future = triton_client_grpc.async_infer(
model_name=self.simple_model_name,
inputs=inputs,
callback=self._callback,
outputs=self._outputs,
)
time.sleep(2) # ensure the inference has started
future.cancel()
time.sleep(0.1) # context switch
self._test_trace_cancel()

def test_non_decoupled(self):
"""
Tests trace, collected from executing one inference request and cancelling the request
for a `simple` model and GRPC client.
"""
inputs = [
grpcclient.InferInput("IN", [1], "INT32").set_data_from_numpy(
self.input_data["IN"]
),
grpcclient.InferInput("DELAY", [1], "UINT32").set_data_from_numpy(
self.input_data["DELAY"]
),
grpcclient.InferInput("WAIT", [1], "UINT32").set_data_from_numpy(
self.input_data["WAIT"]
),
]

triton_client = grpcclient.InferenceServerClient(
url="localhost:8001", verbose=True
)
# Expect the inference is successful
res = triton_client.infer(
github-advanced-security[bot] marked this conversation as resolved.
Fixed
Show resolved Hide resolved
model_name=self.non_decoupled_model_name_, inputs=inputs
)
self._test_non_decoupled_trace()

def test_grpc_trace_simple_model_context_propagation(self):
"""
Tests trace, collected from executing one inference request
Expand Down
8 changes: 6 additions & 2 deletions qa/L0_trace/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ MODELSDIR=`pwd`/trace_models

SERVER=/opt/tritonserver/bin/tritonserver
source ../common/util.sh

rm -f *.log
rm -f *.log.*
rm -fr $MODELSDIR && mkdir -p $MODELSDIR
Expand All @@ -79,6 +78,10 @@ cp -r $DATADIR/$MODELBASE $MODELSDIR/simple && \
sed -i "s/model_name:.*/model_name: \"simple\"/" config.pbtxt) && \
mkdir -p $MODELSDIR/bls_simple/1 && cp $BLSDIR/bls_simple.py $MODELSDIR/bls_simple/1/model.py

# set up repeat_int32 model
cp -r ../L0_decoupled/models/repeat_int32 $MODELSDIR
sed -i "s/decoupled: True/decoupled: False/" $MODELSDIR/repeat_int32/config.pbtxt

RET=0

# Helpers =======================================
Expand Down Expand Up @@ -740,7 +743,7 @@ rm collected_traces.json*
# Unittests then check that produced spans have expected format and events
OPENTELEMETRY_TEST=opentelemetry_unittest.py
OPENTELEMETRY_LOG="opentelemetry_unittest.log"
EXPECTED_NUM_TESTS="14"
EXPECTED_NUM_TESTS="16"

# Set up repo and args for SageMaker
export SAGEMAKER_TRITON_DEFAULT_MODEL_NAME="simple"
Expand All @@ -757,6 +760,7 @@ mkdir -p $MODELSDIR/trace_context/1 && cp ./trace_context.py $MODELSDIR/trace_co

SERVER_ARGS="--allow-sagemaker=true --model-control-mode=explicit \
--load-model=simple --load-model=ensemble_add_sub_int32_int32_int32 \
--load-mode=repeat_int32 \
--load-model=bls_simple --trace-config=level=TIMESTAMPS \
--load-model=trace_context --trace-config=rate=1 \
--trace-config=count=-1 --trace-config=mode=opentelemetry \
Expand Down
Loading