Skip to content

Commit

Permalink
fix(opentelemetry-instrumentation-celery): detach context after task …
Browse files Browse the repository at this point in the history
…is run
  • Loading branch information
malcolmrebughini authored and ocelotl committed Aug 1, 2024
1 parent eef8b9d commit 9126473
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ def add(x, y):
from billiard.einfo import ExceptionInfo
from celery import signals # pylint: disable=no-name-in-module

from opentelemetry import context, trace
from opentelemetry import context as context_api
from opentelemetry import trace
from opentelemetry.instrumentation.celery import utils
from opentelemetry.instrumentation.celery.package import _instruments
from opentelemetry.instrumentation.celery.version import __version__
Expand Down Expand Up @@ -142,12 +143,9 @@ def _instrument(self, **kwargs):

signals.task_prerun.connect(self._trace_prerun, weak=False)
signals.task_postrun.connect(self._trace_postrun, weak=False)
signals.before_task_publish.connect(
self._trace_before_publish, weak=False
)
signals.after_task_publish.connect(
self._trace_after_publish, weak=False
)
signals.before_task_publish.connect(self._trace_before_publish, weak=False)
signals.before_task_publish.connect(self._trace_before_publish, weak=False)
signals.after_task_publish.connect(self._trace_after_publish, weak=False)
signals.task_failure.connect(self._trace_failure, weak=False)
signals.task_retry.connect(self._trace_retry, weak=False)

Expand All @@ -169,9 +167,7 @@ def _trace_prerun(self, *args, **kwargs):
self.update_task_duration_time(task_id)
request = task.request
tracectx = extract(request, getter=celery_getter) or None

if tracectx is not None:
context.attach(tracectx)
token = context_api.attach(tracectx)

logger.debug("prerun signal start task_id=%s", task_id)

Expand All @@ -182,7 +178,7 @@ def _trace_prerun(self, *args, **kwargs):

activation = trace.use_span(span, end_on_exit=True)
activation.__enter__() # pylint: disable=E1101
utils.attach_span(task, task_id, (span, activation))
utils.attach_context(task, task_id, span, activation, token)

def _trace_postrun(self, *args, **kwargs):
task = utils.retrieve_task(kwargs)
Expand All @@ -194,11 +190,14 @@ def _trace_postrun(self, *args, **kwargs):
logger.debug("postrun signal task_id=%s", task_id)

# retrieve and finish the Span
span, activation = utils.retrieve_span(task, task_id)
if span is None:
ctx = utils.retrieve_context(task, task_id)

if ctx is None:
logger.warning("no existing span found for task_id=%s", task_id)
return

span, activation, token = ctx

# request context tags
if span.is_recording():
span.set_attribute(_TASK_TAG_KEY, _TASK_RUN)
Expand All @@ -207,10 +206,11 @@ def _trace_postrun(self, *args, **kwargs):
span.set_attribute(_TASK_NAME_KEY, task.name)

activation.__exit__(None, None, None)
utils.detach_span(task, task_id)
utils.detach_context(task, task_id)
self.update_task_duration_time(task_id)
labels = {"task": task.name, "worker": task.request.hostname}
self._record_histograms(task_id, labels)
context_api.detach(token)

def _trace_before_publish(self, *args, **kwargs):
task = utils.retrieve_task_from_sender(kwargs)
Expand All @@ -227,9 +227,7 @@ def _trace_before_publish(self, *args, **kwargs):
else:
task_name = task.name
operation_name = f"{_TASK_APPLY_ASYNC}/{task_name}"
span = self._tracer.start_span(
operation_name, kind=trace.SpanKind.PRODUCER
)
span = self._tracer.start_span(operation_name, kind=trace.SpanKind.PRODUCER)

# apply some attributes here because most of the data is not available
if span.is_recording():
Expand All @@ -241,7 +239,7 @@ def _trace_before_publish(self, *args, **kwargs):
activation = trace.use_span(span, end_on_exit=True)
activation.__enter__() # pylint: disable=E1101

utils.attach_span(task, task_id, (span, activation), is_publish=True)
utils.attach_context(task, task_id, span, activation, None, is_publish=True)

headers = kwargs.get("headers")
if headers:
Expand All @@ -256,13 +254,16 @@ def _trace_after_publish(*args, **kwargs):
return

