Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix asynchonous unary call traces #536

Merged
merged 15 commits into from
Jul 12, 2021
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
([#560](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/560))
- `opentelemetry-instrumentation-django` Migrated Django middleware to new-style.
([#533](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/533))
- `opentelemetry-instrumentation-grpc` Fixed asynchonous unary call traces
([#536](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/536))

### Added
- `opentelemetry-instrumentation-httpx` Add `httpx` instrumentation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,27 +33,6 @@
from opentelemetry.trace.status import Status, StatusCode


class _GuardedSpan:
def __init__(self, span):
self.span = span
self.generated_span = None
self._engaged = True

def __enter__(self):
self.generated_span = self.span.__enter__()
return self

def __exit__(self, *args, **kwargs):
if self._engaged:
self.generated_span = None
return self.span.__exit__(*args, **kwargs)
return False

def release(self):
self._engaged = False
return self.span


class _CarrierSetter(Setter):
"""We use a custom setter in order to be able to lower case
keys as is required by grpc.
Expand All @@ -68,7 +47,7 @@ def set(self, carrier: MutableMapping[str, str], key: str, value: str):

def _make_future_done_callback(span, rpc_info):
def callback(response_future):
with span:
with trace.use_span(span, end_on_exit=True):
code = response_future.code()
if code != grpc.StatusCode.OK:
rpc_info.error = code
Expand All @@ -85,7 +64,7 @@ class OpenTelemetryClientInterceptor(
def __init__(self, tracer):
self._tracer = tracer

def _start_span(self, method):
def _start_span(self, method, **kwargs):
service, meth = method.lstrip("/").split("/", 1)
attributes = {
SpanAttributes.RPC_SYSTEM: "grpc",
Expand All @@ -95,16 +74,19 @@ def _start_span(self, method):
}

return self._tracer.start_as_current_span(
name=method, kind=trace.SpanKind.CLIENT, attributes=attributes
name=method,
kind=trace.SpanKind.CLIENT,
attributes=attributes,
**kwargs,
)

# pylint:disable=no-self-use
def _trace_result(self, guarded_span, rpc_info, result):
# If the RPC is called asynchronously, release the guard and add a
# callback so that the span can be finished once the future is done.
def _trace_result(self, span, rpc_info, result):
# If the RPC is called asynchronously, add a callback to end the span
# when the future is done, else end the span immediately
if isinstance(result, grpc.Future):
result.add_done_callback(
_make_future_done_callback(guarded_span.release(), rpc_info)
_make_future_done_callback(span, rpc_info)
)
return result
response = result
Expand All @@ -115,41 +97,54 @@ def _trace_result(self, guarded_span, rpc_info, result):
if isinstance(result, tuple):
response = result[0]
rpc_info.response = response

span.end()
return result

def _start_guarded_span(self, *args, **kwargs):
return _GuardedSpan(self._start_span(*args, **kwargs))

def intercept_unary(self, request, metadata, client_info, invoker):
def _intercept(self, request, metadata, client_info, invoker):
if not metadata:
mutable_metadata = OrderedDict()
else:
mutable_metadata = OrderedDict(metadata)

with self._start_guarded_span(client_info.full_method) as guarded_span:
inject(mutable_metadata, setter=_carrier_setter)
metadata = tuple(mutable_metadata.items())

rpc_info = RpcInfo(
full_method=client_info.full_method,
metadata=metadata,
timeout=client_info.timeout,
request=request,
)

with self._start_span(
client_info.full_method,
end_on_exit=False,
record_exception=False,
set_status_on_exception=False,
) as span:
result = None
try:
result = invoker(request, metadata)
except grpc.RpcError as err:
guarded_span.generated_span.set_status(
Status(StatusCode.ERROR)
inject(mutable_metadata, setter=_carrier_setter)
metadata = tuple(mutable_metadata.items())

rpc_info = RpcInfo(
full_method=client_info.full_method,
metadata=metadata,
timeout=client_info.timeout,
request=request,
)
guarded_span.generated_span.set_attribute(
SpanAttributes.RPC_GRPC_STATUS_CODE, err.code().value[0]

result = invoker(request, metadata)
except Exception as exc:
if isinstance(exc, grpc.RpcError):
span.set_attribute(
SpanAttributes.RPC_GRPC_STATUS_CODE,
exc.code().value[0],
)
span.set_status(
Status(
status_code=StatusCode.ERROR,
description="{}: {}".format(type(exc).__name__, exc),
)
)
raise err
span.record_exception(exc)
raise exc
finally:
if not result:
sengjea marked this conversation as resolved.
Show resolved Hide resolved
span.end()
return self._trace_result(span, rpc_info, result)
sengjea marked this conversation as resolved.
Show resolved Hide resolved

return self._trace_result(guarded_span, rpc_info, result)
def intercept_unary(self, request, metadata, client_info, invoker):
return self._intercept(request, metadata, client_info, invoker)

# For RPCs that stream responses, the result can be a generator. To record
# the span across the generated responses and detect any errors, we wrap
Expand Down Expand Up @@ -194,32 +189,6 @@ def intercept_stream(
request_or_iterator, metadata, client_info, invoker
)

if not metadata:
mutable_metadata = OrderedDict()
else:
mutable_metadata = OrderedDict(metadata)

with self._start_guarded_span(client_info.full_method) as guarded_span:
inject(mutable_metadata, setter=_carrier_setter)
metadata = tuple(mutable_metadata.items())
rpc_info = RpcInfo(
full_method=client_info.full_method,
metadata=metadata,
timeout=client_info.timeout,
request=request_or_iterator,
)

rpc_info.request = request_or_iterator

try:
result = invoker(request_or_iterator, metadata)
except grpc.RpcError as err:
guarded_span.generated_span.set_status(
Status(StatusCode.ERROR)
)
guarded_span.generated_span.set_attribute(
SpanAttributes.RPC_GRPC_STATUS_CODE, err.code().value[0],
)
raise err

return self._trace_result(guarded_span, rpc_info, result)
return self._intercept(
request_or_iterator, metadata, client_info, invoker
)
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ def simple_method(stub, error=False):
stub.SimpleMethod(request)


def simple_method_future(stub, error=False):
request = Request(
client_id=CLIENT_ID, request_data="error" if error else "data"
)
return stub.SimpleMethod.future(request)


def client_streaming_method(stub, error=False):
# create a generator
def request_messages():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
client_streaming_method,
server_streaming_method,
simple_method,
simple_method_future,
)
from ._server import create_test_server
from .protobuf.test_server_pb2 import Request
Expand Down Expand Up @@ -100,6 +101,20 @@ def tearDown(self):
self.server.stop(None)
self.channel.close()

def test_unary_unary_future(self):
simple_method_future(self._stub).result()
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)
span = spans[0]

self.assertEqual(span.name, "/GRPCTestServer/SimpleMethod")
self.assertIs(span.kind, trace.SpanKind.CLIENT)

# Check version and name in span's instrumentation info
self.check_span_instrumentation_info(
span, opentelemetry.instrumentation.grpc
)

def test_unary_unary(self):
simple_method(self._stub)
spans = self.memory_exporter.get_finished_spans()
Expand Down