diff --git a/CHANGELOG.md b/CHANGELOG.md index ea4843f843..82fd3a23fe 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Fix async redis clients not being traced correctly ([#1830](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1830)) - Make Flask request span attributes available for `start_span`. ([#1784](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1784)) - Fix falcon instrumentation's usage of Span Status to only set the description if the status code is ERROR. diff --git a/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py b/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py index 3c8acdef31..ba4b8d529e 100644 --- a/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py @@ -136,6 +136,43 @@ def _set_connection_attributes(span, conn): span.set_attribute(key, value) +def _build_span_name(instance, cmd_args): + if len(cmd_args) > 0 and cmd_args[0]: + name = cmd_args[0] + else: + name = instance.connection_pool.connection_kwargs.get("db", 0) + return name + + +def _build_span_meta_data_for_pipeline(instance): + try: + command_stack = ( + instance.command_stack + if hasattr(instance, "command_stack") + else instance._command_stack + ) + + cmds = [ + _format_command_args(c.args if hasattr(c, "args") else c[0]) + for c in command_stack + ] + resource = "\n".join(cmds) + + span_name = " ".join( + [ + (c.args[0] if hasattr(c, "args") else c[0][0]) + for c in command_stack + ] + ) + except (AttributeError, IndexError): + command_stack = [] + resource = "" + span_name = "" + + return command_stack, resource, span_name + + +# pylint: disable=R0915 def _instrument( tracer, request_hook: _RequestHookT = None, @@ -143,11 +180,8 @@ def _instrument( ): def _traced_execute_command(func, instance, args, kwargs): query = _format_command_args(args) + name = _build_span_name(instance, args) - if len(args) > 0 and args[0]: - name = args[0] - else: - name = instance.connection_pool.connection_kwargs.get("db", 0) with tracer.start_as_current_span( name, kind=trace.SpanKind.CLIENT ) as span: @@ -163,31 +197,11 @@ def _traced_execute_command(func, instance, args, kwargs): return response def _traced_execute_pipeline(func, instance, args, kwargs): - try: - command_stack = ( - instance.command_stack - if hasattr(instance, "command_stack") - else instance._command_stack - ) - - cmds = [ - _format_command_args( - c.args if hasattr(c, "args") else c[0], - ) - for c in command_stack - ] - resource = "\n".join(cmds) - - span_name = " ".join( - [ - (c.args[0] if hasattr(c, "args") else c[0][0]) - for c in command_stack - ] - ) - except (AttributeError, IndexError): - command_stack = [] - resource = "" - span_name = "" + ( + command_stack, + resource, + span_name, + ) = _build_span_meta_data_for_pipeline(instance) with tracer.start_as_current_span( span_name, kind=trace.SpanKind.CLIENT @@ -232,32 +246,72 @@ def _traced_execute_pipeline(func, instance, args, kwargs): "ClusterPipeline.execute", _traced_execute_pipeline, ) + + async def _async_traced_execute_command(func, instance, args, kwargs): + query = _format_command_args(args) + name = _build_span_name(instance, args) + + with tracer.start_as_current_span( + name, kind=trace.SpanKind.CLIENT + ) as span: + if span.is_recording(): + span.set_attribute(SpanAttributes.DB_STATEMENT, query) + _set_connection_attributes(span, instance) + span.set_attribute("db.redis.args_length", len(args)) + if callable(request_hook): + request_hook(span, instance, args, kwargs) + response = await func(*args, **kwargs) + if callable(response_hook): + response_hook(span, instance, response) + return response + + async def _async_traced_execute_pipeline(func, instance, args, kwargs): + ( + command_stack, + resource, + span_name, + ) = _build_span_meta_data_for_pipeline(instance) + + with tracer.start_as_current_span( + span_name, kind=trace.SpanKind.CLIENT + ) as span: + if span.is_recording(): + span.set_attribute(SpanAttributes.DB_STATEMENT, resource) + _set_connection_attributes(span, instance) + span.set_attribute( + "db.redis.pipeline_length", len(command_stack) + ) + response = await func(*args, **kwargs) + if callable(response_hook): + response_hook(span, instance, response) + return response + if redis.VERSION >= _REDIS_ASYNCIO_VERSION: wrap_function_wrapper( "redis.asyncio", f"{redis_class}.execute_command", - _traced_execute_command, + _async_traced_execute_command, ) wrap_function_wrapper( "redis.asyncio.client", f"{pipeline_class}.execute", - _traced_execute_pipeline, + _async_traced_execute_pipeline, ) wrap_function_wrapper( "redis.asyncio.client", f"{pipeline_class}.immediate_execute_command", - _traced_execute_command, + _async_traced_execute_command, ) if redis.VERSION >= _REDIS_ASYNCIO_CLUSTER_VERSION: wrap_function_wrapper( "redis.asyncio.cluster", "RedisCluster.execute_command", - _traced_execute_command, + _async_traced_execute_command, ) wrap_function_wrapper( "redis.asyncio.cluster", "ClusterPipeline.execute", - _traced_execute_pipeline, + _async_traced_execute_pipeline, ) diff --git a/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py b/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py index cc6e7de75a..11e56ad953 100644 --- a/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py +++ b/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import asyncio from unittest import mock import redis +import redis.asyncio from opentelemetry import trace from opentelemetry.instrumentation.redis import RedisInstrumentor @@ -21,6 +23,24 @@ from opentelemetry.trace import SpanKind +class AsyncMock: + """A sufficient async mock implementation. + + Python 3.7 doesn't have an inbuilt async mock class, so this is used. + """ + + def __init__(self): + self.mock = mock.Mock() + + async def __call__(self, *args, **kwargs): + future = asyncio.Future() + future.set_result("random") + return future + + def __getattr__(self, item): + return AsyncMock() + + class TestRedis(TestBase): def setUp(self): super().setUp() @@ -87,6 +107,35 @@ def test_instrument_uninstrument(self): spans = self.memory_exporter.get_finished_spans() self.assertEqual(len(spans), 1) + def test_instrument_uninstrument_async_client_command(self): + redis_client = redis.asyncio.Redis() + + with mock.patch.object(redis_client, "connection", AsyncMock()): + asyncio.run(redis_client.get("key")) + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + self.memory_exporter.clear() + + # Test uninstrument + RedisInstrumentor().uninstrument() + + with mock.patch.object(redis_client, "connection", AsyncMock()): + asyncio.run(redis_client.get("key")) + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 0) + self.memory_exporter.clear() + + # Test instrument again + RedisInstrumentor().instrument() + + with mock.patch.object(redis_client, "connection", AsyncMock()): + asyncio.run(redis_client.get("key")) + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + def test_response_hook(self): redis_client = redis.Redis() connection = redis.connection.Connection() diff --git a/tests/opentelemetry-docker-tests/tests/redis/test_redis_functional.py b/tests/opentelemetry-docker-tests/tests/redis/test_redis_functional.py index dc9cf8b1dc..481b8d21c8 100644 --- a/tests/opentelemetry-docker-tests/tests/redis/test_redis_functional.py +++ b/tests/opentelemetry-docker-tests/tests/redis/test_redis_functional.py @@ -13,6 +13,7 @@ # limitations under the License. import asyncio +from time import time_ns import redis import redis.asyncio @@ -318,6 +319,29 @@ def test_basics(self): ) self.assertEqual(span.attributes.get("db.redis.args_length"), 2) + def test_execute_command_traced_full_time(self): + """Command should be traced for coroutine execution time, not creation time.""" + coro_created_time = None + finish_time = None + + async def pipeline_simple(): + nonlocal coro_created_time + nonlocal finish_time + + # delay coroutine creation from coroutine execution + coro = self.redis_client.get("foo") + coro_created_time = time_ns() + await coro + finish_time = time_ns() + + async_call(pipeline_simple()) + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + span = spans[0] + self.assertTrue(span.start_time > coro_created_time) + self.assertTrue(span.end_time < finish_time) + def test_pipeline_traced(self): async def pipeline_simple(): async with self.redis_client.pipeline( @@ -340,6 +364,35 @@ async def pipeline_simple(): ) self.assertEqual(span.attributes.get("db.redis.pipeline_length"), 3) + def test_pipeline_traced_full_time(self): + """Command should be traced for coroutine execution time, not creation time.""" + coro_created_time = None + finish_time = None + + async def pipeline_simple(): + async with self.redis_client.pipeline( + transaction=False + ) as pipeline: + nonlocal coro_created_time + nonlocal finish_time + pipeline.set("blah", 32) + pipeline.rpush("foo", "éé") + pipeline.hgetall("xxx") + + # delay coroutine creation from coroutine execution + coro = pipeline.execute() + coro_created_time = time_ns() + await coro + finish_time = time_ns() + + async_call(pipeline_simple()) + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + span = spans[0] + self.assertTrue(span.start_time > coro_created_time) + self.assertTrue(span.end_time < finish_time) + def test_pipeline_immediate(self): async def pipeline_immediate(): async with self.redis_client.pipeline() as pipeline: @@ -359,6 +412,33 @@ async def pipeline_immediate(): span.attributes.get(SpanAttributes.DB_STATEMENT), "SET ? ?" ) + def test_pipeline_immediate_traced_full_time(self): + """Command should be traced for coroutine execution time, not creation time.""" + coro_created_time = None + finish_time = None + + async def pipeline_simple(): + async with self.redis_client.pipeline( + transaction=False + ) as pipeline: + nonlocal coro_created_time + nonlocal finish_time + pipeline.set("a", 1) + + # delay coroutine creation from coroutine execution + coro = pipeline.immediate_execute_command("SET", "b", 2) + coro_created_time = time_ns() + await coro + finish_time = time_ns() + + async_call(pipeline_simple()) + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + span = spans[0] + self.assertTrue(span.start_time > coro_created_time) + self.assertTrue(span.end_time < finish_time) + def test_parent(self): """Ensure OpenTelemetry works with redis.""" ot_tracer = trace.get_tracer("redis_svc") @@ -408,6 +488,29 @@ def test_basics(self): ) self.assertEqual(span.attributes.get("db.redis.args_length"), 2) + def test_execute_command_traced_full_time(self): + """Command should be traced for coroutine execution time, not creation time.""" + coro_created_time = None + finish_time = None + + async def pipeline_simple(): + nonlocal coro_created_time + nonlocal finish_time + + # delay coroutine creation from coroutine execution + coro = self.redis_client.get("foo") + coro_created_time = time_ns() + await coro + finish_time = time_ns() + + async_call(pipeline_simple()) + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + span = spans[0] + self.assertTrue(span.start_time > coro_created_time) + self.assertTrue(span.end_time < finish_time) + def test_pipeline_traced(self): async def pipeline_simple(): async with self.redis_client.pipeline( @@ -430,6 +533,35 @@ async def pipeline_simple(): ) self.assertEqual(span.attributes.get("db.redis.pipeline_length"), 3) + def test_pipeline_traced_full_time(self): + """Command should be traced for coroutine execution time, not creation time.""" + coro_created_time = None + finish_time = None + + async def pipeline_simple(): + async with self.redis_client.pipeline( + transaction=False + ) as pipeline: + nonlocal coro_created_time + nonlocal finish_time + pipeline.set("blah", 32) + pipeline.rpush("foo", "éé") + pipeline.hgetall("xxx") + + # delay coroutine creation from coroutine execution + coro = pipeline.execute() + coro_created_time = time_ns() + await coro + finish_time = time_ns() + + async_call(pipeline_simple()) + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + span = spans[0] + self.assertTrue(span.start_time > coro_created_time) + self.assertTrue(span.end_time < finish_time) + def test_parent(self): """Ensure OpenTelemetry works with redis.""" ot_tracer = trace.get_tracer("redis_svc")