Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exposing trace context to python backend (#6985) #6993

Merged
merged 1 commit into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions qa/L0_trace/opentelemetry_unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
sys.path.append("../common")
import json
import queue
import re
import shutil
import subprocess
import time
Expand Down Expand Up @@ -104,6 +105,7 @@ def setUp(self):
self.simple_model_name = "simple"
self.ensemble_model_name = "ensemble_add_sub_int32_int32_int32"
self.bls_model_name = "bls_simple"
self.trace_context_model = "trace_context"
self.test_models = [
self.simple_model_name,
self.ensemble_model_name,
Expand Down Expand Up @@ -756,6 +758,27 @@ def test_sagemaker_invoke_trace_simple_model_context_propagation(self):
time.sleep(5)
self._test_simple_trace(headers=self.client_headers)

def test_trace_context_exposed_to_pbe(self):
"""
Tests trace context, propagated to python backend.
"""
triton_client_http = httpclient.InferenceServerClient(
"localhost:8000", verbose=True
)
expect_none = np.array([False], dtype=bool)
inputs = httpclient.InferInput("expect_none", [1], "BOOL")
inputs.set_data_from_numpy(expect_none)
try:
result = triton_client_http.infer(self.trace_context_model, inputs=[inputs])
except InferenceServerException as e:
self.fail(e.message())

context = result.as_numpy("OUTPUT0")[()].decode("utf-8")
context = json.loads(context)
self.assertIn("traceparent", context.keys())
context_pattern = re.compile(r"\d{2}-[0-9a-f]{32}-[0-9a-f]{16}-\d{2}")
self.assertIsNotNone(re.match(context_pattern, context["traceparent"]))


if __name__ == "__main__":
unittest.main()
61 changes: 59 additions & 2 deletions qa/L0_trace/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,16 @@ function update_trace_setting {
set -e
}

function check_pbe_trace_context {
model_name="${1}"
expect_none="${2}"
data='{"inputs":[{"name":"expect_none","datatype":"BOOL","shape":[1],"data":['${expect_none}']}]}'
rm -f ./curl.out
set +e
code=`curl -s -w %{http_code} -o ./curl.out -X POST localhost:8000/v2/models/${model_name}/infer -d ${data}`
set -e
}

function send_inference_requests {
log_file="${1}"
upper_bound="${2}"
Expand Down Expand Up @@ -711,7 +721,7 @@ rm collected_traces.json*
# Unittests then check that produced spans have expected format and events
OPENTELEMETRY_TEST=opentelemetry_unittest.py
OPENTELEMETRY_LOG="opentelemetry_unittest.log"
EXPECTED_NUM_TESTS="13"
EXPECTED_NUM_TESTS="14"

# Set up repo and args for SageMaker
export SAGEMAKER_TRITON_DEFAULT_MODEL_NAME="simple"
Expand All @@ -722,11 +732,15 @@ cp -r $DATADIR/$MODELBASE/* ${MODEL_PATH} && \
rm -r ${MODEL_PATH}/2 && rm -r ${MODEL_PATH}/3 && \
sed -i "s/onnx_int32_int32_int32/simple/" ${MODEL_PATH}/config.pbtxt

# Add model to test trace context exposed to python backend
mkdir -p $MODELSDIR/trace_context/1 && cp ./trace_context.py $MODELSDIR/trace_context/1/model.py


SERVER_ARGS="--allow-sagemaker=true --model-control-mode=explicit \
--load-model=simple --load-model=ensemble_add_sub_int32_int32_int32 \
--load-model=bls_simple --trace-config=level=TIMESTAMPS \
--trace-config=rate=1 --trace-config=count=-1 --trace-config=mode=opentelemetry \
--load-model=trace_context --trace-config=rate=1 \
--trace-config=count=-1 --trace-config=mode=opentelemetry \
--trace-config=opentelemetry,resource=test.key=test.value \
--trace-config=opentelemetry,resource=service.name=test_triton \
--trace-config=opentelemetry,url=localhost:$OTLP_PORT/v1/traces \
Expand Down Expand Up @@ -1025,4 +1039,47 @@ kill $SERVER_PID
wait $SERVER_PID
set +e

# Test that PBE returns None as trace context in trace mode Triton
SERVER_ARGS="--trace-config=level=TIMESTAMPS --trace-config=rate=1\
--trace-config=count=-1 --trace-config=mode=triton \
--model-repository=$MODELSDIR --log-verbose=1"
SERVER_LOG="./inference_server_triton_trace_context.log"

run_server
if [ "$SERVER_PID" == "0" ]; then
echo -e "\n***\n*** Failed to start $SERVER\n***"
cat $SERVER_LOG
exit 1
fi

check_pbe_trace_context "trace_context" true
assert_curl_success "PBE trace context is not None"

set -e
kill $SERVER_PID
wait $SERVER_PID
set +e

# Test that PBE returns None as trace context in trace mode OpenTelemetry,
# but traceing is OFF.
SERVER_ARGS="--trace-config=level=OFF --trace-config=rate=1\
--trace-config=count=-1 --trace-config=mode=opentelemetry \
--model-repository=$MODELSDIR --log-verbose=1"
SERVER_LOG="./inference_server_triton_trace_context.log"

run_server
if [ "$SERVER_PID" == "0" ]; then
echo -e "\n***\n*** Failed to start $SERVER\n***"
cat $SERVER_LOG
exit 1
fi

check_pbe_trace_context "trace_context" true
assert_curl_success "PBE trace context is not None"

set -e
kill $SERVER_PID
wait $SERVER_PID
set +e

exit $RET
46 changes: 46 additions & 0 deletions qa/L0_trace/trace_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import numpy as np
import triton_python_backend_utils as pb_utils


class TritonPythonModel:
@staticmethod
def auto_complete_config(auto_complete_model_config):
inputs = [{"name": "expect_none", "data_type": "TYPE_BOOL", "dims": [1]}]
outputs = [{"name": "OUTPUT0", "data_type": "TYPE_STRING", "dims": [-1]}]

config = auto_complete_model_config.as_dict()
input_names = []
output_names = []
for input in config["input"]:
input_names.append(input["name"])
for output in config["output"]:
output_names.append(output["name"])

for input in inputs:
if input["name"] not in input_names:
auto_complete_model_config.add_input(input)
for output in outputs:
if output["name"] not in output_names:
auto_complete_model_config.add_output(output)

return auto_complete_model_config

def execute(self, requests):
responses = []
for request in requests:
expect_none = pb_utils.get_input_tensor_by_name(
request, "expect_none"
).as_numpy()[0]
context = request.trace().get_context()
if expect_none and context is not None:
raise pb_utils.TritonModelException("Context should be None")
if not expect_none and context is None:
raise pb_utils.TritonModelException("Context should NOT be None")

output_tensor = pb_utils.Tensor(
"OUTPUT0", np.array(context).astype(np.bytes_)
)
inference_response = pb_utils.InferenceResponse([output_tensor])
responses.append(inference_response)

return responses
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,7 @@ if(${TRITON_ENABLE_TRACING})
tracing-library
PUBLIC
triton-common-logging # from repo-common
triton-common-json # from repo-common
triton-core-serverapi # from repo-core
triton-core-serverstub # from repo-core
)
Expand Down
33 changes: 32 additions & 1 deletion src/tracer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,9 @@ TraceManager::GetTraceStartOptions(
{
TraceManager::TraceStartOptions start_options;
GetTraceSetting(model_name, start_options.trace_setting);
if (start_options.trace_setting->mode_ == TRACE_MODE_OPENTELEMETRY) {
if (!start_options.trace_setting->level_ ==
TRITONSERVER_TRACE_LEVEL_DISABLED &&
start_options.trace_setting->mode_ == TRACE_MODE_OPENTELEMETRY) {
#ifndef _WIN32
auto prop =
otel_cntxt::propagation::GlobalTextMapPropagator::GetGlobalPropagator();
Expand Down Expand Up @@ -560,6 +562,9 @@ TraceManager::Trace::StartSpan(
if (std::string(request_id) != "") {
span->SetAttribute("triton.request_id", request_id);
}
triton::common::TritonJson::WriteBuffer buffer;
PrepareTraceContext(span, &buffer);
TRITONSERVER_InferenceTraceSetContext(trace, buffer.Contents().c_str());
}

otel_context_ = otel_context_.SetValue(span_key, span);
Expand Down Expand Up @@ -702,6 +707,32 @@ TraceManager::Trace::AddEvent(
span->AddEvent(event, time_offset_ + std::chrono::nanoseconds{timestamp});
}
}

void
TraceManager::Trace::PrepareTraceContext(
opentelemetry::nostd::shared_ptr<otel_trace_api::Span> span,
triton::common::TritonJson::WriteBuffer* buffer)
{
triton::common::TritonJson::Value json(
triton::common::TritonJson::ValueType::OBJECT);
char trace_id[32] = {0};
char span_id[16] = {0};
char trace_flags[2] = {0};
span->GetContext().span_id().ToLowerBase16(span_id);
span->GetContext().trace_id().ToLowerBase16(trace_id);
span->GetContext().trace_flags().ToLowerBase16(trace_flags);
std::string kTraceParent = std::string("traceparent");
std::string kTraceState = std::string("tracestate");
std::string traceparent = std::string("00-") + std::string(trace_id, 32) +
std::string("-") + std::string(span_id, 16) +
std::string("-") + std::string(trace_flags, 2);
std::string tracestate = span->GetContext().trace_state()->ToHeader();
json.SetStringObject(kTraceParent.c_str(), traceparent);
if (!tracestate.empty()) {
json.SetStringObject(kTraceState.c_str(), tracestate);
}
json.Write(buffer);
}
#endif

void
Expand Down
19 changes: 19 additions & 0 deletions src/tracer.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ namespace otel_cntxt = opentelemetry::context;
namespace otel_resource = opentelemetry::sdk::resource;
#endif
#include "triton/core/tritonserver.h"
#define TRITONJSON_STATUSTYPE TRITONSERVER_Error*
#define TRITONJSON_STATUSSUCCESS nullptr
#define TRITONJSON_STATUSRETURN(M) \
return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL, (M).c_str())
#include "triton/common/triton_json.h"

namespace triton { namespace server {

Expand Down Expand Up @@ -266,6 +271,20 @@ class TraceManager {
// OTel context to store spans, created in the current trace
opentelemetry::context::Context otel_context_;

/// Prepares trace context to propagate to TRITONSERVER_InferenceTrace.
/// Trace context follows W3C Trace Context specification.
/// Ref. https://www.w3.org/TR/trace-context/.
/// OpenTelemetry ref:
/// https://github.com/open-telemetry/opentelemetry-cpp/blob/4bd64c9a336fd438d6c4c9dad2e6b61b0585311f/api/include/opentelemetry/trace/propagation/http_trace_context.h#L94-L113
///
/// \param span An OpenTelemetry span, which is used to extract
/// OpenTelemetry's trace_id and span_id.
/// \param buffer Buffer used when writing JSON representation of
/// OpenTelemetry's context.
void PrepareTraceContext(
opentelemetry::nostd::shared_ptr<otel_trace_api::Span> span,
triton::common::TritonJson::WriteBuffer* buffer);

private:
// OpenTelemetry SDK relies on system's clock for event timestamps.
// Triton Tracing records timestamps using steady_clock. This is a
Expand Down
Loading