Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exposing trace context to python backend #6985

Merged
merged 8 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion qa/L0_trace/opentelemetry_unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
sys.path.append("../common")
import json
import queue
import re
import shutil
import subprocess
import time
Expand All @@ -40,7 +41,7 @@
import test_util as tu
import tritonclient.grpc as grpcclient
import tritonclient.http as httpclient
from tritonclient.utils import InferenceServerException
from tritonclient.utils import InferenceServerException, np_to_triton_dtype
Fixed Show fixed Hide fixed

NO_PARENT_SPAN_ID = ""
COLLECTOR_TIMEOUT = 10
Expand Down Expand Up @@ -756,6 +757,22 @@ def test_sagemaker_invoke_trace_simple_model_context_propagation(self):
time.sleep(5)
self._test_simple_trace(headers=self.client_headers)

def test_trace_context_exposed_to_pbe(self):
rmccorm4 marked this conversation as resolved.
Show resolved Hide resolved
"""
Tests trace context, propagated to python backend.
rmccorm4 marked this conversation as resolved.
Show resolved Hide resolved
"""
triton_client_http = httpclient.InferenceServerClient(
"localhost:8000", verbose=True
)
xs = np.zeros((1, 1)).astype(np.float32)
inputs = httpclient.InferInput("INPUT0", xs.shape, np_to_triton_dtype(xs.dtype))
result = triton_client_http.infer(
self.ensemble_model_name, inputs, headers=self.client_headers
)
context = result.as_numpy("OUTPUT0")
context_pattern = re.compile(r"\d{2}-[0-9a-f]{32}-[0-9a-f]{16}-\d{2}")
self.assertIsNotNone(re.match(context_pattern, context))


