Skip to content

Commit

Permalink
Enhance OTEL testing to capture and verify Cancellation Requests and …
Browse files Browse the repository at this point in the history
…Non-Decoupled model inference. (#7132)

* Enhance OTEL testing
  • Loading branch information
indrajit96 authored May 24, 2024
1 parent b8b0cad commit 76e53c9
Show file tree
Hide file tree
Showing 4 changed files with 298 additions and 9 deletions.
49 changes: 49 additions & 0 deletions qa/L0_trace/models/input_all_required/1/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import json
import time

import numpy as np
import triton_python_backend_utils as pb_utils


class TritonPythonModel:
def initialize(self, args):
self.model_config = json.loads(args["model_config"])

def execute(self, requests):
"""This function is called on inference request."""
# Less than collector timeout which is 10
time.sleep(2)
responses = []
for _ in requests:
# Include one of each specially parsed JSON value: nan, inf, and -inf
out_0 = np.array([1], dtype=np.float32)
out_tensor_0 = pb_utils.Tensor("OUTPUT0", out_0)
responses.append(pb_utils.InferenceResponse([out_tensor_0]))

return responses
55 changes: 55 additions & 0 deletions qa/L0_trace/models/input_all_required/config.pbtxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

name: "input_all_required"
backend: "python"
input [
{
name: "INPUT0"
data_type: TYPE_FP32
dims: [ -1 ]
},
{
name: "INPUT1"
data_type: TYPE_FP32
dims: [ -1 ]
},
{
name: "INPUT2"
data_type: TYPE_FP32
dims: [ -1 ]
}
]

output [
{
name: "OUTPUT0"
data_type: TYPE_FP32
dims: [ 1 ]
}
]

instance_group [{ kind: KIND_CPU }]
179 changes: 172 additions & 7 deletions qa/L0_trace/opentelemetry_unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import sys

sys.path.append("../common")
import concurrent.futures
import json
import queue
import re
Expand Down Expand Up @@ -82,6 +83,11 @@ def send_bls_request(model_name="simple", headers=None):
client.infer("bls_simple", inputs, headers=headers)


class UserData:
def __init__(self):
self._completed_requests = queue.Queue()


class OpenTelemetryTest(tu.TestResultCollector):
def setUp(self):
self.collector_subprocess = subprocess.Popen(
Expand All @@ -104,14 +110,27 @@ def setUp(self):
)
self.simple_model_name = "simple"
self.ensemble_model_name = "ensemble_add_sub_int32_int32_int32"
self.input_all_required_model_name = "input_all_required"
self.cancel_queue_model_name = "dynamic_batch"
self.bls_model_name = "bls_simple"
self.trace_context_model = "trace_context"
self.non_decoupled_model_name_ = "repeat_int32"
self.test_models = [
self.simple_model_name,
self.ensemble_model_name,
self.bls_model_name,
self.non_decoupled_model_name_,
self.cancel_queue_model_name,
]
self.root_span = "InferRequest"
self._user_data = UserData()
self._callback = partial(callback, self._user_data)
self._outputs = []
self.input_data = {
"IN": np.array([1], dtype=np.int32),
"DELAY": np.array([0], dtype=np.uint32),
"WAIT": np.array([0], dtype=np.uint32),
}

def tearDown(self):
self.collector_subprocess.kill()
Expand All @@ -120,6 +139,22 @@ def tearDown(self):
test_name = unittest.TestCase.id(self).split(".")[-1]
shutil.copyfile(self.filename, self.filename + "_" + test_name + ".log")

def _get_inputs(self, batch_size):
shape = [batch_size, 8]
inputs = [grpcclient.InferInput("INPUT0", shape, "FP32")]
inputs[0].set_data_from_numpy(np.ones(shape, dtype=np.float32))
return inputs

def _generate_callback_and_response_pair(self):
response = {"responded": False, "result": None, "error": None}

def callback_queue(result, error):
response["responded"] = True
response["result"] = result
response["error"] = error

return callback_queue, response

def _parse_trace_log(self, trace_log):
"""
Helper function that parses file, containing collected traces.
Expand All @@ -138,7 +173,7 @@ def _parse_trace_log(self, trace_log):

return traces

def _check_events(self, span_name, events):
def _check_events(self, span_name, events, is_cancelled):
"""
Helper function that verifies passed events contain expected entries.
Expand All @@ -160,6 +195,14 @@ def _check_events(self, span_name, events):
"GRPC_SEND_START",
"GRPC_SEND_END",
]
cancel_root_events_http = [
"HTTP_RECV_START",
"HTTP_RECV_END",
]
cancel_root_events_grpc = [
"GRPC_WAITREAD_START",
"GRPC_WAITREAD_END",
]
request_events = ["REQUEST_START", "QUEUE_START", "REQUEST_END"]
compute_events = [
"COMPUTE_START",
Expand All @@ -180,15 +223,21 @@ def _check_events(self, span_name, events):
elif span_name == self.root_span:
# Check that root span has INFER_RESPONSE_COMPLETE, _RECV/_WAITREAD
# and _SEND events (and only them)
if is_cancelled == True:
root_events_http = cancel_root_events_http
root_events_grpc = cancel_root_events_grpc

if "HTTP" in events:
self.assertTrue(all(entry in events for entry in root_events_http))
self.assertFalse(all(entry in events for entry in root_events_grpc))

elif "GRPC" in events:
self.assertTrue(all(entry in events for entry in root_events_grpc))
self.assertFalse(all(entry in events for entry in root_events_http))
self.assertFalse(all(entry in events for entry in request_events))
self.assertFalse(all(entry in events for entry in compute_events))

if is_cancelled == False:
self.assertFalse(all(entry in events for entry in request_events))
self.assertFalse(all(entry in events for entry in compute_events))

elif span_name in self.test_models:
# Check that all request related events (and only them)
Expand Down Expand Up @@ -232,7 +281,7 @@ def _test_resource_attributes(self, attributes):
),
)

