From 24c3d60285a98a45a442bf1fac3be4296565c0e3 Mon Sep 17 00:00:00 2001 From: Olga Andreeva <124622579+oandreeva-nv@users.noreply.github.com> Date: Wed, 17 Jan 2024 10:32:34 -0800 Subject: [PATCH] Support for Context Propagation for OTel trace mode (#6785) Added support for OTel context propagation --------- Co-authored-by: Markus Hennerbichler Co-authored-by: Ryan McCormick --- qa/L0_trace/opentelemetry_unittest.py | 798 +++++++++++++++++++++----- qa/L0_trace/test.sh | 206 +++++-- qa/L0_trace/trace-config.yaml | 12 +- src/CMakeLists.txt | 4 +- src/grpc/infer_handler.cc | 6 +- src/grpc/infer_handler.h | 29 + src/grpc/stream_infer_handler.cc | 6 +- src/http_server.cc | 6 +- src/http_server.h | 30 + src/sagemaker_server.cc | 17 +- src/tracer.cc | 138 +++-- src/tracer.h | 37 +- 12 files changed, 999 insertions(+), 290 deletions(-) diff --git a/qa/L0_trace/opentelemetry_unittest.py b/qa/L0_trace/opentelemetry_unittest.py index 5055f4e88a..0e211b76b6 100644 --- a/qa/L0_trace/opentelemetry_unittest.py +++ b/qa/L0_trace/opentelemetry_unittest.py @@ -28,78 +28,122 @@ sys.path.append("../common") import json -import re +import queue +import shutil +import subprocess +import time import unittest +from functools import partial import numpy as np +import requests import test_util as tu import tritonclient.grpc as grpcclient import tritonclient.http as httpclient +from tritonclient.utils import InferenceServerException -EXPECTED_NUM_SPANS = 16 -# OpenTelemetry OStream exporter sets `parent_span_id` to "0000000000000000", -# if current span is a root span, i.e. there is no parent span. -# https://github.com/open-telemetry/opentelemetry-cpp/blob/b7fd057185c4ed2dff507b859cbe058b7609fb4a/exporters/ostream/src/span_exporter.cc#L78C54-L78C68 -NO_PARENT_SPAN = "0000000000000000" +NO_PARENT_SPAN_ID = "" +COLLECTOR_TIMEOUT = 10 + + +def callback(user_data, result, error): + if error: + user_data.put(error) + else: + user_data.put(result) + + +def prepare_data(client, is_binary=True): + inputs = [] + dim = 16 + input_data = np.arange(dim, dtype=np.int32) + inputs.append(client.InferInput("INPUT0", [1, dim], "INT32")) + inputs.append(client.InferInput("INPUT1", [1, dim], "INT32")) + + # Initialize the data + input_data = np.expand_dims(input_data, axis=0) + + if is_binary: + inputs[0].set_data_from_numpy(input_data) + inputs[1].set_data_from_numpy(input_data) + else: + inputs[0].set_data_from_numpy(input_data, binary_data=is_binary) + inputs[1].set_data_from_numpy(input_data, binary_data=is_binary) + + return inputs + + +def send_bls_request(model_name="simple", headers=None): + with httpclient.InferenceServerClient("localhost:8000") as client: + inputs = prepare_data(httpclient) + inputs.append(httpclient.InferInput("MODEL_NAME", [1], "BYTES")) + inputs[-1].set_data_from_numpy(np.array([model_name], dtype=np.object_)) + client.infer("bls_simple", inputs, headers=headers) class OpenTelemetryTest(tu.TestResultCollector): def setUp(self): - # Extracted spans are in json-like format, thus data needs to be - # post-processed, so that `json` could accept it for further - # processing - with open("trace_collector.log", "rt") as f: - data = f.read() - # Removing new lines and tabs around `{` - json_string = re.sub("\n\t{\n\t", "{", data) - # `resources` field is a dictionary, so adding `{` and`}` - # in the next 2 transformations, `instr-lib` is a next field, - # so whatever goes before it, belongs to `resources`. - json_string = re.sub( - "resources : \n\t", "resources : {\n\t", json_string - ) - json_string = re.sub( - "\n instr-lib :", "}\n instr-lib :", json_string - ) - # `json`` expects "key":"value" format, some fields in the - # data have empty string as value, so need to add `"",` - json_string = re.sub(": \n\t", ':"",', json_string) - json_string = re.sub(": \n", ':"",', json_string) - # Extracted data missing `,' after each key-value pair, - # which `json` exppects - json_string = re.sub("\n|\n\t", ",", json_string) - # Removing tabs - json_string = re.sub("\t", "", json_string) - # `json` expects each key and value have `"`'s, so adding them to - # every word/number/alpha-numeric entry - json_string = re.sub(r"\b([\w.-]+)\b", r'"\1"', json_string) - # `span kind`` represents one key - json_string = re.sub('"span" "kind"', '"span kind"', json_string) - # Removing extra `,` - json_string = re.sub("{,", "{", json_string) - json_string = re.sub(",}", "}", json_string) - # Adding `,` between dictionary entries - json_string = re.sub("}{", "},{", json_string) - # `events` is a list of dictionaries, `json` will accept it in the - # form of "events" : [{....}, {.....}, ...] - json_string = re.sub( - '"events" : {', '"events" : [{', json_string - ) - # Closing `events`' list of dictionaries - json_string = re.sub('}, "links"', '}], "links"', json_string) - # Last 2 symbols are not needed - json_string = json_string[:-2] - # Since now `json_string` is a string, which represents dictionaries, - # we put it into one dictionary, so that `json` could read it as one. - json_string = '{ "spans" :[' + json_string + "] }" - self.spans = json.loads(json_string)["spans"] - + self.collector_subprocess = subprocess.Popen( + ["./otelcol", "--config", "./trace-config.yaml"] + ) + time.sleep(5) + self.filename = "collected_traces.json" + # This simulates OTel context being injected on client side. + # Format explained here: https://www.w3.org/TR/trace-context/#design-overview + # OTel code reference for extraction: + # https://github.com/open-telemetry/opentelemetry-cpp/blob/c4f39f2be8109fd1a3e047677c09cf47954b92db/api/include/opentelemetry/trace/propagation/http_trace_context.h#L165 + # Essentially, this is what will be injected to headers/metadata + # on the client side. Code reference: + # https://github.com/open-telemetry/opentelemetry-cpp/blob/c4f39f2be8109fd1a3e047677c09cf47954b92db/api/include/opentelemetry/trace/propagation/http_trace_context.h#L91 + # Format is: 00-traceId-spanId-traceFlags + # By simply adding this header during tests, we imitate + # that on client side OTel Propagator injected it to request. + self.client_headers = dict( + {"traceparent": "00-0af7651916cd43dd8448eb211c12666c-b7ad6b7169242424-01"} + ) self.simple_model_name = "simple" self.ensemble_model_name = "ensemble_add_sub_int32_int32_int32" self.bls_model_name = "bls_simple" + self.test_models = [ + self.simple_model_name, + self.ensemble_model_name, + self.bls_model_name, + ] self.root_span = "InferRequest" + def tearDown(self): + self.collector_subprocess.kill() + self.collector_subprocess.wait() + time.sleep(5) + test_name = unittest.TestCase.id(self).split(".")[-1] + shutil.copyfile(self.filename, self.filename + "_" + test_name + ".log") + + def _parse_trace_log(self, trace_log): + """ + Helper function that parses file, containing collected traces. + + Args: + trace_log (str): Name of a file, containing all traces. + + Returns: + traces (List[dict]): List of json objects, representing each span. + """ + traces = [] + with open(trace_log) as f: + for json_obj in f: + entry = json.loads(json_obj) + traces.append(entry) + + return traces + def _check_events(self, span_name, events): + """ + Helper function that verifies passed events contain expected entries. + + Args: + span_name (str): name of a span. + events (List[str]): list of event names, collected for the span with the name `span_name`. + """ root_events_http = [ "HTTP_RECV_START", "HTTP_RECV_END", @@ -144,7 +188,7 @@ def _check_events(self, span_name, events): self.assertFalse(all(entry in events for entry in request_events)) self.assertFalse(all(entry in events for entry in compute_events)) - elif span_name == self.simple_model_name: + elif span_name in self.test_models: # Check that all request related events (and only them) # are recorded in request span self.assertTrue(all(entry in events for entry in request_events)) @@ -153,121 +197,559 @@ def _check_events(self, span_name, events): ) self.assertFalse(all(entry in events for entry in compute_events)) - def _check_parent(self, child_span, parent_span): - # Check that child and parent span have the same trace_id - # and child's `parent_span_id` is the same as parent's `span_id` - self.assertEqual(child_span["trace_id"], parent_span["trace_id"]) - self.assertNotEqual( - child_span["parent_span_id"], - NO_PARENT_SPAN, - "child span does not have parent span id specified", + def _test_resource_attributes(self, attributes): + """ + Helper function that verifies passed span attributes. + Currently only test 2 attributes, specified upon tritonserver start: + + --trace-config=opentelemetry,resource=test.key=test.value + and + --trace-config=opentelemetry,resource=service.name=test_triton + + Args: + attributes (List[dict]): list of attributes, collected for a span. + """ + expected_service_name = dict( + {"key": "service.name", "value": {"stringValue": "test_triton"}} + ) + expected_test_key_value = dict( + {"key": "test.key", "value": {"stringValue": "test.value"}} ) + self.assertIn( + expected_service_name, + attributes, + "Expected entry: {}, was not found in the set of collected attributes: {}".format( + expected_service_name, attributes + ), + ) + self.assertIn( + expected_test_key_value, + attributes, + "Expected entry: {}, was not found in the set of collected attributes: {}".format( + expected_test_key_value, attributes + ), + ) + + def _verify_contents(self, spans, expected_counts): + """ + Helper function that: + * iterates over `spans` and for every span it verifies that proper events are collected + * verifies that `spans` has expected number of total spans collected + * verifies that `spans` contains expected number different spans, + specified in `expected_counts` in the form: + span_name : #expected_number_of_entries + + Args: + spans (List[dict]): list of json objects, extracted from the trace and + containing span info. For this test `name` + and `events` are required. + expected_counts (dict): dictionary, containing expected spans in the form: + span_name : #expected_number_of_entries + """ + + span_names = [] + for span in spans: + # Check that collected spans have proper events recorded + span_name = span[0]["name"] + span_names.append(span_name) + span_events = span[0]["events"] + event_names_only = [event["name"] for event in span_events] + self._check_events(span_name, event_names_only) + self.assertEqual( - child_span["parent_span_id"], - parent_span["span_id"], - "child {} , parent {}".format(child_span, parent_span), - ) - - def test_spans(self): - parsed_spans = [] - - # Check that collected spans have proper events recorded - for span in self.spans: - span_name = span["name"] - self._check_events(span_name, str(span["events"])) - parsed_spans.append(span_name) - - # There should be 16 spans in total: - # 3 for http request, 3 for grpc request, 4 for ensemble, 6 for bls - self.assertEqual(len(self.spans), EXPECTED_NUM_SPANS) - # We should have 5 compute spans - self.assertEqual(parsed_spans.count("compute"), 5) - # 7 request spans - # (4 named simple - same as our model name, 2 ensemble, 1 bls) - self.assertEqual(parsed_spans.count(self.simple_model_name), 4) - self.assertEqual(parsed_spans.count(self.ensemble_model_name), 2) - self.assertEqual(parsed_spans.count(self.bls_model_name), 1) - # 4 root spans - self.assertEqual(parsed_spans.count(self.root_span), 4) - - def test_nested_spans(self): - # First 3 spans in `self.spans` belong to HTTP request - # They are recorded in the following order: - # compute_span [idx 0] , request_span [idx 1], root_span [idx 2]. - # compute_span should be a child of request_span - # request_span should be a child of root_span - for child, parent in zip(self.spans[:3], self.spans[1:3]): - self._check_parent(child, parent) - - # Next 3 spans in `self.spans` belong to GRPC request - # Order of spans and their relationship described earlier - for child, parent in zip(self.spans[3:6], self.spans[4:6]): - self._check_parent(child, parent) - - # Next 4 spans in `self.spans` belong to ensemble request - # Order of spans: compute span - request span - request span - root span - for child, parent in zip(self.spans[6:10], self.spans[7:10]): - self._check_parent(child, parent) - - # Final 6 spans in `self.spans` belong to bls with ensemble request - # Order of spans: - # compute span - request span (simple) - request span (ensemble)- - # - compute (for bls) - request (bls) - root span - # request span (ensemble) and compute (for bls) are children of - # request (bls) - children = self.spans[10:] - parents = (self.spans[11:13], self.spans[14], self.spans[14:]) - for child, parent in zip(children, parents[0]): - self._check_parent(child, parent) - - def test_resource_attributes(self): - for span in self.spans: - self.assertIn("test.key", span["resources"]) - self.assertEqual("test.value", span["resources"]["test.key"]) - self.assertIn("service.name", span["resources"]) - self.assertEqual("test_triton", span["resources"]["service.name"]) - - -def prepare_data(client): - inputs = [] - input0_data = np.full(shape=(1, 16), fill_value=-1, dtype=np.int32) - input1_data = np.full(shape=(1, 16), fill_value=-1, dtype=np.int32) + len(span_names), + sum(expected_counts.values()), + "Unexpeced number of span names collected", + ) + for name, count in expected_counts.items(): + self.assertEqual( + span_names.count(name), + count, + "Unexpeced number of " + name + " spans collected", + ) - inputs.append(client.InferInput("INPUT0", [1, 16], "INT32")) - inputs.append(client.InferInput("INPUT1", [1, 16], "INT32")) + def _verify_nesting(self, spans, expected_parent_span_dict): + """ + Helper function that checks parent-child relationships between + collected spans are the same as in `expected_parent_span_dict`. + + Args: + spans (List[dict]): list of json objects, extracted from the trace and + containing span info. For this test `name` + and `events` are required. + expected_parent_span_dict (dict): dictionary, containing expected + parents and children in the dictionary form: + (str) : (List[str]) + """ + seen_spans = {} + for span in spans: + cur_span = span[0]["spanId"] + seen_spans[cur_span] = span[0]["name"] + + parent_child_dict = {} + for span in spans: + cur_parent = span[0]["parentSpanId"] + cur_span = span[0]["name"] + if cur_parent in seen_spans.keys(): + parent_name = seen_spans[cur_parent] + if parent_name not in parent_child_dict: + parent_child_dict[parent_name] = [] + parent_child_dict[parent_name].append(cur_span) + + for key in parent_child_dict.keys(): + parent_child_dict[key].sort() + + self.assertDictEqual(parent_child_dict, expected_parent_span_dict) + + def _verify_headers_propagated_from_client_if_any(self, root_span, headers): + """ + Helper function that checks traceparent's ids, passed in clients + headers/metadata was picked up on the server side. + If `headers` are None, checks that `root_span` does not have + `parentSpanId` specified. + + Args: + root_span (List[dict]): a json objects, extracted from the trace and + containing root span info. For this test `traceID` + and `parentSpanId` are required. + expected_parent_span_dict (dict): dictionary, containing expected + parents and children in the dictionary form: + (str) : (List[str]) + """ + parent_span_id = NO_PARENT_SPAN_ID + + if headers != None: + parent_span_id = headers["traceparent"].split("-")[2] + parent_trace_id = headers["traceparent"].split("-")[1] + self.assertEqual( + root_span["traceId"], + parent_trace_id, + "Child and parent trace ids do not match! child's trace id = {} , expected trace id = {}".format( + root_span["traceId"], parent_trace_id + ), + ) - # Initialize the data - inputs[0].set_data_from_numpy(input0_data) - inputs[1].set_data_from_numpy(input1_data) + self.assertEqual( + root_span["parentSpanId"], + parent_span_id, + "Child and parent span ids do not match! child's parentSpanId = {} , expected parentSpanId {}".format( + root_span["parentSpanId"], parent_span_id + ), + ) - return inputs + def _test_trace( + self, + headers, + expected_number_of_spans, + expected_counts, + expected_parent_span_dict, + ): + """ + Helper method that defines the general test scenario for a trace, + described as follows. + + 1. Parse trace log, exported by OTel collector in self.filename. + 2. For each test we re-start OTel collector, so trace log should + have only 1 trace. + 3. Test that reported resource attributes contain manually specified + at `tritonserver` start time. Currently only test 2 attributes, + specified upon tritonserver start: + + --trace-config=opentelemetry,resource=test.key=test.value + and + --trace-config=opentelemetry,resource=service.name=test_triton + 4. Verifies that every collected span, has expected contents + 5. Verifies parent - child span relationships + 6. Verifies that OTel context was propagated from client side + to server side through headers. For cases, when headers for + context propagation were not specified, checks that root_span has + no `parentSpanId` specified. + + Args: + headers (dict | None): dictionary, containing OTel headers, + specifying OTel context. + expected_number_of_spans (int): expected number of collected spans. + expected_counts(dict): dictionary, containing expected spans in the form: + span_name : #expected_number_of_entries + expected_parent_span_dict (dict): dictionary, containing expected + parents and children in the dictionary form: + (str) : (List[str]) + """ + time.sleep(COLLECTOR_TIMEOUT) + traces = self._parse_trace_log(self.filename) + self.assertEqual(len(traces), 1, "Unexpected number of traces collected") + self._test_resource_attributes( + traces[0]["resourceSpans"][0]["resource"]["attributes"] + ) + + parsed_spans = [ + entry["scopeSpans"][0]["spans"] for entry in traces[0]["resourceSpans"] + ] + root_span = [ + entry[0] for entry in parsed_spans if entry[0]["name"] == "InferRequest" + ][0] + self.assertEqual(len(parsed_spans), expected_number_of_spans) + + self._verify_contents(parsed_spans, expected_counts) + self._verify_nesting(parsed_spans, expected_parent_span_dict) + self._verify_headers_propagated_from_client_if_any(root_span, headers) + + def _test_simple_trace(self, headers=None): + """ + Helper function, that specifies expected parameters to evaluate trace, + collected from running 1 inference request for `simple` model. + """ + expected_number_of_spans = 3 + expected_counts = dict( + {"compute": 1, self.simple_model_name: 1, self.root_span: 1} + ) + expected_parent_span_dict = dict( + {"InferRequest": ["simple"], "simple": ["compute"]} + ) + self._test_trace( + headers=headers, + expected_number_of_spans=expected_number_of_spans, + expected_counts=expected_counts, + expected_parent_span_dict=expected_parent_span_dict, + ) + + def _test_bls_trace(self, headers=None): + """ + Helper function, that specifies expected parameters to evaluate trace, + collected from running 1 inference request for `bls_simple` model. + """ + expected_number_of_spans = 6 + expected_counts = dict( + { + "compute": 2, + self.simple_model_name: 1, + self.ensemble_model_name: 1, + self.bls_model_name: 1, + self.root_span: 1, + } + ) + expected_parent_span_dict = dict( + { + "InferRequest": ["bls_simple"], + "bls_simple": ["compute", "ensemble_add_sub_int32_int32_int32"], + "ensemble_add_sub_int32_int32_int32": ["simple"], + "simple": ["compute"], + } + ) + for key in expected_parent_span_dict.keys(): + expected_parent_span_dict[key].sort() + + self._test_trace( + headers=headers, + expected_number_of_spans=expected_number_of_spans, + expected_counts=expected_counts, + expected_parent_span_dict=expected_parent_span_dict, + ) + + def _test_ensemble_trace(self, headers=None): + """ + Helper function, that specifies expected parameters to evaluate trace, + collected from running 1 inference request for an + `ensemble_add_sub_int32_int32_int32` model. + """ + expected_number_of_spans = 4 + expected_counts = dict( + { + "compute": 1, + self.simple_model_name: 1, + self.ensemble_model_name: 1, + self.root_span: 1, + } + ) + expected_parent_span_dict = dict( + { + "InferRequest": ["ensemble_add_sub_int32_int32_int32"], + "ensemble_add_sub_int32_int32_int32": ["simple"], + "simple": ["compute"], + } + ) + for key in expected_parent_span_dict.keys(): + expected_parent_span_dict[key].sort() + + self._test_trace( + headers=headers, + expected_number_of_spans=expected_number_of_spans, + expected_counts=expected_counts, + expected_parent_span_dict=expected_parent_span_dict, + ) + + def test_http_trace_simple_model(self): + """ + Tests trace, collected from executing one inference request + for a `simple` model and HTTP client. + """ + triton_client_http = httpclient.InferenceServerClient( + "localhost:8000", verbose=True + ) + inputs = prepare_data(httpclient) + triton_client_http.infer(self.simple_model_name, inputs) + + self._test_simple_trace() + + def test_http_trace_simple_model_context_propagation(self): + """ + Tests trace, collected from executing one inference request + for a `simple` model, HTTP client and context propagation, + i.e. client specifies OTel headers, defined in `self.client_headers`. + """ + triton_client_http = httpclient.InferenceServerClient( + "localhost:8000", verbose=True + ) + inputs = prepare_data(httpclient) + triton_client_http.infer( + self.simple_model_name, inputs, headers=self.client_headers + ) + self._test_simple_trace(headers=self.client_headers) -def prepare_traces(): - triton_client_http = httpclient.InferenceServerClient( - "localhost:8000", verbose=True - ) - triton_client_grpc = grpcclient.InferenceServerClient( - "localhost:8001", verbose=True - ) - inputs = prepare_data(httpclient) - triton_client_http.infer("simple", inputs) + def test_grpc_trace_simple_model(self): + """ + Tests trace, collected from executing one inference request + for a `simple` model and GRPC client. + """ + triton_client_grpc = grpcclient.InferenceServerClient( + "localhost:8001", verbose=True + ) + inputs = prepare_data(grpcclient) + triton_client_grpc.infer(self.simple_model_name, inputs) + + self._test_simple_trace() + + def test_grpc_trace_simple_model_context_propagation(self): + """ + Tests trace, collected from executing one inference request + for a `simple` model, GRPC client and context propagation, + i.e. client specifies OTel headers, defined in `self.client_headers`. + """ + triton_client_grpc = grpcclient.InferenceServerClient( + "localhost:8001", verbose=True + ) + inputs = prepare_data(grpcclient) + triton_client_grpc.infer( + self.simple_model_name, inputs, headers=self.client_headers + ) - inputs = prepare_data(grpcclient) - triton_client_grpc.infer("simple", inputs) + self._test_simple_trace(headers=self.client_headers) - inputs = prepare_data(httpclient) - triton_client_http.infer("ensemble_add_sub_int32_int32_int32", inputs) + def test_streaming_grpc_trace_simple_model(self): + """ + Tests trace, collected from executing one inference request + for a `simple` model and GRPC streaming client. + """ + triton_client_grpc = grpcclient.InferenceServerClient( + "localhost:8001", verbose=True + ) + user_data = queue.Queue() + triton_client_grpc.start_stream(callback=partial(callback, user_data)) + + inputs = prepare_data(grpcclient) + triton_client_grpc.async_stream_infer(self.simple_model_name, inputs) + result = user_data.get() + self.assertIsNot(result, InferenceServerException) + triton_client_grpc.stop_stream() + + self._test_simple_trace() + + def test_streaming_grpc_trace_simple_model_context_propagation(self): + """ + Tests trace, collected from executing one inference request + for a `simple` model, GRPC streaming client and context propagation, + i.e. client specifies OTel headers, defined in `self.client_headers`. + """ + triton_client_grpc = grpcclient.InferenceServerClient( + "localhost:8001", verbose=True + ) + user_data = queue.Queue() + triton_client_grpc.start_stream( + callback=partial(callback, user_data), + headers=self.client_headers, + ) - send_bls_request(model_name="ensemble_add_sub_int32_int32_int32") + inputs = prepare_data(grpcclient) + triton_client_grpc.async_stream_infer(self.simple_model_name, inputs) + result = user_data.get() + self.assertIsNot(result, InferenceServerException) + triton_client_grpc.stop_stream() + + self._test_simple_trace(headers=self.client_headers) + + def test_http_trace_bls_model(self): + """ + Tests trace, collected from executing one inference request + for a `bls_simple` model and HTTP client. + """ + send_bls_request(model_name=self.ensemble_model_name) + + self._test_bls_trace() + + def test_http_trace_bls_model_context_propagation(self): + """ + Tests trace, collected from executing one inference request + for a `bls_simple` model, HTTP client and context propagation, + i.e. client specifies OTel headers, defined in `self.client_headers`. + """ + send_bls_request( + model_name=self.ensemble_model_name, headers=self.client_headers + ) + self._test_bls_trace(headers=self.client_headers) -def send_bls_request(model_name="simple"): - with httpclient.InferenceServerClient("localhost:8000") as client: + def test_http_trace_ensemble_model(self): + """ + Tests trace, collected from executing one inference request + for a `ensemble_add_sub_int32_int32_int32` model and HTTP client. + """ + triton_client_http = httpclient.InferenceServerClient( + "localhost:8000", verbose=True + ) inputs = prepare_data(httpclient) - inputs.append(httpclient.InferInput("MODEL_NAME", [1], "BYTES")) - inputs[-1].set_data_from_numpy(np.array([model_name], dtype=np.object_)) - client.infer("bls_simple", inputs) + triton_client_http.infer(self.ensemble_model_name, inputs) + + self._test_ensemble_trace() + + def test_http_trace_ensemble_model_context_propagation(self): + """ + Tests trace, collected from executing one inference request + for a `ensemble_add_sub_int32_int32_int32` model, HTTP client + and context propagation, i.e. client specifies OTel headers, + defined in `self.client_headers`. + """ + triton_client_http = httpclient.InferenceServerClient( + "localhost:8000", verbose=True + ) + inputs = prepare_data(httpclient) + triton_client_http.infer( + self.ensemble_model_name, inputs, headers=self.client_headers + ) + + self._test_ensemble_trace(headers=self.client_headers) + + def test_http_trace_triggered(self): + triton_client_http = httpclient.InferenceServerClient("localhost:8000") + triton_client_http.update_trace_settings(settings={"trace_rate": "5"}) + + expected_trace_rate = "5" + simple_model_trace_settings = triton_client_http.get_trace_settings( + model_name=self.simple_model_name + ) + + self.assertEqual( + expected_trace_rate, + simple_model_trace_settings["trace_rate"], + "Unexpected model trace rate settings after its update. Expected {}, but got {}".format( + expected_trace_rate, simple_model_trace_settings["trace_rate"] + ), + ) + + inputs = prepare_data(httpclient) + for _ in range(5): + triton_client_http.infer(self.ensemble_model_name, inputs) + time.sleep(COLLECTOR_TIMEOUT) + + expected_accumulated_traces = 1 + traces = self._parse_trace_log(self.filename) + # Should only be 1 trace collected + self.assertEqual( + len(traces), + expected_accumulated_traces, + "Unexpected number of traces collected", + ) + + for _ in range(5): + triton_client_http.infer( + self.ensemble_model_name, inputs, headers=self.client_headers + ) + expected_accumulated_traces += 1 + time.sleep(COLLECTOR_TIMEOUT) + + traces = self._parse_trace_log(self.filename) + # Should only be 1 trace collected + self.assertEqual( + len(traces), + expected_accumulated_traces, + "Unexpected number of traces collected", + ) + + # Restore trace rate to 1 + triton_client_http.update_trace_settings(settings={"trace_rate": "1"}) + expected_trace_rate = "1" + simple_model_trace_settings = triton_client_http.get_trace_settings( + model_name=self.simple_model_name + ) + + self.assertEqual( + expected_trace_rate, + simple_model_trace_settings["trace_rate"], + "Unexpected model trace rate settings after its update. Expected {}, but got {}".format( + expected_trace_rate, simple_model_trace_settings["trace_rate"] + ), + ) + + def test_sagemaker_invocation_trace_simple_model_context_propagation(self): + """ + Tests trace, collected from executing one inference request + for a `simple` model, SageMaker (invocations) and context propagation, + i.e. client specifies OTel headers, defined in `self.client_headers`. + """ + inputs = prepare_data(httpclient, is_binary=False) + request_body, _ = httpclient.InferenceServerClient.generate_request_body(inputs) + self.client_headers["Content-Type"] = "application/json" + r = requests.post( + "http://localhost:8080/invocations", + data=request_body, + headers=self.client_headers, + ) + r.raise_for_status() + self.assertEqual( + r.status_code, + 200, + "Expected status code 200, received {}".format(r.status_code), + ) + self._test_simple_trace(headers=self.client_headers) + + def test_sagemaker_invoke_trace_simple_model_context_propagation(self): + """ + Tests trace, collected from executing one inference request + for a `simple` model, SageMaker (invoke) and context propagation, + i.e. client specifies OTel headers, defined in `self.client_headers`. + """ + # Loading model for this test + model_url = "/opt/ml/models/123456789abcdefghi/model" + request_body = {"model_name": self.simple_model_name, "url": model_url} + headers = {"Content-Type": "application/json"} + r = requests.post( + "http://localhost:8080/models", + data=json.dumps(request_body), + headers=headers, + ) + time.sleep(5) # wait for model to load + self.assertEqual( + r.status_code, + 200, + "Expected status code 200, received {}".format(r.status_code), + ) + + inputs = prepare_data(httpclient, is_binary=False) + request_body, _ = httpclient.InferenceServerClient.generate_request_body(inputs) + + self.client_headers["Content-Type"] = "application/json" + invoke_url = "{}/{}/invoke".format( + "http://localhost:8080/models", self.simple_model_name + ) + r = requests.post(invoke_url, data=request_body, headers=self.client_headers) + r.raise_for_status() + self.assertEqual( + r.status_code, + 200, + "Expected status code 200, received {}".format(r.status_code), + ) + time.sleep(5) + self._test_simple_trace(headers=self.client_headers) if __name__ == "__main__": diff --git a/qa/L0_trace/test.sh b/qa/L0_trace/test.sh index b4c47034bc..1d553336e9 100755 --- a/qa/L0_trace/test.sh +++ b/qa/L0_trace/test.sh @@ -698,28 +698,40 @@ set +e # Check opentelemetry trace exporter sends proper info. # A helper python script starts listening on $OTLP_PORT, where # OTLP exporter sends traces. -export TRITON_OPENTELEMETRY_TEST='false' OTLP_PORT=10000 -OTEL_COLLECTOR_DIR=./opentelemetry-collector -OTEL_COLLECTOR=./opentelemetry-collector/bin/otelcorecol_* +OTEL_COLLECTOR=./otelcol OTEL_COLLECTOR_LOG="./trace_collector_http_exporter.log" -# Building the latest version of the OpenTelemetry collector. +# Installing OpenTelemetry collector (v0.91.0). # Ref: https://opentelemetry.io/docs/collector/getting-started/#local -if [ -d "$OTEL_COLLECTOR_DIR" ]; then rm -Rf $OTEL_COLLECTOR_DIR; fi -git clone --depth 1 --branch v0.82.0 https://github.com/open-telemetry/opentelemetry-collector.git -cd $OTEL_COLLECTOR_DIR -make install-tools -make otelcorecol -cd .. -$OTEL_COLLECTOR --config ./trace-config.yaml >> $OTEL_COLLECTOR_LOG 2>&1 & COLLECTOR_PID=$! - +curl --proto '=https' --tlsv1.2 -fOL https://github.com/open-telemetry/opentelemetry-collector-releases/releases/download/v0.91.0/otelcol_0.91.0_linux_amd64.tar.gz +tar -xvf otelcol_0.91.0_linux_amd64.tar.gz -SERVER_ARGS="--trace-config=level=TIMESTAMPS --trace-config=rate=1 \ - --trace-config=count=100 --trace-config=mode=opentelemetry \ +rm collected_traces.json* +# Unittests then check that produced spans have expected format and events +OPENTELEMETRY_TEST=opentelemetry_unittest.py +OPENTELEMETRY_LOG="opentelemetry_unittest.log" +EXPECTED_NUM_TESTS="13" + +# Set up repo and args for SageMaker +export SAGEMAKER_TRITON_DEFAULT_MODEL_NAME="simple" +MODEL_PATH="/opt/ml/models/123456789abcdefghi/model" +rm -r ${MODEL_PATH} +mkdir -p "${MODEL_PATH}" +cp -r $DATADIR/$MODELBASE/* ${MODEL_PATH} && \ + rm -r ${MODEL_PATH}/2 && rm -r ${MODEL_PATH}/3 && \ + sed -i "s/onnx_int32_int32_int32/simple/" ${MODEL_PATH}/config.pbtxt + + +SERVER_ARGS="--allow-sagemaker=true --model-control-mode=explicit \ + --load-model=simple --load-model=ensemble_add_sub_int32_int32_int32 \ + --load-model=bls_simple --trace-config=level=TIMESTAMPS \ + --trace-config=rate=1 --trace-config=count=-1 --trace-config=mode=opentelemetry \ + --trace-config=opentelemetry,resource=test.key=test.value \ + --trace-config=opentelemetry,resource=service.name=test_triton \ --trace-config=opentelemetry,url=localhost:$OTLP_PORT/v1/traces \ --model-repository=$MODELSDIR" -SERVER_LOG="./inference_server_otel_http_exporter.log" +SERVER_LOG="./inference_server_otel_otelcol_exporter.log" run_server if [ "$SERVER_PID" == "0" ]; then @@ -728,38 +740,97 @@ if [ "$SERVER_PID" == "0" ]; then exit 1 fi -$SIMPLE_HTTP_CLIENT >>$CLIENT_LOG 2>&1 +set +e -set -e +python $OPENTELEMETRY_TEST >>$OPENTELEMETRY_LOG 2>&1 +if [ $? -ne 0 ]; then + cat $OPENTELEMETRY_LOG + RET=1 +else + check_test_results $TEST_RESULT_FILE $EXPECTED_NUM_TESTS + if [ $? -ne 0 ]; then + cat $OPENTELEMETRY_LOG + echo -e "\n***\n*** Test Result Verification Failed\n***" + RET=1 + fi +fi +set -e kill $SERVER_PID wait $SERVER_PID +set +e -kill $COLLECTOR_PID -wait $COLLECTOR_PID +# Testing OTel WAR with trace rate = 0 +rm collected_traces.json + +OTEL_COLLECTOR=./otelcol +OTEL_COLLECTOR_LOG="./trace_collector_exporter.log" +$OTEL_COLLECTOR --config ./trace-config.yaml >> $OTEL_COLLECTOR_LOG 2>&1 & COLLECTOR_PID=$! + +SERVER_ARGS="--trace-config=level=TIMESTAMPS --trace-config=rate=0\ + --trace-config=count=-1 --trace-config=mode=opentelemetry \ + --trace-config=opentelemetry,url=localhost:$OTLP_PORT/v1/traces \ + --model-repository=$MODELSDIR" +SERVER_LOG="./inference_server_otel_WAR.log" + +run_server +if [ "$SERVER_PID" == "0" ]; then + echo -e "\n***\n*** Failed to start $SERVER\n***" + cat $SERVER_LOG + exit 1 +fi + +get_trace_setting "bls_simple" +assert_curl_success "Failed to obtain trace settings for 'simple' model" + +if [ `grep -c "\"trace_level\":\[\"TIMESTAMPS\"\]" ./curl.out` != "1" ]; then + RET=1 +fi +if [ `grep -c "\"trace_rate\":\"0\"" ./curl.out` != "1" ]; then + RET=1 +fi +if [ `grep -c "\"trace_count\":\"-1\"" ./curl.out` != "1" ]; then + RET=1 +fi set +e +# Send bls requests to make sure bls_simple model is NOT traced +for p in {1..10}; do + python -c 'import opentelemetry_unittest; \ + opentelemetry_unittest.send_bls_request(model_name="ensemble_add_sub_int32_int32_int32")' >> client_update.log 2>&1 +done -if ! [[ -s $OTEL_COLLECTOR_LOG && `grep -c 'InstrumentationScope triton-server' $OTEL_COLLECTOR_LOG` == 3 ]] ; then - echo -e "\n***\n*** HTTP exporter test failed.\n***" - cat $OTEL_COLLECTOR_LOG +if [ -s collected_traces.json ] ; then + echo -e "\n***\n*** collected_traces.json should be empty, but it is not.\n***" exit 1 fi +# Send 1 bls request with OTel context to make sure it is traced +python -c 'import opentelemetry_unittest; \ + opentelemetry_unittest.send_bls_request(model_name="ensemble_add_sub_int32_int32_int32", \ + headers={"traceparent": "00-0af7651916cd43dd8448eb211c12666c-b7ad6b7169242424-01"} \ + )' >> client_update.log 2>&1 -# Unittests then check that produced spans have expected format and events -OPENTELEMETRY_TEST=opentelemetry_unittest.py -OPENTELEMETRY_LOG="opentelemetry_unittest.log" -EXPECTED_NUM_TESTS="3" +sleep 20 -export TRITON_OPENTELEMETRY_TEST='true' +if ! [ -s collected_traces.json ] ; then + echo -e "\n***\n*** collected_traces.json should contain OTel trace, but it is not. \n***" + exit 1 +fi -SERVER_ARGS="--trace-config=level=TIMESTAMPS --trace-config=rate=1 \ - --trace-config=count=100 --trace-config=mode=opentelemetry \ - --trace-config=opentelemetry,resource=test.key=test.value \ - --trace-config=opentelemetry,resource=service.name=test_triton \ +set -e +kill $COLLECTOR_PID +wait $COLLECTOR_PID +kill $SERVER_PID +wait $SERVER_PID +set +e + +# Test that only traces with OTel Context are collected after count goes to 0 +SERVER_ARGS="--trace-config=level=TIMESTAMPS --trace-config=rate=5\ + --trace-config=count=1 --trace-config=mode=opentelemetry \ + --trace-config=opentelemetry,url=localhost:$OTLP_PORT/v1/traces \ --model-repository=$MODELSDIR" -SERVER_LOG="./inference_server_otel_ostream_exporter.log" +SERVER_LOG="./inference_server_otel_WAR.log" run_server if [ "$SERVER_PID" == "0" ]; then @@ -768,41 +839,62 @@ if [ "$SERVER_PID" == "0" ]; then exit 1 fi -set +e -# Preparing traces for unittest. -# Note: running this separately, so that I could extract spans with `grep` -# from server log later. -python -c 'import opentelemetry_unittest; \ - opentelemetry_unittest.prepare_traces()' >>$CLIENT_LOG 2>&1 -sleep 5 +rm collected_traces.json +$OTEL_COLLECTOR --config ./trace-config.yaml >> $OTEL_COLLECTOR_LOG 2>&1 & COLLECTOR_PID=$! -set -e +get_trace_setting "bls_simple" +assert_curl_success "Failed to obtain trace settings for 'simple' model" -kill $SERVER_PID -wait $SERVER_PID +if [ `grep -c "\"trace_level\":\[\"TIMESTAMPS\"\]" ./curl.out` != "1" ]; then + RET=1 +fi +if [ `grep -c "\"trace_rate\":\"5\"" ./curl.out` != "1" ]; then + RET=1 +fi +if [ `grep -c "\"trace_count\":\"1\"" ./curl.out` != "1" ]; then + RET=1 +fi set +e +# Send bls requests to make sure bls_simple model is NOT traced +for p in {1..20}; do + python -c 'import opentelemetry_unittest; \ + opentelemetry_unittest.send_bls_request(model_name="ensemble_add_sub_int32_int32_int32")' >> client_update.log 2>&1 +done -grep -z -o -P '({\n(?s).*}\n)' $SERVER_LOG >> trace_collector.log +sleep 20 -if ! [ -s trace_collector.log ] ; then - echo -e "\n***\n*** $SERVER_LOG did not contain any OpenTelemetry spans.\n***" +if ! [[ -s collected_traces.json && `grep -c "\"name\":\"InferRequest\"" ./collected_traces.json` == 1 && `grep -c "\"parentSpanId\":\"\"" ./collected_traces.json` == 1 ]] ; then + echo -e "\n***\n*** collected_traces.json should contain only 1 trace.\n***" + cat collected_traces.json exit 1 fi -# Unittest will not start until expected number of spans is collected. -python $OPENTELEMETRY_TEST >>$OPENTELEMETRY_LOG 2>&1 -if [ $? -ne 0 ]; then - cat $OPENTELEMETRY_LOG - RET=1 -else - check_test_results $TEST_RESULT_FILE $EXPECTED_NUM_TESTS - if [ $? -ne 0 ]; then - cat $OPENTELEMETRY_LOG - echo -e "\n***\n*** Test Result Verification Failed\n***" - RET=1 - fi +# Send 4 bls request with OTel context and 4 without to make sure it is traced +for p in {1..10}; do + python -c 'import opentelemetry_unittest; \ + opentelemetry_unittest.send_bls_request(model_name="ensemble_add_sub_int32_int32_int32", \ + headers={"traceparent": "00-0af7651916cd43dd8448eb211c12666c-b7ad6b7169242424-01"} \ + )' >> client_update.log 2>&1 + + python -c 'import opentelemetry_unittest; \ + opentelemetry_unittest.send_bls_request(model_name="ensemble_add_sub_int32_int32_int32" \ + )' >> client_update.log 2>&1 + + sleep 10 +done + +if ! [[ -s collected_traces.json && `grep -c "\"parentSpanId\":\"\"" ./collected_traces.json` == 1 && `grep -c "\"parentSpanId\":\"b7ad6b7169242424\"" ./collected_traces.json` == 10 ]] ; then + echo -e "\n***\n*** collected_traces.json should contain 11 OTel trace, but it is not. \n***" + exit 1 fi +set -e +kill $COLLECTOR_PID +wait $COLLECTOR_PID +kill $SERVER_PID +wait $SERVER_PID +set +e + exit $RET diff --git a/qa/L0_trace/trace-config.yaml b/qa/L0_trace/trace-config.yaml index f8fe2424c0..2948058adf 100644 --- a/qa/L0_trace/trace-config.yaml +++ b/qa/L0_trace/trace-config.yaml @@ -34,12 +34,18 @@ receivers: http: endpoint: 0.0.0.0:10000 +processors: + batch: + send_batch_size: 10 + timeout: 10s + exporters: - logging: - verbosity: detailed + file: + path: ./collected_traces.json service: pipelines: traces: receivers: [otlp] - exporters: [logging] + processors: [batch] + exporters: [file] diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 85cd1434fc..3ffe375477 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -507,9 +507,9 @@ if(${TRITON_ENABLE_TRACING}) tracer.cc tracer.h ) - if (NOT WIN32) - target_compile_features(tracing-library PRIVATE cxx_std_${TRITON_MIN_CXX_STANDARD}) + target_compile_features(tracing-library PRIVATE cxx_std_${TRITON_MIN_CXX_STANDARD}) + if (NOT WIN32) target_include_directories( tracing-library PRIVATE ${OPENTELEMETRY_CPP_INCLUDE_DIRS} diff --git a/src/grpc/infer_handler.cc b/src/grpc/infer_handler.cc index 30d93fa4f9..e179f0f34c 100644 --- a/src/grpc/infer_handler.cc +++ b/src/grpc/infer_handler.cc @@ -916,8 +916,10 @@ ModelInferHandler::Execute(InferHandler::State* state) if (err == nullptr) { TRITONSERVER_InferenceTrace* triton_trace = nullptr; #ifdef TRITON_ENABLE_TRACING - state->trace_ = - std::move(trace_manager_->SampleTrace(request.model_name())); + GrpcServerCarrier carrier(state->context_->ctx_.get()); + auto start_options = + trace_manager_->GetTraceStartOptions(carrier, request.model_name()); + state->trace_ = std::move(trace_manager_->SampleTrace(start_options)); if (state->trace_ != nullptr) { triton_trace = state->trace_->trace_; } diff --git a/src/grpc/infer_handler.h b/src/grpc/infer_handler.h index 36783e5912..1c96b0e1fe 100644 --- a/src/grpc/infer_handler.h +++ b/src/grpc/infer_handler.h @@ -1433,4 +1433,33 @@ class ModelInferHandler grpc_compression_level compression_level_; }; +#if !defined(_WIN32) && defined(TRITON_ENABLE_TRACING) +class GrpcServerCarrier : public otel_cntxt::propagation::TextMapCarrier { + public: + GrpcServerCarrier(::grpc::ServerContext* context) : context_(context) {} + GrpcServerCarrier() = default; + virtual opentelemetry::nostd::string_view Get( + opentelemetry::nostd::string_view key) const noexcept override + { + auto it = context_->client_metadata().find({key.data(), key.size()}); + if (it != context_->client_metadata().end()) { + return it->second.data(); + } + return ""; + } + + // Not required on server side + virtual void Set( + opentelemetry::nostd::string_view key, + opentelemetry::nostd::string_view value) noexcept override + { + return; + } + + ::grpc::ServerContext* context_; +}; +#else +using GrpcServerCarrier = void*; +#endif // TRITON_ENABLE_TRACING + }}} // namespace triton::server::grpc diff --git a/src/grpc/stream_infer_handler.cc b/src/grpc/stream_infer_handler.cc index 9c162ad644..306925c570 100644 --- a/src/grpc/stream_infer_handler.cc +++ b/src/grpc/stream_infer_handler.cc @@ -309,8 +309,10 @@ ModelStreamInferHandler::Process(InferHandler::State* state, bool rpc_ok) if (err == nullptr) { TRITONSERVER_InferenceTrace* triton_trace = nullptr; #ifdef TRITON_ENABLE_TRACING - state->trace_ = - std::move(trace_manager_->SampleTrace(request.model_name())); + GrpcServerCarrier carrier(state->context_->ctx_.get()); + auto start_options = + trace_manager_->GetTraceStartOptions(carrier, request.model_name()); + state->trace_ = std::move(trace_manager_->SampleTrace(start_options)); if (state->trace_ != nullptr) { triton_trace = state->trace_->trace_; } diff --git a/src/http_server.cc b/src/http_server.cc index 647c6d83de..d1bd9ce641 100644 --- a/src/http_server.cc +++ b/src/http_server.cc @@ -3055,11 +3055,13 @@ HTTPAPIServer::StartTrace( TRITONSERVER_InferenceTrace** triton_trace) { #ifdef TRITON_ENABLE_TRACING + HttpTextMapCarrier carrier(req->headers_in); + auto start_options = + trace_manager_->GetTraceStartOptions(carrier, model_name); std::shared_ptr trace; - trace = std::move(trace_manager_->SampleTrace(model_name)); + trace = std::move(trace_manager_->SampleTrace(start_options)); if (trace != nullptr) { *triton_trace = trace->trace_; - // Timestamps from evhtp are capture in 'req'. We record here // since this is the first place where we have access to trace // manager. diff --git a/src/http_server.h b/src/http_server.h index 10a9ed6388..9c0643db91 100644 --- a/src/http_server.h +++ b/src/http_server.h @@ -142,6 +142,36 @@ class HTTPMetricsServer : public HTTPServer { }; #endif // TRITON_ENABLE_METRICS +#if !defined(_WIN32) && defined(TRITON_ENABLE_TRACING) +class HttpTextMapCarrier : public otel_cntxt::propagation::TextMapCarrier { + public: + HttpTextMapCarrier(evhtp_kvs_t* headers) : headers_(headers) {} + HttpTextMapCarrier() = default; + virtual opentelemetry::nostd::string_view Get( + opentelemetry::nostd::string_view key) const noexcept override + { + std::string key_to_compare = key.data(); + auto it = evhtp_kv_find(headers_, key_to_compare.c_str()); + if (it != NULL) { + return opentelemetry::nostd::string_view(it); + } + return ""; + } + // Not required on server side + virtual void Set( + opentelemetry::nostd::string_view key, + opentelemetry::nostd::string_view value) noexcept override + { + return; + } + + evhtp_kvs_t* headers_; +}; +#else +using HttpTextMapCarrier = void*; +#endif + + // HTTP API server that implements KFServing community standard inference // protocols and extensions used by Triton. class HTTPAPIServer : public HTTPServer { diff --git a/src/sagemaker_server.cc b/src/sagemaker_server.cc index 5d6ca80d9e..28c8b688d3 100644 --- a/src/sagemaker_server.cc +++ b/src/sagemaker_server.cc @@ -459,21 +459,8 @@ SagemakerAPIServer::SageMakerMMEHandleInfer( // If tracing is enabled see if this request should be traced. TRITONSERVER_InferenceTrace* triton_trace = nullptr; -#ifdef TRITON_ENABLE_TRACING - std::shared_ptr trace; - if (err == nullptr) { - trace = std::move(trace_manager_->SampleTrace(model_name)); - if (trace != nullptr) { - triton_trace = trace->trace_; - - // Timestamps from evhtp are capture in 'req'. We record here - // since this is the first place where we have access to trace - // manager. - trace->CaptureTimestamp("HTTP_RECV_START", req->recv_start_ns); - trace->CaptureTimestamp("HTTP_RECV_END", req->recv_end_ns); - } - } -#endif // TRITON_ENABLE_TRACING + std::shared_ptr trace = + StartTrace(req, model_name, &triton_trace); // Create the inference request object which provides all information needed // for an inference. diff --git a/src/tracer.cc b/src/tracer.cc index c64d10ee10..8c633ee3ab 100644 --- a/src/tracer.cc +++ b/src/tracer.cc @@ -36,7 +36,6 @@ #include #endif // TRITON_ENABLE_GPU #ifndef _WIN32 -#include "opentelemetry/exporters/ostream/span_exporter_factory.h" #include "opentelemetry/exporters/otlp/otlp_http_exporter_factory.h" #include "opentelemetry/sdk/resource/semantic_conventions.h" namespace otlp = opentelemetry::exporter::otlp; @@ -289,19 +288,69 @@ TraceManager::GetTraceSetting( *filepath = trace_setting->file_->FileName(); } -std::shared_ptr -TraceManager::SampleTrace(const std::string& model_name) +void +TraceManager::GetTraceSetting( + const std::string& model_name, std::shared_ptr& trace_setting) { - std::shared_ptr trace_setting; - { - std::lock_guard r_lk(r_mu_); - auto m_it = model_settings_.find(model_name); - trace_setting = - (m_it == model_settings_.end()) ? global_setting_ : m_it->second; + std::lock_guard r_lk(r_mu_); + auto m_it = model_settings_.find(model_name); + trace_setting = + (m_it == model_settings_.end()) ? global_setting_ : m_it->second; +} + +TraceManager::TraceStartOptions +TraceManager::GetTraceStartOptions( + AbstractCarrier& carrier, const std::string& model_name) +{ + TraceManager::TraceStartOptions start_options; + GetTraceSetting(model_name, start_options.trace_setting); + if (start_options.trace_setting->mode_ == TRACE_MODE_OPENTELEMETRY) { +#ifndef _WIN32 + auto prop = + otel_cntxt::propagation::GlobalTextMapPropagator::GetGlobalPropagator(); + auto ctxt = otel_cntxt::Context(); + ctxt = prop->Extract(carrier, ctxt); + otel_trace_api::SpanContext span_context = + otel_trace_api::GetSpan(ctxt)->GetContext(); + if (span_context.IsValid()) { + start_options.propagated_context = ctxt; + start_options.force_sample = true; + } +#else + LOG_ERROR << "Unsupported trace mode: " + << TraceManager::InferenceTraceModeString( + start_options.trace_setting->mode_); +#endif // _WIN32 } - std::shared_ptr ts = trace_setting->SampleTrace(); + return start_options; +} + + +std::shared_ptr +TraceManager::SampleTrace(const TraceStartOptions& start_options) +{ + std::shared_ptr ts = + start_options.trace_setting->SampleTrace(start_options.force_sample); if (ts != nullptr) { - ts->setting_ = trace_setting; + ts->setting_ = start_options.trace_setting; + if (ts->setting_->mode_ == TRACE_MODE_OPENTELEMETRY) { +#ifndef _WIN32 + auto steady_timestamp_ns = + std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()) + .count(); + ts->otel_context_ = start_options.propagated_context; + opentelemetry::nostd::shared_ptr root_span; + root_span = ts->StartSpan( + "InferRequest", steady_timestamp_ns, otel_trace_api::kSpanKey); + // Storing "InferRequest" span as a root span + // to keep it alive for the duration of the request. + ts->otel_context_ = ts->otel_context_.SetValue(kRootSpan, root_span); +#else + LOG_ERROR << "Unsupported trace mode: " + << TraceManager::InferenceTraceModeString(ts->setting_->mode_); +#endif + } } return ts; } @@ -359,11 +408,11 @@ TraceManager::InitTracer(const triton::server::TraceConfigMap& config_map) { switch (global_setting_->mode_) { case TRACE_MODE_OPENTELEMETRY: { -#if !defined(_WIN32) && defined(TRITON_ENABLE_TRACING) +#ifndef _WIN32 otlp::OtlpHttpExporterOptions opts; otel_resource::ResourceAttributes attributes = {}; attributes[otel_resource::SemanticConventions::kServiceName] = - "triton-inference-server"; + std::string("triton-inference-server"); auto mode_key = std::to_string(TRACE_MODE_OPENTELEMETRY); auto otel_options_it = config_map.find(mode_key); if (otel_options_it != config_map.end()) { @@ -381,12 +430,6 @@ TraceManager::InitTracer(const triton::server::TraceConfigMap& config_map) } } auto exporter = otlp::OtlpHttpExporterFactory::Create(opts); - auto test_exporter = triton::server::GetEnvironmentVariableOrDefault( - "TRITON_OPENTELEMETRY_TEST", "false"); - if (test_exporter != "false") { - exporter = opentelemetry::exporter::trace::OStreamSpanExporterFactory:: - Create(); - } auto processor = otel_trace_sdk::SimpleSpanProcessorFactory::Create( std::move(exporter)); auto resource = otel_resource::Resource::Create(attributes); @@ -395,6 +438,10 @@ TraceManager::InitTracer(const triton::server::TraceConfigMap& config_map) std::move(processor), resource); otel_trace_api::Provider::SetTracerProvider(provider); + otel_cntxt::propagation::GlobalTextMapPropagator::SetGlobalPropagator( + opentelemetry::nostd::shared_ptr< + otel_cntxt::propagation::TextMapPropagator>( + new otel_trace_api::propagation::HttpTraceContext())); break; #else LOG_ERROR << "Unsupported trace mode: " @@ -413,7 +460,7 @@ TraceManager::CleanupTracer() { switch (global_setting_->mode_) { case TRACE_MODE_OPENTELEMETRY: { -#if !defined(_WIN32) && defined(TRITON_ENABLE_TRACING) +#ifndef _WIN32 std::shared_ptr none; otel_trace_api::Provider::SetTracerProvider(none); break; @@ -1022,21 +1069,36 @@ TraceManager::TraceFile::SaveTraces( } std::shared_ptr -TraceManager::TraceSetting::SampleTrace() +TraceManager::TraceSetting::SampleTrace(bool force_sample) { - bool create_trace = false; + bool count_rate_hit = false; { std::lock_guard lk(mu_); - if (!Valid()) { - return nullptr; - } - create_trace = (((++sample_) % rate_) == 0); - if (create_trace && (count_ > 0)) { - --count_; - ++created_; + // [FIXME: DLIS-6033] + // A current WAR for initiating trace based on propagated context only + // Currently this is implemented through setting trace rate as 0 + if (rate_ != 0) { + // If `count_` hits 0, `Valid()` returns false for this and all + // following requests (unless `count_` is updated by a user). + // At this point we only trace requests for which + // `force_sample` is true. + if (!Valid() && !force_sample) { + return nullptr; + } + // `sample_` counts all requests, coming to server. + count_rate_hit = (((++sample_) % rate_) == 0); + if (count_rate_hit && (count_ > 0)) { + --count_; + ++created_; + } else if (count_rate_hit && (count_ == 0)) { + // This condition is reached, when `force_sample` is true, + // `count_rate_hit` is true, but `count_` is 0. Due to the + // latter, we explicitly set `count_rate_hit` to false. + count_rate_hit = false; + } } } - if (create_trace) { + if (count_rate_hit || force_sample) { std::shared_ptr lts(new Trace()); // Split 'Trace' management to frontend and Triton trace separately // to avoid dependency between frontend request and Triton trace's @@ -1056,22 +1118,6 @@ TraceManager::TraceSetting::SampleTrace() LOG_TRITONSERVER_ERROR( TRITONSERVER_InferenceTraceId(trace, <s->trace_id_), "getting trace id"); - if (mode_ == TRACE_MODE_OPENTELEMETRY) { -#ifndef _WIN32 - auto steady_timestamp_ns = - std::chrono::duration_cast( - std::chrono::steady_clock::now().time_since_epoch()) - .count(); - auto root_span = lts->StartSpan("InferRequest", steady_timestamp_ns); - // Initializing OTel context and storing "InferRequest" span as a root - // span to keep it alive for the duration of the request. - lts->otel_context_ = - opentelemetry::context::Context({kRootSpan, root_span}); -#else - LOG_ERROR << "Unsupported trace mode: " - << TraceManager::InferenceTraceModeString(mode_); -#endif - } return lts; } return nullptr; diff --git a/src/tracer.h b/src/tracer.h index baba2c8893..2af7f28956 100644 --- a/src/tracer.h +++ b/src/tracer.h @@ -36,23 +36,30 @@ #include #if !defined(_WIN32) && defined(TRITON_ENABLE_TRACING) +#include "opentelemetry/context/propagation/global_propagator.h" #include "opentelemetry/nostd/shared_ptr.h" #include "opentelemetry/sdk/resource/resource.h" #include "opentelemetry/sdk/trace/processor.h" #include "opentelemetry/sdk/trace/simple_processor_factory.h" #include "opentelemetry/sdk/trace/tracer_provider_factory.h" #include "opentelemetry/trace/context.h" +#include "opentelemetry/trace/propagation/http_trace_context.h" #include "opentelemetry/trace/provider.h" namespace otel_trace_sdk = opentelemetry::sdk::trace; namespace otel_trace_api = opentelemetry::trace; +namespace otel_cntxt = opentelemetry::context; #endif - #include "triton/core/tritonserver.h" namespace triton { namespace server { using TraceConfig = std::vector>; using TraceConfigMap = std::unordered_map; +#if !defined(_WIN32) && defined(TRITON_ENABLE_TRACING) +using AbstractCarrier = otel_cntxt::propagation::TextMapCarrier; +#else +using AbstractCarrier = void*; +#endif // Common OTel span keys to store in OTel context // with the corresponding trace id. @@ -124,9 +131,24 @@ class TraceManager { ~TraceManager() { CleanupTracer(); } + /// Options required at Trace initialization + struct TraceStartOptions { +#if !defined(_WIN32) && defined(TRITON_ENABLE_TRACING) + otel_cntxt::Context propagated_context{otel_cntxt::Context{}}; +#else + void* propagated_context{nullptr}; +#endif + std::shared_ptr trace_setting{nullptr}; + bool force_sample{false}; + }; + + // Returns TraceStartOptions for specified model + TraceStartOptions GetTraceStartOptions( + AbstractCarrier& carriers, const std::string& model_name); + // Return a trace that should be used to collected trace activities // for an inference request. Return nullptr if no tracing should occur. - std::shared_ptr SampleTrace(const std::string& model_name); + std::shared_ptr SampleTrace(const TraceStartOptions& start_options); // Update global setting if 'model_name' is empty, otherwise, model setting is // updated. @@ -138,6 +160,11 @@ class TraceManager { uint32_t* rate, int32_t* count, uint32_t* log_frequency, std::string* filepath); + // Sets provided TraceSetting with correct trace settings for provided model. + void GetTraceSetting( + const std::string& model_name, + std::shared_ptr& trace_setting); + // Return the current timestamp. static uint64_t CaptureTimestamp() { @@ -397,7 +424,11 @@ class TraceManager { const std::unordered_map>& streams); - std::shared_ptr SampleTrace(); + // Pass `force_sample` = true, when trace needs to be initiated + // no matter what `rate` and `count` is. + // For example, in OpenTelemetry tracing mode, we always initiate tracing + // when OpenTelemetry context was propagated from client. + std::shared_ptr SampleTrace(bool force_sample = false); const TRITONSERVER_InferenceTraceLevel level_; const uint32_t rate_;