if __name__ == "__main__":
unittest.main()
3 changes: 3 additions & 0 deletions qa/L0_trace/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,9 @@ cp -r $DATADIR/$MODELBASE/* ${MODEL_PATH} && \
rm -r ${MODEL_PATH}/2 && rm -r ${MODEL_PATH}/3 && \
sed -i "s/onnx_int32_int32_int32/simple/" ${MODEL_PATH}/config.pbtxt

# Add model to test trace context exposed to python backend
mkdir -p $MODELSDIR/trace_context/1 && cp ./trace_context.py $MODELSDIR/trace_context/1/model.py


SERVER_ARGS="--allow-sagemaker=true --model-control-mode=explicit \
--load-model=simple --load-model=ensemble_add_sub_int32_int32_int32 \
Expand Down
48 changes: 48 additions & 0 deletions qa/L0_trace/trace_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import json
import time
Fixed Show fixed Hide fixed

import numpy as np
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@oandreeva-nv do we need a copyright statement for this file?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, will follow up shortly

import triton_python_backend_utils as pb_utils


class TritonPythonModel:
@staticmethod
def auto_complete_config(auto_complete_model_config):
inputs = [{"name": "INPUT0", "data_type": "TYPE_FP32", "dims": [1]}]
outputs = [{"name": "OUTPUT0", "data_type": "TYPE_STRING", "dims": [-1]}]

config = auto_complete_model_config.as_dict()
input_names = []
output_names = []
for input in config["input"]:
input_names.append(input["name"])
for output in config["output"]:
output_names.append(output["name"])

for input in inputs:
if input["name"] not in input_names:
auto_complete_model_config.add_input(input)
for output in outputs:
if output["name"] not in output_names:
auto_complete_model_config.add_output(output)

auto_complete_model_config.set_max_batch_size(1)

return auto_complete_model_config

def initialize(self, args):
self.model_config = json.loads(args["model_config"])
output_config = pb_utils.get_output_config_by_name(self.model_config, "OUTPUT0")
self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"])
oandreeva-nv marked this conversation as resolved.
Show resolved Hide resolved

def execute(self, requests):
responses = []
for request in requests:
context = request.trace().get_context(mode="opentelemetry")
oandreeva-nv marked this conversation as resolved.
Show resolved Hide resolved
output_tensor = pb_utils.Tensor(
"OUTPUT0", np.array(context).astype(np.bytes_)
)
inference_response = pb_utils.InferenceResponse([output_tensor])
responses.append(inference_response)

return responses
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,7 @@ if(${TRITON_ENABLE_TRACING})
tracing-library
PUBLIC
triton-common-logging # from repo-common
triton-common-json # from repo-common
triton-core-serverapi # from repo-core
triton-core-serverstub # from repo-core
)
Expand Down
33 changes: 32 additions & 1 deletion src/tracer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,9 @@ TraceManager::GetTraceStartOptions(
{
TraceManager::TraceStartOptions start_options;
GetTraceSetting(model_name, start_options.trace_setting);
if (start_options.trace_setting->mode_ == TRACE_MODE_OPENTELEMETRY) {
if (!start_options.trace_setting->level_ ==
TRITONSERVER_TRACE_LEVEL_DISABLED &&
start_options.trace_setting->mode_ == TRACE_MODE_OPENTELEMETRY) {
oandreeva-nv marked this conversation as resolved.
Show resolved Hide resolved
#ifndef _WIN32
auto prop =
otel_cntxt::propagation::GlobalTextMapPropagator::GetGlobalPropagator();
Expand Down Expand Up @@ -560,6 +562,9 @@ TraceManager::Trace::StartSpan(
if (std::string(request_id) != "") {
span->SetAttribute("triton.request_id", request_id);
}
triton::common::TritonJson::WriteBuffer buffer;
PrepareTraceContext(span, &buffer);
oandreeva-nv marked this conversation as resolved.
Show resolved Hide resolved
TRITONSERVER_InferenceTraceSetContext(trace, buffer.Contents().c_str());
}

otel_context_ = otel_context_.SetValue(span_key, span);
Expand Down Expand Up @@ -702,6 +707,32 @@ TraceManager::Trace::AddEvent(
span->AddEvent(event, time_offset_ + std::chrono::nanoseconds{timestamp});
}
}

void
TraceManager::Trace::PrepareTraceContext(
opentelemetry::nostd::shared_ptr<otel_trace_api::Span> span,
triton::common::TritonJson::WriteBuffer* buffer)
{
triton::common::TritonJson::Value json(
triton::common::TritonJson::ValueType::OBJECT);
char trace_id[32] = {0};
char span_id[16] = {0};
char trace_flags[2] = {0};
span->GetContext().span_id().ToLowerBase16(span_id);
span->GetContext().trace_id().ToLowerBase16(trace_id);
span->GetContext().trace_flags().ToLowerBase16(trace_flags);
std::string kTraceParent = std::string("traceparent");
std::string kTraceState = std::string("tracestate");
std::string traceparent = std::string("00-") + std::string(trace_id, 32) +
std::string("-") + std::string(span_id, 16) +
std::string("-") + std::string(trace_flags, 2);
std::string tracestate = span->GetContext().trace_state()->ToHeader();
json.SetStringObject(kTraceParent.c_str(), traceparent);
if (!tracestate.empty()) {
json.SetStringObject(kTraceState.c_str(), tracestate);
}
json.Write(buffer);
}
#endif

void
Expand Down
19 changes: 19 additions & 0 deletions src/tracer.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ namespace otel_cntxt = opentelemetry::context;
namespace otel_resource = opentelemetry::sdk::resource;
#endif
#include "triton/core/tritonserver.h"
#define TRITONJSON_STATUSTYPE TRITONSERVER_Error*
#define TRITONJSON_STATUSSUCCESS nullptr
#define TRITONJSON_STATUSRETURN(M) \
return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL, (M).c_str())
#include "triton/common/triton_json.h"

namespace triton { namespace server {

Expand Down Expand Up @@ -266,6 +271,20 @@ class TraceManager {
// OTel context to store spans, created in the current trace
opentelemetry::context::Context otel_context_;

/// Prepares trace context to propagate to TRITONSERVER_InferenceTrace.
/// Trace context follows W3C Trace Context specification.
/// Ref. https://www.w3.org/TR/trace-context/.
/// OpenTelemetry ref:
/// https://github.com/open-telemetry/opentelemetry-cpp/blob/4bd64c9a336fd438d6c4c9dad2e6b61b0585311f/api/include/opentelemetry/trace/propagation/http_trace_context.h#L94-L113
///
/// \param span An OpenTelemetry span, which is used to extract
/// OpenTelemetry's trace_id and span_id.
/// \param buffer Buffer used when writing JSON representation of
/// OpenTelemetry's context.
void PrepareTraceContext(
opentelemetry::nostd::shared_ptr<otel_trace_api::Span> span,
triton::common::TritonJson::WriteBuffer* buffer);

private:
// OpenTelemetry SDK relies on system's clock for event timestamps.
// Triton Tracing records timestamps using steady_clock. This is a
Expand Down
Loading