# retrieve and finish the Span
_, activation = utils.retrieve_span(task, task_id, is_publish=True)
if activation is None:
ctx = utils.retrieve_context(task, task_id, is_publish=True)

if ctx is None:
logger.warning("no existing span found for task_id=%s", task_id)
return

_, activation, _ = ctx

activation.__exit__(None, None, None) # pylint: disable=E1101
utils.detach_span(task, task_id, is_publish=True)
utils.detach_context(task, task_id, is_publish=True)

@staticmethod
def _trace_failure(*args, **kwargs):
Expand All @@ -272,9 +273,14 @@ def _trace_failure(*args, **kwargs):
if task is None or task_id is None:
return

# retrieve and pass exception info to activation
span, _ = utils.retrieve_span(task, task_id)
if span is None or not span.is_recording():
ctx = utils.retrieve_context(task, task_id)

if ctx is None:
return

span, _, _ = ctx

if not span.is_recording():
return

status_kwargs = {"status_code": StatusCode.ERROR}
Expand Down Expand Up @@ -314,8 +320,14 @@ def _trace_retry(*args, **kwargs):
if task is None or task_id is None or reason is None:
return

span, _ = utils.retrieve_span(task, task_id)
if span is None or not span.is_recording():
ctx = utils.retrieve_context(task, task_id)

if ctx is None:
return

span, _, _ = ctx

if not span.is_recording():
return

# Add retry reason metadata to span
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@
# limitations under the License.

import logging
from typing import Optional, ContextManager

from celery import registry # pylint: disable=no-name-in-module
from celery.app.task import Task

from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.trace import Span

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -82,9 +85,7 @@ def set_attributes_from_context(span, context):
# Get also destination from this
routing_key = value.get("routing_key")
if routing_key is not None:
span.set_attribute(
SpanAttributes.MESSAGING_DESTINATION, routing_key
)
span.set_attribute(SpanAttributes.MESSAGING_DESTINATION, routing_key)
value = str(value)

elif key == "id":
Expand Down Expand Up @@ -114,6 +115,70 @@ def set_attributes_from_context(span, context):
span.set_attribute(attribute_name, value)


def attach_context(
task: Optional[Task],
task_id: str,
span: Span,
activation: ContextManager[Span],
token: Optional[object],
is_publish: bool = False,
) -> None:
"""Helper to propagate a `Span`, `ContextManager` and context token
for the given `Task` instance. This function uses a `dict` that stores
the Span using the `(task_id, is_publish)` as a key. This is useful
when information must be propagated from one Celery signal to another.
We use (task_id, is_publish) for the key to ensure that publishing a
task from within another task does not cause any conflicts.
This mostly happens when either a task fails and a retry policy is in place,
or when a task is manually retries (e.g. `task.retry()`), we end up trying
to publish a task with the same id as the task currently running.
Previously publishing the new task would overwrite the existing `celery.run` span
in the `dict` causing that span to be forgotten and never finished
NOTE: We cannot test for this well yet, because we do not run a celery worker,
and cannot run `task.apply_async()`
"""
if task is None:
return

ctx_dict = getattr(task, CTX_KEY, None)

if ctx_dict is None:
ctx_dict = {}
setattr(task, CTX_KEY, ctx_dict)

ctx_dict[(task_id, is_publish)] = (span, activation, token)


def detach_context(task, task_id, is_publish=False) -> None:
"""Helper to remove `Span`, `ContextManager` and context token in a
Celery task when it's propagated.
This function handles tasks where no values are attached to the `Task`.
"""
span_dict = getattr(task, CTX_KEY, None)
if span_dict is None:
return

# See note in `attach_context` for key info
span_dict.pop((task_id, is_publish), None)


def retrieve_context(
task, task_id, is_publish=False
) -> Optional[tuple[Span, ContextManager[Span], Optional[object]]]:
"""Helper to retrieve an active `Span`, `ContextManager` and context token
stored in a `Task` instance
"""
span_dict = getattr(task, CTX_KEY, None)
if span_dict is None:
return None

# See note in `attach_context` for key info
return span_dict.get((task_id, is_publish), None)


def attach_span(task, task_id, span, is_publish=False):
"""Helper to propagate a `Span` for the given `Task` instance. This
function uses a `dict` that stores the Span using the
Expand Down

0 comments on commit 9126473

Please sign in to comment.