Skip to content

Commit

Permalink
Exposing trace context to python backend (#6985)
Browse files Browse the repository at this point in the history
* Added TRITONSERVER_InferenceTraceSetContext logic
  • Loading branch information
oandreeva-nv committed Mar 15, 2024
1 parent 3d817de commit f17f40b
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 3 deletions.
23 changes: 23 additions & 0 deletions qa/L0_trace/opentelemetry_unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
sys.path.append("../common")
import json
import queue
import re
import shutil
import subprocess
import time
Expand Down Expand Up @@ -104,6 +105,7 @@ def setUp(self):
self.simple_model_name = "simple"
self.ensemble_model_name = "ensemble_add_sub_int32_int32_int32"
self.bls_model_name = "bls_simple"
self.trace_context_model = "trace_context"
self.test_models = [
self.simple_model_name,
self.ensemble_model_name,
Expand Down Expand Up @@ -756,6 +758,27 @@ def test_sagemaker_invoke_trace_simple_model_context_propagation(self):
time.sleep(5)
self._test_simple_trace(headers=self.client_headers)

def test_trace_context_exposed_to_pbe(self):
"""
Tests trace context, propagated to python backend.
"""
triton_client_http = httpclient.InferenceServerClient(
"localhost:8000", verbose=True
)
expect_none = np.array([False], dtype=bool)
inputs = httpclient.InferInput("expect_none", [1], "BOOL")
inputs.set_data_from_numpy(expect_none)
try:
result = triton_client_http.infer(self.trace_context_model, inputs=[inputs])
except InferenceServerException as e:
self.fail(e.message())

context = result.as_numpy("OUTPUT0")[()].decode("utf-8")
context = json.loads(context)
self.assertIn("traceparent", context.keys())
context_pattern = re.compile(r"\d{2}-[0-9a-f]{32}-[0-9a-f]{16}-\d{2}")
self.assertIsNotNone(re.match(context_pattern, context["traceparent"]))


if __name__ == "__main__":
unittest.main()
61 changes: 59 additions & 2 deletions qa/L0_trace/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,16 @@ function update_trace_setting {
set -e
}

function check_pbe_trace_context {
model_name="${1}"
expect_none="${2}"
data='{"inputs":[{"name":"expect_none","datatype":"BOOL","shape":[1],"data":['${expect_none}']}]}'
rm -f ./curl.out
set +e
code=`curl -s -w %{http_code} -o ./curl.out -X POST localhost:8000/v2/models/${model_name}/infer -d ${data}`
set -e
}

function send_inference_requests {
log_file="${1}"
upper_bound="${2}"
Expand Down Expand Up @@ -711,7 +721,7 @@ rm collected_traces.json*
# Unittests then check that produced spans have expected format and events
OPENTELEMETRY_TEST=opentelemetry_unittest.py
OPENTELEMETRY_LOG="opentelemetry_unittest.log"
EXPECTED_NUM_TESTS="13"
EXPECTED_NUM_TESTS="14"

# Set up repo and args for SageMaker
export SAGEMAKER_TRITON_DEFAULT_MODEL_NAME="simple"
Expand All @@ -722,11 +732,15 @@ cp -r $DATADIR/$MODELBASE/* ${MODEL_PATH} && \
rm -r ${MODEL_PATH}/2 && rm -r ${MODEL_PATH}/3 && \
sed -i "s/onnx_int32_int32_int32/simple/" ${MODEL_PATH}/config.pbtxt

# Add model to test trace context exposed to python backend
mkdir -p $MODELSDIR/trace_context/1 && cp ./trace_context.py $MODELSDIR/trace_context/1/model.py


SERVER_ARGS="--allow-sagemaker=true --model-control-mode=explicit \
--load-model=simple --load-model=ensemble_add_sub_int32_int32_int32 \
--load-model=bls_simple --trace-config=level=TIMESTAMPS \
--trace-config=rate=1 --trace-config=count=-1 --trace-config=mode=opentelemetry \
--load-model=trace_context --trace-config=rate=1 \
--trace-config=count=-1 --trace-config=mode=opentelemetry \
--trace-config=opentelemetry,resource=test.key=test.value \
--trace-config=opentelemetry,resource=service.name=test_triton \
--trace-config=opentelemetry,url=localhost:$OTLP_PORT/v1/traces \
Expand Down Expand Up @@ -1025,4 +1039,47 @@ kill $SERVER_PID
wait $SERVER_PID
set +e

# Test that PBE returns None as trace context in trace mode Triton
SERVER_ARGS="--trace-config=level=TIMESTAMPS --trace-config=rate=1\
--trace-config=count=-1 --trace-config=mode=triton \
--model-repository=$MODELSDIR --log-verbose=1"
SERVER_LOG="./inference_server_triton_trace_context.log"

run_server
if [ "$SERVER_PID" == "0" ]; then
echo -e "\n***\n*** Failed to start $SERVER\n***"
cat $SERVER_LOG
exit 1
fi

check_pbe_trace_context "trace_context" true
assert_curl_success "PBE trace context is not None"

set -e
kill $SERVER_PID
wait $SERVER_PID
set +e

# Test that PBE returns None as trace context in trace mode OpenTelemetry,
# but traceing is OFF.
SERVER_ARGS="--trace-config=level=OFF --trace-config=rate=1\
--trace-config=count=-1 --trace-config=mode=opentelemetry \
--model-repository=$MODELSDIR --log-verbose=1"
SERVER_LOG="./inference_server_triton_trace_context.log"

run_server
if [ "$SERVER_PID" == "0" ]; then
echo -e "\n***\n*** Failed to start $SERVER\n***"
cat $SERVER_LOG
exit 1
fi

check_pbe_trace_context "trace_context" true
assert_curl_success "PBE trace context is not None"

set -e
kill $SERVER_PID
wait $SERVER_PID
set +e

exit $RET
46 changes: 46 additions & 0 deletions qa/L0_trace/trace_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import numpy as np
import triton_python_backend_utils as pb_utils


class TritonPythonModel:
@staticmethod
def auto_complete_config(auto_complete_model_config):
inputs = [{"name": "expect_none", "data_type": "TYPE_BOOL", "dims": [1]}]
outputs = [{"name": "OUTPUT0", "data_type": "TYPE_STRING", "dims": [-1]}]

config = auto_complete_model_config.as_dict()
input_names = []
output_names = []
for input in config["input"]:
input_names.append(input["name"])
for output in config["output"]:
output_names.append(output["name"])

for input in inputs:
if input["name"] not in input_names:
auto_complete_model_config.add_input(input)
for output in outputs:
if output["name"] not in output_names:
auto_complete_model_config.add_output(output)

return auto_complete_model_config

def execute(self, requests):
responses = []
for request in requests:
expect_none = pb_utils.get_input_tensor_by_name(
request, "expect_none"
).as_numpy()[0]
context = request.trace().get_context()
if expect_none and context is not None:
raise pb_utils.TritonModelException("Context should be None")
if not expect_none and context is None:
raise pb_utils.TritonModelException("Context should NOT be None")

output_tensor = pb_utils.Tensor(
"OUTPUT0", np.array(context).astype(np.bytes_)
)
inference_response = pb_utils.InferenceResponse([output_tensor])
responses.append(inference_response)

return responses
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,7 @@ if(${TRITON_ENABLE_TRACING})
tracing-library
PUBLIC
triton-common-logging # from repo-common
triton-common-json # from repo-common
triton-core-serverapi # from repo-core
triton-core-serverstub # from repo-core
)
Expand Down
33 changes: 32 additions & 1 deletion src/tracer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,9 @@ TraceManager::GetTraceStartOptions(
{
TraceManager::TraceStartOptions start_options;
GetTraceSetting(model_name, start_options.trace_setting);
if (start_options.trace_setting->mode_ == TRACE_MODE_OPENTELEMETRY) {
if (!start_options.trace_setting->level_ ==
TRITONSERVER_TRACE_LEVEL_DISABLED &&
start_options.trace_setting->mode_ == TRACE_MODE_OPENTELEMETRY) {
#ifndef _WIN32
auto prop =
otel_cntxt::propagation::GlobalTextMapPropagator::GetGlobalPropagator();
Expand Down Expand Up @@ -560,6 +562,9 @@ TraceManager::Trace::StartSpan(
if (std::string(request_id) != "") {
span->SetAttribute("triton.request_id", request_id);
}
triton::common::TritonJson::WriteBuffer buffer;
PrepareTraceContext(span, &buffer);
TRITONSERVER_InferenceTraceSetContext(trace, buffer.Contents().c_str());
}

otel_context_ = otel_context_.SetValue(span_key, span);
Expand Down Expand Up @@ -702,6 +707,32 @@ TraceManager::Trace::AddEvent(
span->AddEvent(event, time_offset_ + std::chrono::nanoseconds{timestamp});
}
}

void
TraceManager::Trace::PrepareTraceContext(
opentelemetry::nostd::shared_ptr<otel_trace_api::Span> span,
triton::common::TritonJson::WriteBuffer* buffer)
{
triton::common::TritonJson::Value json(
triton::common::TritonJson::ValueType::OBJECT);
char trace_id[32] = {0};
char span_id[16] = {0};
char trace_flags[2] = {0};
span->GetContext().span_id().ToLowerBase16(span_id);
span->GetContext().trace_id().ToLowerBase16(trace_id);
span->GetContext().trace_flags().ToLowerBase16(trace_flags);
std::string kTraceParent = std::string("traceparent");
std::string kTraceState = std::string("tracestate");
std::string traceparent = std::string("00-") + std::string(trace_id, 32) +
std::string("-") + std::string(span_id, 16) +
std::string("-") + std::string(trace_flags, 2);
std::string tracestate = span->GetContext().trace_state()->ToHeader();
json.SetStringObject(kTraceParent.c_str(), traceparent);
if (!tracestate.empty()) {
json.SetStringObject(kTraceState.c_str(), tracestate);
}
json.Write(buffer);
}
#endif

void
Expand Down
19 changes: 19 additions & 0 deletions src/tracer.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ namespace otel_cntxt = opentelemetry::context;
namespace otel_resource = opentelemetry::sdk::resource;
#endif
#include "triton/core/tritonserver.h"
#define TRITONJSON_STATUSTYPE TRITONSERVER_Error*
#define TRITONJSON_STATUSSUCCESS nullptr
#define TRITONJSON_STATUSRETURN(M) \
return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL, (M).c_str())
#include "triton/common/triton_json.h"

namespace triton { namespace server {

Expand Down Expand Up @@ -266,6 +271,20 @@ class TraceManager {
// OTel context to store spans, created in the current trace
opentelemetry::context::Context otel_context_;

/// Prepares trace context to propagate to TRITONSERVER_InferenceTrace.
/// Trace context follows W3C Trace Context specification.
/// Ref. https://www.w3.org/TR/trace-context/.
/// OpenTelemetry ref:
/// https://github.com/open-telemetry/opentelemetry-cpp/blob/4bd64c9a336fd438d6c4c9dad2e6b61b0585311f/api/include/opentelemetry/trace/propagation/http_trace_context.h#L94-L113
///
/// \param span An OpenTelemetry span, which is used to extract
/// OpenTelemetry's trace_id and span_id.
/// \param buffer Buffer used when writing JSON representation of
/// OpenTelemetry's context.
void PrepareTraceContext(
opentelemetry::nostd::shared_ptr<otel_trace_api::Span> span,
triton::common::TritonJson::WriteBuffer* buffer);

private:
// OpenTelemetry SDK relies on system's clock for event timestamps.
// Triton Tracing records timestamps using steady_clock. This is a
Expand Down

0 comments on commit f17f40b

Please sign in to comment.