def _verify_contents(self, spans, expected_counts):
def _verify_contents(self, spans, expected_counts, is_cancelled):
"""
Helper function that:
* iterates over `spans` and for every span it verifies that proper events are collected
Expand All @@ -247,6 +296,7 @@ def _verify_contents(self, spans, expected_counts):
and `events` are required.
expected_counts (dict): dictionary, containing expected spans in the form:
span_name : #expected_number_of_entries
is_cancelled (bool): boolean, is true if called by cancelled workflow
"""

span_names = []
Expand All @@ -256,7 +306,7 @@ def _verify_contents(self, spans, expected_counts):
span_names.append(span_name)
span_events = span["events"]
event_names_only = [event["name"] for event in span_events]
self._check_events(span_name, event_names_only)
self._check_events(span_name, event_names_only, is_cancelled)

self.assertEqual(
len(span_names),
Expand Down Expand Up @@ -339,6 +389,24 @@ def _verify_headers_propagated_from_client_if_any(self, root_span, headers):
),
)

def _test_trace_cancel(self, is_queued):
# We want to capture a cancellation request traces WHILE the inference is in the COMPUTE stage.
# Because the model "input_all_required" has a delay/wait in the compute phase so the cancellation request can be send while the request is waiting in the compute phase.
# The idea here is to wait before we try and read the traces from the file.
time.sleep(2 * COLLECTOR_TIMEOUT)
traces = self._parse_trace_log(self.filename)
if is_queued == False:
expected_counts = dict(
{"compute": 1, self.input_all_required_model_name: 1, self.root_span: 1}
)
else:
# Compute is expected to be 0 as cancelled in queue
expected_counts = dict(
{"compute": 0, self.cancel_queue_model_name: 1, self.root_span: 1}
)
parsed_spans = traces[0]["resourceSpans"][0]["scopeSpans"][0]["spans"]
self._verify_contents(parsed_spans, expected_counts, is_cancelled=True)

def _test_trace(
self,
headers,
Expand Down Expand Up @@ -396,8 +464,7 @@ def _test_trace(
entry for entry in parsed_spans if entry["name"] == "InferRequest"
][0]
self.assertEqual(len(parsed_spans), expected_number_of_spans)

self._verify_contents(parsed_spans, expected_counts)
self._verify_contents(parsed_spans, expected_counts, is_cancelled=False)
self._verify_nesting(parsed_spans, expected_parent_span_dict)
self._verify_headers_propagated_from_client_if_any(root_span, headers)

Expand All @@ -420,6 +487,24 @@ def _test_simple_trace(self, headers=None):
expected_parent_span_dict=expected_parent_span_dict,
)

