Skip to content

Commit

Permalink
Fix async redis clients tracing (#1830)
Browse files Browse the repository at this point in the history
* Fix async redis clients tracing

* Update changelog

* Add functional integration tests and fix linting issues

---------

Co-authored-by: Shalev Roda <[email protected]>
  • Loading branch information
Vivanov98 and shalevr authored Jun 25, 2023
1 parent e70437a commit cd6b024
Show file tree
Hide file tree
Showing 4 changed files with 270 additions and 34 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Fix async redis clients not being traced correctly ([#1830](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1830))
- Make Flask request span attributes available for `start_span`.
([#1784](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1784))
- Fix falcon instrumentation's usage of Span Status to only set the description if the status code is ERROR.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,18 +136,52 @@ def _set_connection_attributes(span, conn):
span.set_attribute(key, value)


def _build_span_name(instance, cmd_args):
if len(cmd_args) > 0 and cmd_args[0]:
name = cmd_args[0]
else:
name = instance.connection_pool.connection_kwargs.get("db", 0)
return name


def _build_span_meta_data_for_pipeline(instance):
try:
command_stack = (
instance.command_stack
if hasattr(instance, "command_stack")
else instance._command_stack
)

cmds = [
_format_command_args(c.args if hasattr(c, "args") else c[0])
for c in command_stack
]
resource = "\n".join(cmds)

span_name = " ".join(
[
(c.args[0] if hasattr(c, "args") else c[0][0])
for c in command_stack
]
)
except (AttributeError, IndexError):
command_stack = []
resource = ""
span_name = ""

return command_stack, resource, span_name


# pylint: disable=R0915
def _instrument(
tracer,
request_hook: _RequestHookT = None,
response_hook: _ResponseHookT = None,
):
def _traced_execute_command(func, instance, args, kwargs):
query = _format_command_args(args)
name = _build_span_name(instance, args)

if len(args) > 0 and args[0]:
name = args[0]
else:
name = instance.connection_pool.connection_kwargs.get("db", 0)
with tracer.start_as_current_span(
name, kind=trace.SpanKind.CLIENT
) as span:
Expand All @@ -163,31 +197,11 @@ def _traced_execute_command(func, instance, args, kwargs):
return response

def _traced_execute_pipeline(func, instance, args, kwargs):
try:
command_stack = (
instance.command_stack
if hasattr(instance, "command_stack")
else instance._command_stack
)

cmds = [
_format_command_args(
c.args if hasattr(c, "args") else c[0],
)
for c in command_stack
]
resource = "\n".join(cmds)

span_name = " ".join(
[
(c.args[0] if hasattr(c, "args") else c[0][0])
for c in command_stack
]
)
except (AttributeError, IndexError):
command_stack = []
resource = ""
span_name = ""
(
command_stack,
resource,
span_name,
) = _build_span_meta_data_for_pipeline(instance)

with tracer.start_as_current_span(
span_name, kind=trace.SpanKind.CLIENT
Expand Down Expand Up @@ -232,32 +246,72 @@ def _traced_execute_pipeline(func, instance, args, kwargs):
"ClusterPipeline.execute",
_traced_execute_pipeline,
)

async def _async_traced_execute_command(func, instance, args, kwargs):
query = _format_command_args(args)
name = _build_span_name(instance, args)

with tracer.start_as_current_span(
name, kind=trace.SpanKind.CLIENT
) as span:
if span.is_recording():
span.set_attribute(SpanAttributes.DB_STATEMENT, query)
_set_connection_attributes(span, instance)
span.set_attribute("db.redis.args_length", len(args))
if callable(request_hook):
request_hook(span, instance, args, kwargs)
response = await func(*args, **kwargs)
if callable(response_hook):
response_hook(span, instance, response)
return response

async def _async_traced_execute_pipeline(func, instance, args, kwargs):
(
command_stack,
resource,
span_name,
) = _build_span_meta_data_for_pipeline(instance)

with tracer.start_as_current_span(
span_name, kind=trace.SpanKind.CLIENT
) as span:
if span.is_recording():
span.set_attribute(SpanAttributes.DB_STATEMENT, resource)
_set_connection_attributes(span, instance)
span.set_attribute(
"db.redis.pipeline_length", len(command_stack)
)
response = await func(*args, **kwargs)
if callable(response_hook):
response_hook(span, instance, response)
return response

if redis.VERSION >= _REDIS_ASYNCIO_VERSION:
wrap_function_wrapper(
"redis.asyncio",
f"{redis_class}.execute_command",
_traced_execute_command,
_async_traced_execute_command,
)
wrap_function_wrapper(
"redis.asyncio.client",
f"{pipeline_class}.execute",
_traced_execute_pipeline,
_async_traced_execute_pipeline,
)
wrap_function_wrapper(
"redis.asyncio.client",
f"{pipeline_class}.immediate_execute_command",
_traced_execute_command,
_async_traced_execute_command,
)
if redis.VERSION >= _REDIS_ASYNCIO_CLUSTER_VERSION:
wrap_function_wrapper(
"redis.asyncio.cluster",
"RedisCluster.execute_command",
_traced_execute_command,
_async_traced_execute_command,
)
wrap_function_wrapper(
"redis.asyncio.cluster",
"ClusterPipeline.execute",
_traced_execute_pipeline,
_async_traced_execute_pipeline,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,36 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from unittest import mock

import redis
import redis.asyncio

from opentelemetry import trace
from opentelemetry.instrumentation.redis import RedisInstrumentor
from opentelemetry.test.test_base import TestBase
from opentelemetry.trace import SpanKind


class AsyncMock:
"""A sufficient async mock implementation.
Python 3.7 doesn't have an inbuilt async mock class, so this is used.
"""

def __init__(self):
self.mock = mock.Mock()

async def __call__(self, *args, **kwargs):
future = asyncio.Future()
future.set_result("random")
return future

def __getattr__(self, item):
return AsyncMock()


class TestRedis(TestBase):
def setUp(self):
super().setUp()
Expand Down Expand Up @@ -87,6 +107,35 @@ def test_instrument_uninstrument(self):
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)

def test_instrument_uninstrument_async_client_command(self):
redis_client = redis.asyncio.Redis()

with mock.patch.object(redis_client, "connection", AsyncMock()):
asyncio.run(redis_client.get("key"))

spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)
self.memory_exporter.clear()

# Test uninstrument
RedisInstrumentor().uninstrument()

with mock.patch.object(redis_client, "connection", AsyncMock()):
asyncio.run(redis_client.get("key"))

spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 0)
self.memory_exporter.clear()

# Test instrument again
RedisInstrumentor().instrument()

with mock.patch.object(redis_client, "connection", AsyncMock()):
asyncio.run(redis_client.get("key"))

spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)

def test_response_hook(self):
redis_client = redis.Redis()
connection = redis.connection.Connection()
Expand Down
Loading

0 comments on commit cd6b024

Please sign in to comment.