From f17f40be812346450aa94b8cc361b452021a9778 Mon Sep 17 00:00:00 2001 From: Olga Andreeva <124622579+oandreeva-nv@users.noreply.github.com> Date: Fri, 15 Mar 2024 14:52:29 -0700 Subject: [PATCH] Exposing trace context to python backend (#6985) * Added TRITONSERVER_InferenceTraceSetContext logic --- qa/L0_trace/opentelemetry_unittest.py | 23 ++++++++++ qa/L0_trace/test.sh | 61 ++++++++++++++++++++++++++- qa/L0_trace/trace_context.py | 46 ++++++++++++++++++++ src/CMakeLists.txt | 1 + src/tracer.cc | 33 ++++++++++++++- src/tracer.h | 19 +++++++++ 6 files changed, 180 insertions(+), 3 deletions(-) create mode 100644 qa/L0_trace/trace_context.py diff --git a/qa/L0_trace/opentelemetry_unittest.py b/qa/L0_trace/opentelemetry_unittest.py index 41ec1d0024..04a82d157c 100644 --- a/qa/L0_trace/opentelemetry_unittest.py +++ b/qa/L0_trace/opentelemetry_unittest.py @@ -29,6 +29,7 @@ sys.path.append("../common") import json import queue +import re import shutil import subprocess import time @@ -104,6 +105,7 @@ def setUp(self): self.simple_model_name = "simple" self.ensemble_model_name = "ensemble_add_sub_int32_int32_int32" self.bls_model_name = "bls_simple" + self.trace_context_model = "trace_context" self.test_models = [ self.simple_model_name, self.ensemble_model_name, @@ -756,6 +758,27 @@ def test_sagemaker_invoke_trace_simple_model_context_propagation(self): time.sleep(5) self._test_simple_trace(headers=self.client_headers) + def test_trace_context_exposed_to_pbe(self): + """ + Tests trace context, propagated to python backend. + """ + triton_client_http = httpclient.InferenceServerClient( + "localhost:8000", verbose=True + ) + expect_none = np.array([False], dtype=bool) + inputs = httpclient.InferInput("expect_none", [1], "BOOL") + inputs.set_data_from_numpy(expect_none) + try: + result = triton_client_http.infer(self.trace_context_model, inputs=[inputs]) + except InferenceServerException as e: + self.fail(e.message()) + + context = result.as_numpy("OUTPUT0")[()].decode("utf-8") + context = json.loads(context) + self.assertIn("traceparent", context.keys()) + context_pattern = re.compile(r"\d{2}-[0-9a-f]{32}-[0-9a-f]{16}-\d{2}") + self.assertIsNotNone(re.match(context_pattern, context["traceparent"])) + if __name__ == "__main__": unittest.main() diff --git a/qa/L0_trace/test.sh b/qa/L0_trace/test.sh index b4a17bcd95..3a5e976594 100755 --- a/qa/L0_trace/test.sh +++ b/qa/L0_trace/test.sh @@ -131,6 +131,16 @@ function update_trace_setting { set -e } +function check_pbe_trace_context { + model_name="${1}" + expect_none="${2}" + data='{"inputs":[{"name":"expect_none","datatype":"BOOL","shape":[1],"data":['${expect_none}']}]}' + rm -f ./curl.out + set +e + code=`curl -s -w %{http_code} -o ./curl.out -X POST localhost:8000/v2/models/${model_name}/infer -d ${data}` + set -e +} + function send_inference_requests { log_file="${1}" upper_bound="${2}" @@ -711,7 +721,7 @@ rm collected_traces.json* # Unittests then check that produced spans have expected format and events OPENTELEMETRY_TEST=opentelemetry_unittest.py OPENTELEMETRY_LOG="opentelemetry_unittest.log" -EXPECTED_NUM_TESTS="13" +EXPECTED_NUM_TESTS="14" # Set up repo and args for SageMaker export SAGEMAKER_TRITON_DEFAULT_MODEL_NAME="simple" @@ -722,11 +732,15 @@ cp -r $DATADIR/$MODELBASE/* ${MODEL_PATH} && \ rm -r ${MODEL_PATH}/2 && rm -r ${MODEL_PATH}/3 && \ sed -i "s/onnx_int32_int32_int32/simple/" ${MODEL_PATH}/config.pbtxt +# Add model to test trace context exposed to python backend +mkdir -p $MODELSDIR/trace_context/1 && cp ./trace_context.py $MODELSDIR/trace_context/1/model.py + SERVER_ARGS="--allow-sagemaker=true --model-control-mode=explicit \ --load-model=simple --load-model=ensemble_add_sub_int32_int32_int32 \ --load-model=bls_simple --trace-config=level=TIMESTAMPS \ - --trace-config=rate=1 --trace-config=count=-1 --trace-config=mode=opentelemetry \ + --load-model=trace_context --trace-config=rate=1 \ + --trace-config=count=-1 --trace-config=mode=opentelemetry \ --trace-config=opentelemetry,resource=test.key=test.value \ --trace-config=opentelemetry,resource=service.name=test_triton \ --trace-config=opentelemetry,url=localhost:$OTLP_PORT/v1/traces \ @@ -1025,4 +1039,47 @@ kill $SERVER_PID wait $SERVER_PID set +e +# Test that PBE returns None as trace context in trace mode Triton +SERVER_ARGS="--trace-config=level=TIMESTAMPS --trace-config=rate=1\ + --trace-config=count=-1 --trace-config=mode=triton \ + --model-repository=$MODELSDIR --log-verbose=1" +SERVER_LOG="./inference_server_triton_trace_context.log" + +run_server +if [ "$SERVER_PID" == "0" ]; then + echo -e "\n***\n*** Failed to start $SERVER\n***" + cat $SERVER_LOG + exit 1 +fi + +check_pbe_trace_context "trace_context" true +assert_curl_success "PBE trace context is not None" + +set -e +kill $SERVER_PID +wait $SERVER_PID +set +e + +# Test that PBE returns None as trace context in trace mode OpenTelemetry, +# but traceing is OFF. +SERVER_ARGS="--trace-config=level=OFF --trace-config=rate=1\ + --trace-config=count=-1 --trace-config=mode=opentelemetry \ + --model-repository=$MODELSDIR --log-verbose=1" +SERVER_LOG="./inference_server_triton_trace_context.log" + +run_server +if [ "$SERVER_PID" == "0" ]; then + echo -e "\n***\n*** Failed to start $SERVER\n***" + cat $SERVER_LOG + exit 1 +fi + +check_pbe_trace_context "trace_context" true +assert_curl_success "PBE trace context is not None" + +set -e +kill $SERVER_PID +wait $SERVER_PID +set +e + exit $RET diff --git a/qa/L0_trace/trace_context.py b/qa/L0_trace/trace_context.py new file mode 100644 index 0000000000..f47db92a58 --- /dev/null +++ b/qa/L0_trace/trace_context.py @@ -0,0 +1,46 @@ +import numpy as np +import triton_python_backend_utils as pb_utils + + +class TritonPythonModel: + @staticmethod + def auto_complete_config(auto_complete_model_config): + inputs = [{"name": "expect_none", "data_type": "TYPE_BOOL", "dims": [1]}] + outputs = [{"name": "OUTPUT0", "data_type": "TYPE_STRING", "dims": [-1]}] + + config = auto_complete_model_config.as_dict() + input_names = [] + output_names = [] + for input in config["input"]: + input_names.append(input["name"]) + for output in config["output"]: + output_names.append(output["name"]) + + for input in inputs: + if input["name"] not in input_names: + auto_complete_model_config.add_input(input) + for output in outputs: + if output["name"] not in output_names: + auto_complete_model_config.add_output(output) + + return auto_complete_model_config + + def execute(self, requests): + responses = [] + for request in requests: + expect_none = pb_utils.get_input_tensor_by_name( + request, "expect_none" + ).as_numpy()[0] + context = request.trace().get_context() + if expect_none and context is not None: + raise pb_utils.TritonModelException("Context should be None") + if not expect_none and context is None: + raise pb_utils.TritonModelException("Context should NOT be None") + + output_tensor = pb_utils.Tensor( + "OUTPUT0", np.array(context).astype(np.bytes_) + ) + inference_response = pb_utils.InferenceResponse([output_tensor]) + responses.append(inference_response) + + return responses diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index fce60106c3..783275d8d7 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -525,6 +525,7 @@ if(${TRITON_ENABLE_TRACING}) tracing-library PUBLIC triton-common-logging # from repo-common + triton-common-json # from repo-common triton-core-serverapi # from repo-core triton-core-serverstub # from repo-core ) diff --git a/src/tracer.cc b/src/tracer.cc index 6d7110a9f3..3c8f4dcddd 100644 --- a/src/tracer.cc +++ b/src/tracer.cc @@ -300,7 +300,9 @@ TraceManager::GetTraceStartOptions( { TraceManager::TraceStartOptions start_options; GetTraceSetting(model_name, start_options.trace_setting); - if (start_options.trace_setting->mode_ == TRACE_MODE_OPENTELEMETRY) { + if (!start_options.trace_setting->level_ == + TRITONSERVER_TRACE_LEVEL_DISABLED && + start_options.trace_setting->mode_ == TRACE_MODE_OPENTELEMETRY) { #ifndef _WIN32 auto prop = otel_cntxt::propagation::GlobalTextMapPropagator::GetGlobalPropagator(); @@ -560,6 +562,9 @@ TraceManager::Trace::StartSpan( if (std::string(request_id) != "") { span->SetAttribute("triton.request_id", request_id); } + triton::common::TritonJson::WriteBuffer buffer; + PrepareTraceContext(span, &buffer); + TRITONSERVER_InferenceTraceSetContext(trace, buffer.Contents().c_str()); } otel_context_ = otel_context_.SetValue(span_key, span); @@ -702,6 +707,32 @@ TraceManager::Trace::AddEvent( span->AddEvent(event, time_offset_ + std::chrono::nanoseconds{timestamp}); } } + +void +TraceManager::Trace::PrepareTraceContext( + opentelemetry::nostd::shared_ptr span, + triton::common::TritonJson::WriteBuffer* buffer) +{ + triton::common::TritonJson::Value json( + triton::common::TritonJson::ValueType::OBJECT); + char trace_id[32] = {0}; + char span_id[16] = {0}; + char trace_flags[2] = {0}; + span->GetContext().span_id().ToLowerBase16(span_id); + span->GetContext().trace_id().ToLowerBase16(trace_id); + span->GetContext().trace_flags().ToLowerBase16(trace_flags); + std::string kTraceParent = std::string("traceparent"); + std::string kTraceState = std::string("tracestate"); + std::string traceparent = std::string("00-") + std::string(trace_id, 32) + + std::string("-") + std::string(span_id, 16) + + std::string("-") + std::string(trace_flags, 2); + std::string tracestate = span->GetContext().trace_state()->ToHeader(); + json.SetStringObject(kTraceParent.c_str(), traceparent); + if (!tracestate.empty()) { + json.SetStringObject(kTraceState.c_str(), tracestate); + } + json.Write(buffer); +} #endif void diff --git a/src/tracer.h b/src/tracer.h index a591f7bd98..cae636ca9b 100644 --- a/src/tracer.h +++ b/src/tracer.h @@ -54,6 +54,11 @@ namespace otel_cntxt = opentelemetry::context; namespace otel_resource = opentelemetry::sdk::resource; #endif #include "triton/core/tritonserver.h" +#define TRITONJSON_STATUSTYPE TRITONSERVER_Error* +#define TRITONJSON_STATUSSUCCESS nullptr +#define TRITONJSON_STATUSRETURN(M) \ + return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL, (M).c_str()) +#include "triton/common/triton_json.h" namespace triton { namespace server { @@ -266,6 +271,20 @@ class TraceManager { // OTel context to store spans, created in the current trace opentelemetry::context::Context otel_context_; + /// Prepares trace context to propagate to TRITONSERVER_InferenceTrace. + /// Trace context follows W3C Trace Context specification. + /// Ref. https://www.w3.org/TR/trace-context/. + /// OpenTelemetry ref: + /// https://github.com/open-telemetry/opentelemetry-cpp/blob/4bd64c9a336fd438d6c4c9dad2e6b61b0585311f/api/include/opentelemetry/trace/propagation/http_trace_context.h#L94-L113 + /// + /// \param span An OpenTelemetry span, which is used to extract + /// OpenTelemetry's trace_id and span_id. + /// \param buffer Buffer used when writing JSON representation of + /// OpenTelemetry's context. + void PrepareTraceContext( + opentelemetry::nostd::shared_ptr span, + triton::common::TritonJson::WriteBuffer* buffer); + private: // OpenTelemetry SDK relies on system's clock for event timestamps. // Triton Tracing records timestamps using steady_clock. This is a