def _test_non_decoupled_trace(self, headers=None):
"""
Helper function, that collects trace for non decoupled model and verifies it.
"""
expected_number_of_spans = 3
expected_counts = dict(
{"compute": 1, self.non_decoupled_model_name_: 1, self.root_span: 1}
)
expected_parent_span_dict = dict(
{"InferRequest": ["repeat_int32"], "repeat_int32": ["compute"]}
)
self._test_trace(
headers=headers,
expected_number_of_spans=expected_number_of_spans,
expected_counts=expected_counts,
expected_parent_span_dict=expected_parent_span_dict,
)

def _test_bls_trace(self, headers=None):
"""
Helper function, that specifies expected parameters to evaluate trace,
Expand Down Expand Up @@ -527,6 +612,86 @@ def test_grpc_trace_simple_model(self):

self._test_simple_trace()

def test_grpc_trace_all_input_required_model_cancel(self):
"""
Tests trace, collected from executing one inference request and cancelling the request
for a model and GRPC client. Expects only 2 GRPC stage events
"""
triton_client_grpc = grpcclient.InferenceServerClient(
"localhost:8001", verbose=True
)
inputs = []
inputs.append(grpcclient.InferInput("INPUT0", [1], "FP32"))
inputs[0].set_data_from_numpy(np.arange(1, dtype=np.float32))
inputs.append(grpcclient.InferInput("INPUT1", [1], "FP32"))
inputs[1].set_data_from_numpy(np.arange(1, dtype=np.float32))
inputs.append(grpcclient.InferInput("INPUT2", [1], "FP32"))
inputs[2].set_data_from_numpy(np.arange(1, dtype=np.float32))
future = triton_client_grpc.async_infer(
model_name=self.input_all_required_model_name,
inputs=inputs,
callback=self._callback,
outputs=self._outputs,
)
time.sleep(2) # ensure the inference has started
future.cancel()
time.sleep(0.1) # context switch
self._test_trace_cancel(is_queued=False)

# Test queued requests on dynamic batch scheduler can be cancelled
def test_grpc_trace_model_cancel_in_queue(self):
"""
Tests trace, collected from executing one inference request and cancelling the request
for a model and GRPC client while the request is in queue. Expects 0 compute stage traces
"""
model_name = self.cancel_queue_model_name
triton_client_grpc = grpcclient.InferenceServerClient(
"localhost:8001", verbose=True
)
with concurrent.futures.ThreadPoolExecutor() as pool:
# Saturate the slots on the model
saturate_thread = pool.submit(
triton_client_grpc.infer, model_name, self._get_inputs(batch_size=1)
)
time.sleep(2) # ensure the slots are filled
# The next request should be queued
callback, response = self._generate_callback_and_response_pair()
future = triton_client_grpc.async_infer(
model_name, self._get_inputs(batch_size=1), callback
)
time.sleep(0.2) # ensure the request is queued
future.cancel()
# Join saturating thread
saturate_thread.result()
self._test_trace_cancel(is_queued=True)

def test_non_decoupled(self):
"""
Tests trace, collected from executing one inference request of non decoupled model.
"""
inputs = [
grpcclient.InferInput("IN", [1], "INT32").set_data_from_numpy(
self.input_data["IN"]
),
grpcclient.InferInput("DELAY", [1], "UINT32").set_data_from_numpy(
self.input_data["DELAY"]
),
grpcclient.InferInput("WAIT", [1], "UINT32").set_data_from_numpy(
self.input_data["WAIT"]
),
]

triton_client = grpcclient.InferenceServerClient(
url="localhost:8001", verbose=True
)
# Expect the inference is successful
res = triton_client.infer(
model_name=self.non_decoupled_model_name_, inputs=inputs
)
self._test_non_decoupled_trace()
self.assertEqual(1, res.as_numpy("OUT")[0])
self.assertEqual(0, res.as_numpy("IDX")[0])

def test_grpc_trace_simple_model_context_propagation(self):
"""
Tests trace, collected from executing one inference request
Expand Down
Loading

0 comments on commit 76e53c9

Please sign in to comment.