Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make propagators conform to spec #488

Merged
merged 5 commits into from
May 31, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
([#504](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/504))
- `opentelemetry-instrumentation-asgi` Fix instrumentation default span name.
([#418](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/418))
- Propagators use the root context as default for `extract` and do not modify
the context if extracting from carrier does not work.
([#488](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/488))

### Added
- `opentelemetry-instrumentation-botocore` now supports
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def extract(
context: typing.Optional[Context] = None,
getter: Getter = default_getter,
) -> Context:
if context is None:
context = Context()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this PR description you mention this, @mariojonke:

Propagators should return the passed in context (or the current one if none was given) if nothing could be extracted.

Should we use the current context instead of instantiating a new one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We decided in the SIG meeting to use the root context as default because the current context would continue a possible active trace instead of starting a new one if extracting from the carrier fails (see also open-telemetry/opentelemetry-python#1765 (comment))

See also PR in the core repo: open-telemetry/opentelemetry-python#1811

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, good 👍


trace_id = extract_first_element(
getter.get(carrier, self.TRACE_ID_KEY)
)
Expand All @@ -64,7 +67,7 @@ def extract(
trace_flags = trace.TraceFlags(trace.TraceFlags.SAMPLED)

if trace_id is None or span_id is None:
return set_span_in_context(trace.INVALID_SPAN, context)
return context

trace_state = []
if origin is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from unittest.mock import Mock, patch

from opentelemetry import trace as trace_api
from opentelemetry.context import Context
from opentelemetry.exporter.datadog import constants, propagator
from opentelemetry.sdk import trace
from opentelemetry.sdk.trace.id_generator import RandomIdGenerator
Expand All @@ -36,42 +37,58 @@ def setUpClass(cls):
)
cls.serialized_origin = "origin-service"

def test_malformed_headers(self):
def test_extract_malformed_headers_to_explicit_ctx(self):
"""Test with no Datadog headers"""
orig_ctx = Context({"k1": "v1"})
malformed_trace_id_key = FORMAT.TRACE_ID_KEY + "-x"
malformed_parent_id_key = FORMAT.PARENT_ID_KEY + "-x"
context = get_current_span(
FORMAT.extract(
{
malformed_trace_id_key: self.serialized_trace_id,
malformed_parent_id_key: self.serialized_parent_id,
},
)
).get_span_context()
context = FORMAT.extract(
{
malformed_trace_id_key: self.serialized_trace_id,
malformed_parent_id_key: self.serialized_parent_id,
},
orig_ctx,
)
self.assertDictEqual(orig_ctx, context)

self.assertNotEqual(context.trace_id, int(self.serialized_trace_id))
self.assertNotEqual(context.span_id, int(self.serialized_parent_id))
self.assertFalse(context.is_remote)
def test_extract_malformed_headers_to_implicit_ctx(self):
malformed_trace_id_key = FORMAT.TRACE_ID_KEY + "-x"
malformed_parent_id_key = FORMAT.PARENT_ID_KEY + "-x"
context = FORMAT.extract(
{
malformed_trace_id_key: self.serialized_trace_id,
malformed_parent_id_key: self.serialized_parent_id,
}
)
self.assertDictEqual(Context(), context)

def test_missing_trace_id(self):
def test_extract_missing_trace_id_to_explicit_ctx(self):
"""If a trace id is missing, populate an invalid trace id."""
carrier = {
FORMAT.PARENT_ID_KEY: self.serialized_parent_id,
}
orig_ctx = Context({"k1": "v1"})
carrier = {FORMAT.PARENT_ID_KEY: self.serialized_parent_id}

ctx = FORMAT.extract(carrier, orig_ctx)
self.assertDictEqual(orig_ctx, ctx)

def test_extract_missing_trace_id_to_implicit_ctx(self):
carrier = {FORMAT.PARENT_ID_KEY: self.serialized_parent_id}

ctx = FORMAT.extract(carrier)
span_context = get_current_span(ctx).get_span_context()
self.assertEqual(span_context.trace_id, trace_api.INVALID_TRACE_ID)
self.assertDictEqual(Context(), ctx)

def test_missing_parent_id(self):
def test_extract_missing_parent_id_to_explicit_ctx(self):
"""If a parent id is missing, populate an invalid trace id."""
carrier = {
FORMAT.TRACE_ID_KEY: self.serialized_trace_id,
}
orig_ctx = Context({"k1": "v1"})
carrier = {FORMAT.TRACE_ID_KEY: self.serialized_trace_id}

ctx = FORMAT.extract(carrier, orig_ctx)
self.assertDictEqual(orig_ctx, ctx)

def test_extract_missing_parent_id_to_implicit_ctx(self):
carrier = {FORMAT.TRACE_ID_KEY: self.serialized_trace_id}

ctx = FORMAT.extract(carrier)
span_context = get_current_span(ctx).get_span_context()
self.assertEqual(span_context.span_id, trace_api.INVALID_SPAN_ID)
self.assertDictEqual(Context(), ctx)

def test_context_propagation(self):
"""Test the propagation of Datadog headers."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,19 @@ def extract(
context: Optional[Context] = None,
getter: Getter = default_getter,
) -> Context:
if context is None:
context = Context()

traceid = _extract_first_element(
getter.get(carrier, OT_TRACE_ID_HEADER), INVALID_TRACE_ID
traceid = _extract_identifier(
getter.get(carrier, OT_TRACE_ID_HEADER),
_valid_extract_traceid,
INVALID_TRACE_ID,
)

spanid = _extract_first_element(
getter.get(carrier, OT_SPAN_ID_HEADER), INVALID_SPAN_ID
spanid = _extract_identifier(
getter.get(carrier, OT_SPAN_ID_HEADER),
_valid_extract_spanid,
INVALID_SPAN_ID,
)

sampled = _extract_first_element(
Expand All @@ -73,17 +79,12 @@ def extract(
else:
traceflags = TraceFlags.DEFAULT

if (
traceid != INVALID_TRACE_ID
and _valid_extract_traceid.fullmatch(traceid) is not None
and spanid != INVALID_SPAN_ID
and _valid_extract_spanid.fullmatch(spanid) is not None
):
if traceid != INVALID_TRACE_ID and spanid != INVALID_SPAN_ID:
context = set_span_in_context(
NonRecordingSpan(
SpanContext(
trace_id=int(traceid, 16),
span_id=int(spanid, 16),
trace_id=traceid,
span_id=spanid,
is_remote=True,
trace_flags=TraceFlags(traceflags),
)
Expand Down Expand Up @@ -172,3 +173,16 @@ def _extract_first_element(
if items is None:
return default
return next(iter(items), None)


def _extract_identifier(
items: Iterable[CarrierT], validator_pattern, default: int
) -> int:
header = _extract_first_element(items)
if header is None or validator_pattern.fullmatch(header) is None:
return default

try:
return int(header, 16)
except ValueError:
return default
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from unittest import TestCase

from opentelemetry.baggage import get_all, set_baggage
from opentelemetry.context import Context
from opentelemetry.propagators.ot_trace import (
OT_BAGGAGE_PREFIX,
OT_SAMPLED_HEADER,
Expand All @@ -24,8 +25,6 @@
)
from opentelemetry.sdk.trace import _Span
from opentelemetry.trace import (
INVALID_SPAN_CONTEXT,
INVALID_SPAN_ID,
INVALID_TRACE_ID,
SpanContext,
TraceFlags,
Expand Down Expand Up @@ -275,65 +274,44 @@ def test_extract_trace_id_span_id_sampled_false(self):
get_current_span().get_span_context().trace_flags, TraceFlags
)

def test_extract_malformed_trace_id(self):
"""Test extraction with malformed trace_id"""

span_context = get_current_span(
self.ot_trace_propagator.extract(
{
OT_TRACE_ID_HEADER: "abc123!",
OT_SPAN_ID_HEADER: "e457b5a2e4d86bd1",
OT_SAMPLED_HEADER: "false",
},
)
).get_span_context()

self.assertEqual(span_context, INVALID_SPAN_CONTEXT)

def test_extract_malformed_span_id(self):
"""Test extraction with malformed span_id"""

span_context = get_current_span(
self.ot_trace_propagator.extract(
{
OT_TRACE_ID_HEADER: "64fe8b2a57d3eff7",
OT_SPAN_ID_HEADER: "abc123!",
OT_SAMPLED_HEADER: "false",
},
)
).get_span_context()

self.assertEqual(span_context, INVALID_SPAN_CONTEXT)

def test_extract_invalid_trace_id(self):
"""Test extraction with invalid trace_id"""

span_context = get_current_span(
self.ot_trace_propagator.extract(
{
OT_TRACE_ID_HEADER: INVALID_TRACE_ID,
OT_SPAN_ID_HEADER: "e457b5a2e4d86bd1",
OT_SAMPLED_HEADER: "false",
},
)
).get_span_context()

self.assertEqual(span_context, INVALID_SPAN_CONTEXT)

def test_extract_invalid_span_id(self):
"""Test extraction with invalid span_id"""

span_context = get_current_span(
self.ot_trace_propagator.extract(
{
OT_TRACE_ID_HEADER: "64fe8b2a57d3eff7",
OT_SPAN_ID_HEADER: INVALID_SPAN_ID,
OT_SAMPLED_HEADER: "false",
},
)
).get_span_context()

self.assertEqual(span_context, INVALID_SPAN_CONTEXT)
def test_extract_invalid_trace_header_to_explict_ctx(self):
invalid_headers = [
("abc123!", "e457b5a2e4d86bd1"), # malformed trace id
("64fe8b2a57d3eff7", "abc123!"), # malformed span id
("0" * 32, "e457b5a2e4d86bd1"), # invalid trace id
("64fe8b2a57d3eff7", "0" * 16), # invalid span id
]
for trace_id, span_id in invalid_headers:
with self.subTest(trace_id=trace_id, span_id=span_id):
orig_ctx = Context({"k1": "v1"})

ctx = self.ot_trace_propagator.extract(
{
OT_TRACE_ID_HEADER: trace_id,
OT_SPAN_ID_HEADER: span_id,
OT_SAMPLED_HEADER: "false",
},
orig_ctx,
)
self.assertDictEqual(orig_ctx, ctx)

def test_extract_invalid_trace_header_to_implicit_ctx(self):
invalid_headers = [
("abc123!", "e457b5a2e4d86bd1"), # malformed trace id
("64fe8b2a57d3eff7", "abc123!"), # malformed span id
("0" * 32, "e457b5a2e4d86bd1"), # invalid trace id
("64fe8b2a57d3eff7", "0" * 16), # invalid span id
]
for trace_id, span_id in invalid_headers:
with self.subTest(trace_id=trace_id, span_id=span_id):
ctx = self.ot_trace_propagator.extract(
{
OT_TRACE_ID_HEADER: trace_id,
OT_SPAN_ID_HEADER: span_id,
OT_SAMPLED_HEADER: "false",
}
)
self.assertDictEqual(Context(), ctx)

def test_extract_baggage(self):
"""Test baggage extraction"""
Expand All @@ -359,11 +337,13 @@ def test_extract_baggage(self):
self.assertEqual(baggage["abc"], "abc")
self.assertEqual(baggage["def"], "def")

def test_extract_empty(self):
"Test extraction when no headers are present"
def test_extract_empty_to_explicit_ctx(self):
"""Test extraction when no headers are present"""
orig_ctx = Context({"k1": "v1"})
ctx = self.ot_trace_propagator.extract({}, orig_ctx)

span_context = get_current_span(
self.ot_trace_propagator.extract({})
).get_span_context()
self.assertDictEqual(orig_ctx, ctx)

self.assertEqual(span_context, INVALID_SPAN_CONTEXT)
def test_extract_empty_to_implicit_ctx(self):
ctx = self.ot_trace_propagator.extract({})
self.assertDictEqual(Context(), ctx)
Original file line number Diff line number Diff line change
Expand Up @@ -106,19 +106,18 @@ def extract(
context: typing.Optional[Context] = None,
getter: Getter = default_getter,
) -> Context:
if context is None:
context = Context()

trace_header_list = getter.get(carrier, TRACE_HEADER_KEY)

if not trace_header_list or len(trace_header_list) != 1:
return trace.set_span_in_context(
trace.INVALID_SPAN, context=context
)
return context

trace_header = trace_header_list[0]

if not trace_header:
return trace.set_span_in_context(
trace.INVALID_SPAN, context=context
)
return context

try:
(
Expand All @@ -128,9 +127,7 @@ def extract(
) = AwsXRayFormat._extract_span_properties(trace_header)
except AwsParseTraceHeaderError as err:
_logger.debug(err.message)
return trace.set_span_in_context(
trace.INVALID_SPAN, context=context
)
return context

options = 0
if sampled:
Expand All @@ -148,9 +145,7 @@ def extract(
_logger.debug(
"Invalid Span Extracted. Insertting INVALID span into provided context."
)
return trace.set_span_in_context(
trace.INVALID_SPAN, context=context
)
return context

return trace.set_span_in_context(
trace.NonRecordingSpan(span_context), context=context
Expand Down
Loading