Skip to content

Commit

Permalink
pymongo instrumentation hooks (#793)
Browse files Browse the repository at this point in the history
* pymongo instrumentation hooks

* update PR number
  • Loading branch information
ItayGibel-helios authored Nov 9, 2021
1 parent 5993329 commit 760673f
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 8 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
([#781](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/781))
- `opentelemetry-instrumentation-aws-lambda` Add instrumentation for AWS Lambda Service - Implementation (Part 2/2)
([#777](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/777))
- `opentelemetry-instrumentation-pymongo` Add `request_hook`, `response_hook` and `failed_hook` callbacks passed as arguments to the instrument method
([#793](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/793))
- `opentelemetry-instrumentation-pymysql` Add support for PyMySQL 1.x series
([#792](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/792))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from pymongo import MongoClient
from opentelemetry.instrumentation.pymongo import PymongoInstrumentor
PymongoInstrumentor().instrument()
client = MongoClient()
db = client["MongoDB_Database"]
Expand All @@ -35,9 +34,47 @@
API
---
"""
The `instrument` method accepts the following keyword args:
tracer_provider (TracerProvider) - an optional tracer provider
request_hook (Callable) -
a function with extra user-defined logic to be performed before querying mongodb
this function signature is: def request_hook(span: Span, event: CommandStartedEvent) -> None
response_hook (Callable) -
a function with extra user-defined logic to be performed after the query returns with a successful response
this function signature is: def response_hook(span: Span, event: CommandSucceededEvent) -> None
failed_hook (Callable) -
a function with extra user-defined logic to be performed after the query returns with a failed response
this function signature is: def failed_hook(span: Span, event: CommandFailedEvent) -> None
for example:
.. code: python
from opentelemetry.instrumentation.pymongo import PymongoInstrumentor
from pymongo import MongoClient
def request_hook(span, event):
# request hook logic
from typing import Collection
def response_hook(span, event):
# response hook logic
def failed_hook(span, event):
# failed hook logic
# Instrument pymongo with hooks
PymongoInstrumentor().instrument(request_hook=request_hook, response_hooks=response_hook, failed_hook=failed_hook)
# This will create a span with pymongo specific attributes, including custom attributes added from the hooks
client = MongoClient()
db = client["MongoDB_Database"]
collection = db["MongoDB_Collection"]
collection.find_one()
"""
from logging import getLogger
from typing import Callable, Collection

from pymongo import monitoring

Expand All @@ -48,14 +85,34 @@
from opentelemetry.instrumentation.utils import _SUPPRESS_INSTRUMENTATION_KEY
from opentelemetry.semconv.trace import DbSystemValues, SpanAttributes
from opentelemetry.trace import SpanKind, get_tracer
from opentelemetry.trace.span import Span
from opentelemetry.trace.status import Status, StatusCode

_LOG = getLogger(__name__)

RequestHookT = Callable[[Span, monitoring.CommandStartedEvent], None]
ResponseHookT = Callable[[Span, monitoring.CommandSucceededEvent], None]
FailedHookT = Callable[[Span, monitoring.CommandFailedEvent], None]


def dummy_callback(span, event):
...


class CommandTracer(monitoring.CommandListener):
def __init__(self, tracer):
def __init__(
self,
tracer,
request_hook: RequestHookT = dummy_callback,
response_hook: ResponseHookT = dummy_callback,
failed_hook: FailedHookT = dummy_callback,
):
self._tracer = tracer
self._span_dict = {}
self.is_enabled = True
self.start_hook = request_hook
self.success_hook = response_hook
self.failed_hook = failed_hook

def started(self, event: monitoring.CommandStartedEvent):
""" Method to handle a pymongo CommandStartedEvent """
Expand Down Expand Up @@ -85,6 +142,10 @@ def started(self, event: monitoring.CommandStartedEvent):
span.set_attribute(
SpanAttributes.NET_PEER_PORT, event.connection_id[1]
)
try:
self.start_hook(span, event)
except Exception as hook_exception: # noqa pylint: disable=broad-except
_LOG.exception(hook_exception)

# Add Span to dictionary
self._span_dict[_get_span_dict_key(event)] = span
Expand All @@ -103,6 +164,11 @@ def succeeded(self, event: monitoring.CommandSucceededEvent):
span = self._pop_span(event)
if span is None:
return
if span.is_recording():
try:
self.success_hook(span, event)
except Exception as hook_exception: # noqa pylint: disable=broad-except
_LOG.exception(hook_exception)
span.end()

def failed(self, event: monitoring.CommandFailedEvent):
Expand All @@ -116,6 +182,10 @@ def failed(self, event: monitoring.CommandFailedEvent):
return
if span.is_recording():
span.set_status(Status(StatusCode.ERROR, event.failure))
try:
self.failed_hook(span, event)
except Exception as hook_exception: # noqa pylint: disable=broad-except
_LOG.exception(hook_exception)
span.end()

def _pop_span(self, event):
Expand Down Expand Up @@ -150,12 +220,20 @@ def _instrument(self, **kwargs):
"""

tracer_provider = kwargs.get("tracer_provider")
request_hook = kwargs.get("request_hook", dummy_callback)
response_hook = kwargs.get("response_hook", dummy_callback)
failed_hook = kwargs.get("failed_hook", dummy_callback)

# Create and register a CommandTracer only the first time
if self._commandtracer_instance is None:
tracer = get_tracer(__name__, __version__, tracer_provider)

self._commandtracer_instance = CommandTracer(tracer)
self._commandtracer_instance = CommandTracer(
tracer,
request_hook=request_hook,
response_hook=response_hook,
failed_hook=failed_hook,
)
monitoring.register(self._commandtracer_instance)

# If already created, just enable it
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ class TestPymongo(TestBase):
def setUp(self):
super().setUp()
self.tracer = self.tracer_provider.get_tracer(__name__)
self.start_callback = mock.MagicMock()
self.success_callback = mock.MagicMock()
self.failed_callback = mock.MagicMock()

def test_pymongo_instrumentor(self):
mock_register = mock.Mock()
Expand All @@ -44,7 +47,9 @@ def test_started(self):
command_attrs = {
"command_name": "find",
}
command_tracer = CommandTracer(self.tracer)
command_tracer = CommandTracer(
self.tracer, request_hook=self.start_callback
)
mock_event = MockEvent(
command_attrs, ("test.com", "1234"), "test_request_id"
)
Expand All @@ -66,17 +71,24 @@ def test_started(self):
span.attributes[SpanAttributes.NET_PEER_NAME], "test.com"
)
self.assertEqual(span.attributes[SpanAttributes.NET_PEER_PORT], "1234")
self.start_callback.assert_called_once_with(span, mock_event)

def test_succeeded(self):
mock_event = MockEvent({})
command_tracer = CommandTracer(self.tracer)
command_tracer = CommandTracer(
self.tracer,
request_hook=self.start_callback,
response_hook=self.success_callback,
)
command_tracer.started(event=mock_event)
command_tracer.succeeded(event=mock_event)
spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)
span = spans_list[0]
self.assertIs(span.status.status_code, trace_api.StatusCode.UNSET)
self.assertIsNotNone(span.end_time)
self.start_callback.assert_called_once()
self.success_callback.assert_called_once()

def test_not_recording(self):
mock_tracer = mock.Mock()
Expand Down Expand Up @@ -119,7 +131,11 @@ def test_suppression_key(self):

def test_failed(self):
mock_event = MockEvent({})
command_tracer = CommandTracer(self.tracer)
command_tracer = CommandTracer(
self.tracer,
request_hook=self.start_callback,
failed_hook=self.failed_callback,
)
command_tracer.started(event=mock_event)
command_tracer.failed(event=mock_event)

Expand All @@ -132,6 +148,8 @@ def test_failed(self):
)
self.assertEqual(span.status.description, "failure")
self.assertIsNotNone(span.end_time)
self.start_callback.assert_called_once()
self.failed_callback.assert_called_once()

def test_multiple_commands(self):
first_mock_event = MockEvent({}, ("firstUrl", "123"), "first")
Expand Down

0 comments on commit 760673f

Please sign in to comment.