diff --git a/CHANGELOG.md b/CHANGELOG.md index 0c1311c7a6..891d0fa159 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `opentelemetry-instrumentation-asgi` now explicitly depends on asgiref as it uses the package instead of instrumenting it. ([#765](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/765)) +- `opentelemetry-instrumentation-pika` now propagates context to basic_consume callback + ([#766](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/766)) ## [1.6.2-0.25b2](https://github.com/open-telemetry/opentelemetry-python/releases/tag/v1.6.2-0.25b2) - 2021-10-19 diff --git a/instrumentation/opentelemetry-instrumentation-pika/src/opentelemetry/instrumentation/pika/pika_instrumentor.py b/instrumentation/opentelemetry-instrumentation-pika/src/opentelemetry/instrumentation/pika/pika_instrumentor.py index 05496f53dd..cc088c9ad0 100644 --- a/instrumentation/opentelemetry-instrumentation-pika/src/opentelemetry/instrumentation/pika/pika_instrumentor.py +++ b/instrumentation/opentelemetry-instrumentation-pika/src/opentelemetry/instrumentation/pika/pika_instrumentor.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. from logging import getLogger -from typing import Any, Callable, Collection, Dict, Optional +from typing import Any, Collection, Dict, Optional import wrapt from pika.adapters import BlockingConnection -from pika.channel import Channel +from pika.adapters.blocking_connection import BlockingChannel from opentelemetry import trace from opentelemetry.instrumentation.instrumentor import BaseInstrumentor @@ -35,18 +35,25 @@ class PikaInstrumentor(BaseInstrumentor): # type: ignore # pylint: disable=attribute-defined-outside-init @staticmethod - def _instrument_consumers( - consumers_dict: Dict[str, Callable[..., Any]], tracer: Tracer + def _instrument_blocking_channel_consumers( + channel: BlockingChannel, tracer: Tracer ) -> Any: - for key, callback in consumers_dict.items(): + for consumer_tag, consumer_info in channel._consumer_infos.items(): decorated_callback = utils._decorate_callback( - callback, tracer, key + consumer_info.on_message_callback, tracer, consumer_tag ) - setattr(decorated_callback, "_original_callback", callback) - consumers_dict[key] = decorated_callback + + setattr( + decorated_callback, + "_original_callback", + consumer_info.on_message_callback, + ) + consumer_info.on_message_callback = decorated_callback @staticmethod - def _instrument_basic_publish(channel: Channel, tracer: Tracer) -> None: + def _instrument_basic_publish( + channel: BlockingChannel, tracer: Tracer + ) -> None: original_function = getattr(channel, "basic_publish") decorated_function = utils._decorate_basic_publish( original_function, channel, tracer @@ -57,13 +64,13 @@ def _instrument_basic_publish(channel: Channel, tracer: Tracer) -> None: @staticmethod def _instrument_channel_functions( - channel: Channel, tracer: Tracer + channel: BlockingChannel, tracer: Tracer ) -> None: if hasattr(channel, "basic_publish"): PikaInstrumentor._instrument_basic_publish(channel, tracer) @staticmethod - def _uninstrument_channel_functions(channel: Channel) -> None: + def _uninstrument_channel_functions(channel: BlockingChannel) -> None: for function_name in _FUNCTIONS_TO_UNINSTRUMENT: if not hasattr(channel, function_name): continue @@ -73,8 +80,10 @@ def _uninstrument_channel_functions(channel: Channel) -> None: unwrap(channel, "basic_consume") @staticmethod + # Make sure that the spans are created inside hash them set as parent and not as brothers def instrument_channel( - channel: Channel, tracer_provider: Optional[TracerProvider] = None, + channel: BlockingChannel, + tracer_provider: Optional[TracerProvider] = None, ) -> None: if not hasattr(channel, "_is_instrumented_by_opentelemetry"): channel._is_instrumented_by_opentelemetry = False @@ -84,18 +93,14 @@ def instrument_channel( ) return tracer = trace.get_tracer(__name__, __version__, tracer_provider) - if not hasattr(channel, "_impl"): - _LOG.error("Could not find implementation for provided channel!") - return - if channel._impl._consumers: - PikaInstrumentor._instrument_consumers( - channel._impl._consumers, tracer - ) + PikaInstrumentor._instrument_blocking_channel_consumers( + channel, tracer + ) PikaInstrumentor._decorate_basic_consume(channel, tracer) PikaInstrumentor._instrument_channel_functions(channel, tracer) @staticmethod - def uninstrument_channel(channel: Channel) -> None: + def uninstrument_channel(channel: BlockingChannel) -> None: if ( not hasattr(channel, "_is_instrumented_by_opentelemetry") or not channel._is_instrumented_by_opentelemetry @@ -104,12 +109,12 @@ def uninstrument_channel(channel: Channel) -> None: "Attempting to uninstrument Pika channel while already uninstrumented!" ) return - if not hasattr(channel, "_impl"): - _LOG.error("Could not find implementation for provided channel!") - return - for key, callback in channel._impl._consumers.items(): - if hasattr(callback, "_original_callback"): - channel._impl._consumers[key] = callback._original_callback + + for consumers_tag, client_info in channel._consumer_infos.items(): + if hasattr(client_info.on_message_callback, "_original_callback"): + channel._consumer_infos[ + consumers_tag + ] = client_info.on_message_callback._original_callback PikaInstrumentor._uninstrument_channel_functions(channel) def _decorate_channel_function( @@ -123,28 +128,15 @@ def wrapper(wrapped, instance, args, kwargs): wrapt.wrap_function_wrapper(BlockingConnection, "channel", wrapper) @staticmethod - def _decorate_basic_consume(channel, tracer: Optional[Tracer]) -> None: + def _decorate_basic_consume( + channel: BlockingChannel, tracer: Optional[Tracer] + ) -> None: def wrapper(wrapped, instance, args, kwargs): - if not hasattr(channel, "_impl"): - _LOG.error( - "Could not find implementation for provided channel!" - ) - return wrapped(*args, **kwargs) - current_keys = set(channel._impl._consumers.keys()) return_value = wrapped(*args, **kwargs) - new_key_list = list( - set(channel._impl._consumers.keys()) - current_keys - ) - if not new_key_list: - _LOG.error("Could not find added callback") - return return_value - new_key = new_key_list[0] - callback = channel._impl._consumers[new_key] - decorated_callback = utils._decorate_callback( - callback, tracer, new_key + + PikaInstrumentor._instrument_blocking_channel_consumers( + channel, tracer ) - setattr(decorated_callback, "_original_callback", callback) - channel._impl._consumers[new_key] = decorated_callback return return_value wrapt.wrap_function_wrapper(channel, "basic_consume", wrapper) diff --git a/instrumentation/opentelemetry-instrumentation-pika/src/opentelemetry/instrumentation/pika/utils.py b/instrumentation/opentelemetry-instrumentation-pika/src/opentelemetry/instrumentation/pika/utils.py index 12161d2334..5cd6ca795f 100644 --- a/instrumentation/opentelemetry-instrumentation-pika/src/opentelemetry/instrumentation/pika/utils.py +++ b/instrumentation/opentelemetry-instrumentation-pika/src/opentelemetry/instrumentation/pika/utils.py @@ -46,17 +46,23 @@ def decorated_callback( ctx = propagate.extract(properties.headers, getter=_pika_getter) if not ctx: ctx = context.get_current() + token = context.attach(ctx) span = _get_span( tracer, channel, properties, + destination=method.exchange + if method.exchange + else method.routing_key, span_kind=SpanKind.CONSUMER, task_name=task_name, - ctx=ctx, operation=MessagingOperationValues.RECEIVE, ) - with trace.use_span(span, end_on_exit=True): - retval = callback(channel, method, properties, body) + try: + with trace.use_span(span, end_on_exit=True): + retval = callback(channel, method, properties, body) + finally: + context.detach(token) return retval return decorated_callback @@ -78,14 +84,13 @@ def decorated_function( properties = BasicProperties(headers={}) if properties.headers is None: properties.headers = {} - ctx = context.get_current() span = _get_span( tracer, channel, properties, + destination=exchange if exchange else routing_key, span_kind=SpanKind.PRODUCER, task_name="(temporary)", - ctx=ctx, operation=None, ) if not span: @@ -108,8 +113,8 @@ def _get_span( channel: Channel, properties: BasicProperties, task_name: str, + destination: str, span_kind: SpanKind, - ctx: context.Context, operation: Optional[MessagingOperationValues] = None, ) -> Optional[Span]: if context.get_value("suppress_instrumentation") or context.get_value( @@ -118,9 +123,7 @@ def _get_span( return None task_name = properties.type if properties.type else task_name span = tracer.start_span( - context=ctx, - name=_generate_span_name(task_name, operation), - kind=span_kind, + name=_generate_span_name(destination, operation), kind=span_kind, ) if span.is_recording(): _enrich_span(span, channel, properties, task_name, operation) diff --git a/instrumentation/opentelemetry-instrumentation-pika/tests/test_pika_instrumentation.py b/instrumentation/opentelemetry-instrumentation-pika/tests/test_pika_instrumentation.py index da2a940b5b..711377a17e 100644 --- a/instrumentation/opentelemetry-instrumentation-pika/tests/test_pika_instrumentation.py +++ b/instrumentation/opentelemetry-instrumentation-pika/tests/test_pika_instrumentation.py @@ -13,7 +13,7 @@ # limitations under the License. from unittest import TestCase, mock -from pika.adapters import BaseConnection, BlockingConnection +from pika.adapters import BlockingConnection from pika.channel import Channel from wrapt import BoundFunctionWrapper @@ -24,9 +24,10 @@ class TestPika(TestCase): def setUp(self) -> None: self.channel = mock.MagicMock(spec=Channel) - self.channel._impl = mock.MagicMock(spec=BaseConnection) + consumer_info = mock.MagicMock() + consumer_info.on_message_callback = mock.MagicMock() + self.channel._consumer_infos = {"consumer-tag": consumer_info} self.mock_callback = mock.MagicMock() - self.channel._impl._consumers = {"mock_key": self.mock_callback} def test_instrument_api(self) -> None: instrumentation = PikaInstrumentor() @@ -49,11 +50,11 @@ def test_instrument_api(self) -> None: "opentelemetry.instrumentation.pika.PikaInstrumentor._decorate_basic_consume" ) @mock.patch( - "opentelemetry.instrumentation.pika.PikaInstrumentor._instrument_consumers" + "opentelemetry.instrumentation.pika.PikaInstrumentor._instrument_blocking_channel_consumers" ) def test_instrument( self, - instrument_consumers: mock.MagicMock, + instrument_blocking_channel_consumers: mock.MagicMock, instrument_basic_consume: mock.MagicMock, instrument_channel_functions: mock.MagicMock, ): @@ -61,7 +62,7 @@ def test_instrument( assert hasattr( self.channel, "_is_instrumented_by_opentelemetry" ), "channel is not marked as instrumented!" - instrument_consumers.assert_called_once() + instrument_blocking_channel_consumers.assert_called_once() instrument_basic_consume.assert_called_once() instrument_channel_functions.assert_called_once() @@ -71,18 +72,18 @@ def test_instrument_consumers( ) -> None: tracer = mock.MagicMock(spec=Tracer) expected_decoration_calls = [ - mock.call(value, tracer, key) - for key, value in self.channel._impl._consumers.items() + mock.call(value.on_message_callback, tracer, key) + for key, value in self.channel._consumer_infos.items() ] - PikaInstrumentor._instrument_consumers( - self.channel._impl._consumers, tracer + PikaInstrumentor._instrument_blocking_channel_consumers( + self.channel, tracer ) decorate_callback.assert_has_calls( calls=expected_decoration_calls, any_order=True ) assert all( hasattr(callback, "_original_callback") - for callback in self.channel._impl._consumers.values() + for callback in self.channel._consumer_infos.values() ) @mock.patch( diff --git a/instrumentation/opentelemetry-instrumentation-pika/tests/test_utils.py b/instrumentation/opentelemetry-instrumentation-pika/tests/test_utils.py index 6a163c675f..2cc75b8cbf 100644 --- a/instrumentation/opentelemetry-instrumentation-pika/tests/test_utils.py +++ b/instrumentation/opentelemetry-instrumentation-pika/tests/test_utils.py @@ -38,15 +38,15 @@ def test_get_span( channel = mock.MagicMock() properties = mock.MagicMock() task_name = "test.test" + destination = "myqueue" span_kind = mock.MagicMock(spec=SpanKind) get_value.return_value = None - ctx = mock.MagicMock() _ = utils._get_span( - tracer, channel, properties, task_name, span_kind, ctx + tracer, channel, properties, task_name, destination, span_kind ) generate_span_name.assert_called_once() tracer.start_span.assert_called_once_with( - context=ctx, name=generate_span_name.return_value, kind=span_kind + name=generate_span_name.return_value, kind=span_kind ) enrich_span.assert_called_once() @@ -185,6 +185,7 @@ def test_decorate_callback( tracer = mock.MagicMock() channel = mock.MagicMock(spec=Channel) method = mock.MagicMock(spec=Basic.Deliver) + method.exchange = "test_exchange" properties = mock.MagicMock() mock_body = b"mock_body" decorated_callback = utils._decorate_callback( @@ -198,9 +199,9 @@ def test_decorate_callback( tracer, channel, properties, + destination=method.exchange, span_kind=SpanKind.CONSUMER, task_name=mock_task_name, - ctx=extract.return_value, operation=MessagingOperationValues.RECEIVE, ) use_span.assert_called_once_with( @@ -213,35 +214,33 @@ def test_decorate_callback( @mock.patch("opentelemetry.instrumentation.pika.utils._get_span") @mock.patch("opentelemetry.propagate.inject") - @mock.patch("opentelemetry.context.get_current") @mock.patch("opentelemetry.trace.use_span") def test_decorate_basic_publish( self, use_span: mock.MagicMock, - get_current: mock.MagicMock, inject: mock.MagicMock, get_span: mock.MagicMock, ) -> None: callback = mock.MagicMock() tracer = mock.MagicMock() channel = mock.MagicMock(spec=Channel) - method = mock.MagicMock(spec=Basic.Deliver) + exchange_name = "test-exchange" + routing_key = "test-routing-key" properties = mock.MagicMock() mock_body = b"mock_body" decorated_basic_publish = utils._decorate_basic_publish( callback, channel, tracer ) retval = decorated_basic_publish( - channel, method, mock_body, properties + exchange_name, routing_key, mock_body, properties ) - get_current.assert_called_once() get_span.assert_called_once_with( tracer, channel, properties, + destination=exchange_name, span_kind=SpanKind.PRODUCER, task_name="(temporary)", - ctx=get_current.return_value, operation=None, ) use_span.assert_called_once_with( @@ -250,20 +249,18 @@ def test_decorate_basic_publish( get_span.return_value.is_recording.assert_called_once() inject.assert_called_once_with(properties.headers) callback.assert_called_once_with( - channel, method, mock_body, properties, False + exchange_name, routing_key, mock_body, properties, False ) self.assertEqual(retval, callback.return_value) @mock.patch("opentelemetry.instrumentation.pika.utils._get_span") @mock.patch("opentelemetry.propagate.inject") - @mock.patch("opentelemetry.context.get_current") @mock.patch("opentelemetry.trace.use_span") @mock.patch("pika.spec.BasicProperties.__new__") def test_decorate_basic_publish_no_properties( self, basic_properties: mock.MagicMock, use_span: mock.MagicMock, - get_current: mock.MagicMock, inject: mock.MagicMock, get_span: mock.MagicMock, ) -> None: @@ -277,10 +274,39 @@ def test_decorate_basic_publish_no_properties( ) retval = decorated_basic_publish(channel, method, body=mock_body) basic_properties.assert_called_once_with(BasicProperties, headers={}) - get_current.assert_called_once() use_span.assert_called_once_with( get_span.return_value, end_on_exit=True ) get_span.return_value.is_recording.assert_called_once() inject.assert_called_once_with(basic_properties.return_value.headers) self.assertEqual(retval, callback.return_value) + + @staticmethod + @mock.patch("opentelemetry.instrumentation.pika.utils._get_span") + def test_decorate_basic_publish_published_message_to_queue( + get_span: mock.MagicMock, + ) -> None: + callback = mock.MagicMock() + tracer = mock.MagicMock() + channel = mock.MagicMock(spec=Channel) + exchange_name = "" + routing_key = "test-routing-key" + properties = mock.MagicMock() + mock_body = b"mock_body" + + decorated_basic_publish = utils._decorate_basic_publish( + callback, channel, tracer + ) + decorated_basic_publish( + exchange_name, routing_key, mock_body, properties + ) + + get_span.assert_called_once_with( + tracer, + channel, + properties, + destination=routing_key, + span_kind=SpanKind.PRODUCER, + task_name="(temporary)", + operation=None, + )