diff --git a/CHANGELOG.md b/CHANGELOG.md index 768b0a7e0a..4996f9b9f3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,6 +35,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ([#1413](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1413)) - `opentelemetry-instrumentation-pyramid` Add support for regular expression matching and sanitization of HTTP headers. ([#1414](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1414)) +- `opentelemetry-instrumentation-grpc` Add support for grpc.aio Clients and Servers + ([#1245](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1245)) - Add metric exporter for Prometheus Remote Write ([#1359](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1359)) diff --git a/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/__init__.py b/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/__init__.py index 9b4b0c61fd..25010e147b 100644 --- a/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/__init__.py @@ -108,7 +108,7 @@ def serve(): logging.basicConfig() serve() -You can also add the instrumentor manually, rather than using +You can also add the interceptor manually, rather than using :py:class:`~opentelemetry.instrumentation.grpc.GrpcInstrumentorServer`: .. code-block:: python @@ -118,6 +118,117 @@ def serve(): server = grpc.server(futures.ThreadPoolExecutor(), interceptors = [server_interceptor()]) +Usage Aio Client +---------------- +.. code-block:: python + + import logging + import asyncio + + import grpc + + from opentelemetry import trace + from opentelemetry.instrumentation.grpc import GrpcAioInstrumentorClient + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export import ( + ConsoleSpanExporter, + SimpleSpanProcessor, + ) + + try: + from .gen import helloworld_pb2, helloworld_pb2_grpc + except ImportError: + from gen import helloworld_pb2, helloworld_pb2_grpc + + trace.set_tracer_provider(TracerProvider()) + trace.get_tracer_provider().add_span_processor( + SimpleSpanProcessor(ConsoleSpanExporter()) + ) + + grpc_client_instrumentor = GrpcAioInstrumentorClient() + grpc_client_instrumentor.instrument() + + async def run(): + with grpc.aio.insecure_channel("localhost:50051") as channel: + + stub = helloworld_pb2_grpc.GreeterStub(channel) + response = await stub.SayHello(helloworld_pb2.HelloRequest(name="YOU")) + + print("Greeter client received: " + response.message) + + + if __name__ == "__main__": + logging.basicConfig() + asyncio.run(run()) + +You can also add the interceptor manually, rather than using +:py:class:`~opentelemetry.instrumentation.grpc.GrpcAioInstrumentorClient`: + +.. code-block:: python + + from opentelemetry.instrumentation.grpc import aio_client_interceptors + + channel = grpc.aio.insecure_channel("localhost:12345", interceptors=aio_client_interceptors()) + + +Usage Aio Server +---------------- +.. code-block:: python + + import logging + import asyncio + + import grpc + + from opentelemetry import trace + from opentelemetry.instrumentation.grpc import GrpcAioInstrumentorServer + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export import ( + ConsoleSpanExporter, + SimpleSpanProcessor, + ) + + try: + from .gen import helloworld_pb2, helloworld_pb2_grpc + except ImportError: + from gen import helloworld_pb2, helloworld_pb2_grpc + + trace.set_tracer_provider(TracerProvider()) + trace.get_tracer_provider().add_span_processor( + SimpleSpanProcessor(ConsoleSpanExporter()) + ) + + grpc_server_instrumentor = GrpcAioInstrumentorServer() + grpc_server_instrumentor.instrument() + + class Greeter(helloworld_pb2_grpc.GreeterServicer): + async def SayHello(self, request, context): + return helloworld_pb2.HelloReply(message="Hello, %s!" % request.name) + + + async def serve(): + + server = grpc.aio.server() + + helloworld_pb2_grpc.add_GreeterServicer_to_server(Greeter(), server) + server.add_insecure_port("[::]:50051") + await server.start() + await server.wait_for_termination() + + + if __name__ == "__main__": + logging.basicConfig() + asyncio.run(serve()) + +You can also add the interceptor manually, rather than using +:py:class:`~opentelemetry.instrumentation.grpc.GrpcAioInstrumentorServer`: + +.. code-block:: python + + from opentelemetry.instrumentation.grpc import aio_server_interceptor + + server = grpc.aio.server(interceptors = [aio_server_interceptor()]) + Filters ------- @@ -244,6 +355,58 @@ def _uninstrument(self, **kwargs): grpc.server = self._original_func +class GrpcAioInstrumentorServer(BaseInstrumentor): + """ + Globally instrument the grpc.aio server. + + Usage:: + + grpc_aio_server_instrumentor = GrpcAioInstrumentorServer() + grpc_aio_server_instrumentor.instrument() + + """ + + # pylint:disable=attribute-defined-outside-init, redefined-outer-name + + def __init__(self, filter_=None): + excluded_service_filter = _excluded_service_filter() + if excluded_service_filter is not None: + if filter_ is None: + filter_ = excluded_service_filter + else: + filter_ = any_of(filter_, excluded_service_filter) + self._filter = filter_ + + def instrumentation_dependencies(self) -> Collection[str]: + return _instruments + + def _instrument(self, **kwargs): + self._original_func = grpc.aio.server + tracer_provider = kwargs.get("tracer_provider") + + def server(*args, **kwargs): + if "interceptors" in kwargs: + # add our interceptor as the first + kwargs["interceptors"].insert( + 0, + aio_server_interceptor( + tracer_provider=tracer_provider, filter_=self._filter + ), + ) + else: + kwargs["interceptors"] = [ + aio_server_interceptor( + tracer_provider=tracer_provider, filter_=self._filter + ) + ] + return self._original_func(*args, **kwargs) + + grpc.aio.server = server + + def _uninstrument(self, **kwargs): + grpc.aio.server = self._original_func + + class GrpcInstrumentorClient(BaseInstrumentor): """ Globally instrument the grpc client @@ -315,6 +478,69 @@ def wrapper_fn(self, original_func, instance, args, kwargs): ) +class GrpcAioInstrumentorClient(BaseInstrumentor): + """ + Globally instrument the grpc.aio client. + + Usage:: + + grpc_aio_client_instrumentor = GrpcAioInstrumentorClient() + grpc_aio_client_instrumentor.instrument() + + """ + + # pylint:disable=attribute-defined-outside-init, redefined-outer-name + + def __init__(self, filter_=None): + excluded_service_filter = _excluded_service_filter() + if excluded_service_filter is not None: + if filter_ is None: + filter_ = excluded_service_filter + else: + filter_ = any_of(filter_, excluded_service_filter) + self._filter = filter_ + + def instrumentation_dependencies(self) -> Collection[str]: + return _instruments + + def _add_interceptors(self, tracer_provider, kwargs): + if "interceptors" in kwargs and kwargs["interceptors"]: + kwargs["interceptors"] = ( + aio_client_interceptors( + tracer_provider=tracer_provider, filter_=self._filter + ) + + kwargs["interceptors"] + ) + else: + kwargs["interceptors"] = aio_client_interceptors( + tracer_provider=tracer_provider, filter_=self._filter + ) + + return kwargs + + def _instrument(self, **kwargs): + self._original_insecure = grpc.aio.insecure_channel + self._original_secure = grpc.aio.secure_channel + tracer_provider = kwargs.get("tracer_provider") + + def insecure(*args, **kwargs): + kwargs = self._add_interceptors(tracer_provider, kwargs) + + return self._original_insecure(*args, **kwargs) + + def secure(*args, **kwargs): + kwargs = self._add_interceptors(tracer_provider, kwargs) + + return self._original_secure(*args, **kwargs) + + grpc.aio.insecure_channel = insecure + grpc.aio.secure_channel = secure + + def _uninstrument(self, **kwargs): + grpc.aio.insecure_channel = self._original_insecure + grpc.aio.secure_channel = self._original_secure + + def client_interceptor(tracer_provider=None, filter_=None): """Create a gRPC client channel interceptor. @@ -355,6 +581,45 @@ def server_interceptor(tracer_provider=None, filter_=None): return _server.OpenTelemetryServerInterceptor(tracer, filter_=filter_) +def aio_client_interceptors(tracer_provider=None, filter_=None): + """Create a gRPC client channel interceptor. + + Args: + tracer: The tracer to use to create client-side spans. + + Returns: + An invocation-side interceptor object. + """ + from . import _aio_client + + tracer = trace.get_tracer(__name__, __version__, tracer_provider) + + return [ + _aio_client.UnaryUnaryAioClientInterceptor(tracer, filter_=filter_), + _aio_client.UnaryStreamAioClientInterceptor(tracer, filter_=filter_), + _aio_client.StreamUnaryAioClientInterceptor(tracer, filter_=filter_), + _aio_client.StreamStreamAioClientInterceptor(tracer, filter_=filter_), + ] + + +def aio_server_interceptor(tracer_provider=None, filter_=None): + """Create a gRPC aio server interceptor. + + Args: + tracer: The tracer to use to create server-side spans. + + Returns: + A service-side interceptor object. + """ + from . import _aio_server + + tracer = trace.get_tracer(__name__, __version__, tracer_provider) + + return _aio_server.OpenTelemetryAioServerInterceptor( + tracer, filter_=filter_ + ) + + def _excluded_service_filter() -> Union[Callable[[object], bool], None]: services = _parse_services( os.environ.get("OTEL_PYTHON_GRPC_EXCLUDED_SERVICES", "") diff --git a/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_aio_client.py b/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_aio_client.py new file mode 100644 index 0000000000..c7630bfe9f --- /dev/null +++ b/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_aio_client.py @@ -0,0 +1,222 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from collections import OrderedDict + +import grpc +from grpc.aio import ClientCallDetails + +from opentelemetry import context +from opentelemetry.instrumentation.grpc._client import ( + OpenTelemetryClientInterceptor, + _carrier_setter, +) +from opentelemetry.instrumentation.utils import _SUPPRESS_INSTRUMENTATION_KEY +from opentelemetry.propagate import inject +from opentelemetry.semconv.trace import SpanAttributes +from opentelemetry.trace.status import Status, StatusCode + + +def _unary_done_callback(span, code, details): + def callback(call): + try: + span.set_attribute( + SpanAttributes.RPC_GRPC_STATUS_CODE, + code.value[0], + ) + if code != grpc.StatusCode.OK: + span.set_status( + Status( + status_code=StatusCode.ERROR, + description=details, + ) + ) + finally: + span.end() + + return callback + + +class _BaseAioClientInterceptor(OpenTelemetryClientInterceptor): + @staticmethod + def propagate_trace_in_details(client_call_details): + metadata = client_call_details.metadata + if not metadata: + mutable_metadata = OrderedDict() + else: + mutable_metadata = OrderedDict(metadata) + + inject(mutable_metadata, setter=_carrier_setter) + metadata = tuple(mutable_metadata.items()) + + return ClientCallDetails( + client_call_details.method, + client_call_details.timeout, + metadata, + client_call_details.credentials, + client_call_details.wait_for_ready, + ) + + @staticmethod + def add_error_details_to_span(span, exc): + if isinstance(exc, grpc.RpcError): + span.set_attribute( + SpanAttributes.RPC_GRPC_STATUS_CODE, + exc.code().value[0], + ) + span.set_status( + Status( + status_code=StatusCode.ERROR, + description=f"{type(exc).__name__}: {exc}", + ) + ) + span.record_exception(exc) + + def _start_interceptor_span(self, method): + # method _should_ be a string here but due to a bug in grpc, it is + # populated with a bytes object. Handle both cases such that we + # are forward-compatible with a fixed version of grpc + # More info: https://github.com/grpc/grpc/issues/31092 + if isinstance(method, bytes): + method = method.decode() + + return self._start_span( + method, + end_on_exit=False, + record_exception=False, + set_status_on_exception=False, + ) + + async def _wrap_unary_response(self, continuation, span): + try: + call = await continuation() + + # code and details are both coroutines that need to be await-ed, + # the callbacks added with add_done_callback do not allow async + # code so we need to get the code and details here then pass them + # to the callback. + code = await call.code() + details = await call.details() + + call.add_done_callback(_unary_done_callback(span, code, details)) + + return call + except grpc.aio.AioRpcError as exc: + self.add_error_details_to_span(span, exc) + raise exc + + async def _wrap_stream_response(self, span, call): + try: + async for response in call: + yield response + except Exception as exc: + self.add_error_details_to_span(span, exc) + raise exc + finally: + span.end() + + def tracing_skipped(self, client_call_details): + return context.get_value( + _SUPPRESS_INSTRUMENTATION_KEY + ) or not self.rpc_matches_filters(client_call_details) + + def rpc_matches_filters(self, client_call_details): + return self._filter is None or self._filter(client_call_details) + + +class UnaryUnaryAioClientInterceptor( + grpc.aio.UnaryUnaryClientInterceptor, + _BaseAioClientInterceptor, +): + async def intercept_unary_unary( + self, continuation, client_call_details, request + ): + if self.tracing_skipped(client_call_details): + return await continuation(client_call_details, request) + + with self._start_interceptor_span( + client_call_details.method, + ) as span: + new_details = self.propagate_trace_in_details(client_call_details) + + continuation_with_args = functools.partial( + continuation, new_details, request + ) + return await self._wrap_unary_response( + continuation_with_args, span + ) + + +class UnaryStreamAioClientInterceptor( + grpc.aio.UnaryStreamClientInterceptor, + _BaseAioClientInterceptor, +): + async def intercept_unary_stream( + self, continuation, client_call_details, request + ): + if self.tracing_skipped(client_call_details): + return await continuation(client_call_details, request) + + with self._start_interceptor_span( + client_call_details.method, + ) as span: + new_details = self.propagate_trace_in_details(client_call_details) + + resp = await continuation(new_details, request) + + return self._wrap_stream_response(span, resp) + + +class StreamUnaryAioClientInterceptor( + grpc.aio.StreamUnaryClientInterceptor, + _BaseAioClientInterceptor, +): + async def intercept_stream_unary( + self, continuation, client_call_details, request_iterator + ): + if self.tracing_skipped(client_call_details): + return await continuation(client_call_details, request_iterator) + + with self._start_interceptor_span( + client_call_details.method, + ) as span: + new_details = self.propagate_trace_in_details(client_call_details) + + continuation_with_args = functools.partial( + continuation, new_details, request_iterator + ) + return await self._wrap_unary_response( + continuation_with_args, span + ) + + +class StreamStreamAioClientInterceptor( + grpc.aio.StreamStreamClientInterceptor, + _BaseAioClientInterceptor, +): + async def intercept_stream_stream( + self, continuation, client_call_details, request_iterator + ): + if self.tracing_skipped(client_call_details): + return await continuation(client_call_details, request_iterator) + + with self._start_interceptor_span( + client_call_details.method, + ) as span: + new_details = self.propagate_trace_in_details(client_call_details) + + resp = await continuation(new_details, request_iterator) + + return self._wrap_stream_response(span, resp) diff --git a/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_aio_server.py b/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_aio_server.py new file mode 100644 index 0000000000..d64dcf000b --- /dev/null +++ b/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_aio_server.py @@ -0,0 +1,108 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import grpc.aio + +from ._server import ( + OpenTelemetryServerInterceptor, + _OpenTelemetryServicerContext, + _wrap_rpc_behavior, +) + + +class OpenTelemetryAioServerInterceptor( + grpc.aio.ServerInterceptor, OpenTelemetryServerInterceptor +): + """ + An AsyncIO gRPC server interceptor, to add OpenTelemetry. + Usage:: + tracer = some OpenTelemetry tracer + interceptors = [ + AsyncOpenTelemetryServerInterceptor(tracer), + ] + server = aio.server( + futures.ThreadPoolExecutor(max_workers=concurrency), + interceptors = (interceptors,)) + """ + + async def intercept_service(self, continuation, handler_call_details): + if self._filter is not None and not self._filter(handler_call_details): + return await continuation(handler_call_details) + + def telemetry_wrapper(behavior, request_streaming, response_streaming): + # handle streaming responses specially + if response_streaming: + return self._intercept_aio_server_stream( + behavior, + handler_call_details, + ) + + return self._intercept_aio_server_unary( + behavior, + handler_call_details, + ) + + next_handler = await continuation(handler_call_details) + + return _wrap_rpc_behavior(next_handler, telemetry_wrapper) + + def _intercept_aio_server_unary(self, behavior, handler_call_details): + async def _unary_interceptor(request_or_iterator, context): + with self._set_remote_context(context): + with self._start_span( + handler_call_details, + context, + set_status_on_exception=False, + ) as span: + # wrap the context + context = _OpenTelemetryServicerContext(context, span) + + # And now we run the actual RPC. + try: + return await behavior(request_or_iterator, context) + + except Exception as error: + # Bare exceptions are likely to be gRPC aborts, which + # we handle in our context wrapper. + # Here, we're interested in uncaught exceptions. + # pylint:disable=unidiomatic-typecheck + if type(error) != Exception: + span.record_exception(error) + raise error + + return _unary_interceptor + + def _intercept_aio_server_stream(self, behavior, handler_call_details): + async def _stream_interceptor(request_or_iterator, context): + with self._set_remote_context(context): + with self._start_span( + handler_call_details, + context, + set_status_on_exception=False, + ) as span: + context = _OpenTelemetryServicerContext(context, span) + + try: + async for response in behavior( + request_or_iterator, context + ): + yield response + + except Exception as error: + # pylint:disable=unidiomatic-typecheck + if type(error) != Exception: + span.record_exception(error) + raise error + + return _stream_interceptor diff --git a/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/filters/__init__.py b/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/filters/__init__.py index 905bb8d696..8100a2d17f 100644 --- a/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/filters/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/filters/__init__.py @@ -18,7 +18,10 @@ import grpc TCallDetails = TypeVar( - "TCallDetails", grpc.HandlerCallDetails, grpc.ClientCallDetails + "TCallDetails", + grpc.HandlerCallDetails, + grpc.ClientCallDetails, + grpc.aio.ClientCallDetails, ) Condition = Callable[[TCallDetails], bool] @@ -27,10 +30,25 @@ def _full_method(metadata): name = "" if isinstance(metadata, grpc.HandlerCallDetails): name = metadata.method + elif isinstance(metadata, grpc.aio.ClientCallDetails): + name = metadata.method + # name _should_ be a string here but due to a bug in grpc, it is + # populated with a bytes object. Handle both cases such that we + # are forward-compatible with a fixed version of grpc + # More info: https://github.com/grpc/grpc/issues/31092 + if isinstance(name, bytes): + name = name.decode() # NOTE: replace here if there's better way to match cases to handle # grpcext._interceptor._UnaryClientInfo/_StreamClientInfo elif hasattr(metadata, "full_method"): name = metadata.full_method + # NOTE: this is to handle the grpc.aio Server case. The type interface + # indicates that metadata should be a grpc.HandlerCallDetails and be + # matched prior to this but it is in fact an internal C-extension level + # object. + elif hasattr(metadata, "method"): + name = metadata.method + return name diff --git a/instrumentation/opentelemetry-instrumentation-grpc/tests/_aio_client.py b/instrumentation/opentelemetry-instrumentation-grpc/tests/_aio_client.py new file mode 100644 index 0000000000..9658df1587 --- /dev/null +++ b/instrumentation/opentelemetry-instrumentation-grpc/tests/_aio_client.py @@ -0,0 +1,56 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .protobuf.test_server_pb2 import Request + +CLIENT_ID = 1 + + +async def simple_method(stub, error=False): + request = Request( + client_id=CLIENT_ID, request_data="error" if error else "data" + ) + return await stub.SimpleMethod(request) + + +async def client_streaming_method(stub, error=False): + # create a generator + def request_messages(): + for _ in range(5): + request = Request( + client_id=CLIENT_ID, request_data="error" if error else "data" + ) + yield request + + return await stub.ClientStreamingMethod(request_messages()) + + +def server_streaming_method(stub, error=False): + request = Request( + client_id=CLIENT_ID, request_data="error" if error else "data" + ) + + return stub.ServerStreamingMethod(request) + + +def bidirectional_streaming_method(stub, error=False): + # create a generator + def request_messages(): + for _ in range(5): + request = Request( + client_id=CLIENT_ID, request_data="error" if error else "data" + ) + yield request + + return stub.BidirectionalStreamingMethod(request_messages()) diff --git a/instrumentation/opentelemetry-instrumentation-grpc/tests/test_aio_client_interceptor.py b/instrumentation/opentelemetry-instrumentation-grpc/tests/test_aio_client_interceptor.py new file mode 100644 index 0000000000..6ca5ce92d5 --- /dev/null +++ b/instrumentation/opentelemetry-instrumentation-grpc/tests/test_aio_client_interceptor.py @@ -0,0 +1,366 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +try: + from unittest import IsolatedAsyncioTestCase +except ImportError: + # unittest.IsolatedAsyncioTestCase was introduced in Python 3.8. It's use + # simplifies the following tests. Without it, the amount of test code + # increases significantly, with most of the additional code handling + # the asyncio set up. + from unittest import TestCase + + class IsolatedAsyncioTestCase(TestCase): + def run(self, result=None): + self.skipTest( + "This test requires Python 3.8 for unittest.IsolatedAsyncioTestCase" + ) + + +import grpc +import pytest + +import opentelemetry.instrumentation.grpc +from opentelemetry import context, trace +from opentelemetry.instrumentation.grpc import ( + GrpcAioInstrumentorClient, + aio_client_interceptors, +) +from opentelemetry.instrumentation.grpc._aio_client import ( + UnaryUnaryAioClientInterceptor, +) +from opentelemetry.instrumentation.utils import _SUPPRESS_INSTRUMENTATION_KEY +from opentelemetry.propagate import get_global_textmap, set_global_textmap +from opentelemetry.semconv.trace import SpanAttributes +from opentelemetry.test.mock_textmap import MockTextMapPropagator +from opentelemetry.test.test_base import TestBase + +from ._aio_client import ( + bidirectional_streaming_method, + client_streaming_method, + server_streaming_method, + simple_method, +) +from ._server import create_test_server +from .protobuf import test_server_pb2_grpc # pylint: disable=no-name-in-module + + +class RecordingInterceptor(grpc.aio.UnaryUnaryClientInterceptor): + recorded_details = None + + async def intercept_unary_unary( + self, continuation, client_call_details, request + ): + self.recorded_details = client_call_details + return await continuation(client_call_details, request) + + +@pytest.mark.asyncio +class TestAioClientInterceptor(TestBase, IsolatedAsyncioTestCase): + def setUp(self): + super().setUp() + self.server = create_test_server(25565) + self.server.start() + + interceptors = aio_client_interceptors() + self._channel = grpc.aio.insecure_channel( + "localhost:25565", interceptors=interceptors + ) + + self._stub = test_server_pb2_grpc.GRPCTestServerStub(self._channel) + + def tearDown(self): + super().tearDown() + self.server.stop(1000) + + async def asyncTearDown(self): + await self._channel.close() + + async def test_instrument(self): + instrumentor = GrpcAioInstrumentorClient() + + try: + instrumentor.instrument() + + channel = grpc.aio.insecure_channel("localhost:25565") + stub = test_server_pb2_grpc.GRPCTestServerStub(channel) + + response = await simple_method(stub) + assert response.response_data == "data" + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + finally: + instrumentor.uninstrument() + + async def test_uninstrument(self): + instrumentor = GrpcAioInstrumentorClient() + + instrumentor.instrument() + instrumentor.uninstrument() + + channel = grpc.aio.insecure_channel("localhost:25565") + stub = test_server_pb2_grpc.GRPCTestServerStub(channel) + + response = await simple_method(stub) + assert response.response_data == "data" + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 0) + + async def test_unary_unary(self): + response = await simple_method(self._stub) + assert response.response_data == "data" + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + span = spans[0] + + self.assertEqual(span.name, "/GRPCTestServer/SimpleMethod") + self.assertIs(span.kind, trace.SpanKind.CLIENT) + + # Check version and name in span's instrumentation info + self.assertEqualSpanInstrumentationInfo( + span, opentelemetry.instrumentation.grpc + ) + + self.assertSpanHasAttributes( + span, + { + SpanAttributes.RPC_METHOD: "SimpleMethod", + SpanAttributes.RPC_SERVICE: "GRPCTestServer", + SpanAttributes.RPC_SYSTEM: "grpc", + SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.OK.value[ + 0 + ], + }, + ) + + async def test_unary_stream(self): + async for response in server_streaming_method(self._stub): + self.assertEqual(response.response_data, "data") + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + span = spans[0] + + self.assertEqual(span.name, "/GRPCTestServer/ServerStreamingMethod") + self.assertIs(span.kind, trace.SpanKind.CLIENT) + + # Check version and name in span's instrumentation info + self.assertEqualSpanInstrumentationInfo( + span, opentelemetry.instrumentation.grpc + ) + + self.assertSpanHasAttributes( + span, + { + SpanAttributes.RPC_METHOD: "ServerStreamingMethod", + SpanAttributes.RPC_SERVICE: "GRPCTestServer", + SpanAttributes.RPC_SYSTEM: "grpc", + SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.OK.value[ + 0 + ], + }, + ) + + async def test_stream_unary(self): + response = await client_streaming_method(self._stub) + assert response.response_data == "data" + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + span = spans[0] + + self.assertEqual(span.name, "/GRPCTestServer/ClientStreamingMethod") + self.assertIs(span.kind, trace.SpanKind.CLIENT) + + # Check version and name in span's instrumentation info + self.assertEqualSpanInstrumentationInfo( + span, opentelemetry.instrumentation.grpc + ) + + self.assertSpanHasAttributes( + span, + { + SpanAttributes.RPC_METHOD: "ClientStreamingMethod", + SpanAttributes.RPC_SERVICE: "GRPCTestServer", + SpanAttributes.RPC_SYSTEM: "grpc", + SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.OK.value[ + 0 + ], + }, + ) + + async def test_stream_stream(self): + async for response in bidirectional_streaming_method(self._stub): + self.assertEqual(response.response_data, "data") + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + span = spans[0] + + self.assertEqual( + span.name, "/GRPCTestServer/BidirectionalStreamingMethod" + ) + self.assertIs(span.kind, trace.SpanKind.CLIENT) + + # Check version and name in span's instrumentation info + self.assertEqualSpanInstrumentationInfo( + span, opentelemetry.instrumentation.grpc + ) + + self.assertSpanHasAttributes( + span, + { + SpanAttributes.RPC_METHOD: "BidirectionalStreamingMethod", + SpanAttributes.RPC_SERVICE: "GRPCTestServer", + SpanAttributes.RPC_SYSTEM: "grpc", + SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.OK.value[ + 0 + ], + }, + ) + + async def test_error_simple(self): + with self.assertRaises(grpc.RpcError): + await simple_method(self._stub, error=True) + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + span = spans[0] + self.assertIs( + span.status.status_code, + trace.StatusCode.ERROR, + ) + + async def test_error_unary_stream(self): + with self.assertRaises(grpc.RpcError): + async for _ in server_streaming_method(self._stub, error=True): + pass + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + span = spans[0] + self.assertIs( + span.status.status_code, + trace.StatusCode.ERROR, + ) + + async def test_error_stream_unary(self): + with self.assertRaises(grpc.RpcError): + await client_streaming_method(self._stub, error=True) + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + span = spans[0] + self.assertIs( + span.status.status_code, + trace.StatusCode.ERROR, + ) + + async def test_error_stream_stream(self): + with self.assertRaises(grpc.RpcError): + async for _ in bidirectional_streaming_method( + self._stub, error=True + ): + pass + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + span = spans[0] + self.assertIs( + span.status.status_code, + trace.StatusCode.ERROR, + ) + + # pylint:disable=no-self-use + async def test_client_interceptor_trace_context_propagation(self): + """ensure that client interceptor correctly inject trace context into all outgoing requests.""" + + previous_propagator = get_global_textmap() + + try: + set_global_textmap(MockTextMapPropagator()) + + interceptor = UnaryUnaryAioClientInterceptor(trace.NoOpTracer()) + recording_interceptor = RecordingInterceptor() + interceptors = [interceptor, recording_interceptor] + + channel = grpc.aio.insecure_channel( + "localhost:25565", interceptors=interceptors + ) + + stub = test_server_pb2_grpc.GRPCTestServerStub(channel) + await simple_method(stub) + + metadata = recording_interceptor.recorded_details.metadata + assert len(metadata) == 2 + assert metadata[0][0] == "mock-traceid" + assert metadata[0][1] == "0" + assert metadata[1][0] == "mock-spanid" + assert metadata[1][1] == "0" + finally: + set_global_textmap(previous_propagator) + + async def test_unary_unary_with_suppress_key(self): + token = context.attach( + context.set_value(_SUPPRESS_INSTRUMENTATION_KEY, True) + ) + try: + response = await simple_method(self._stub) + assert response.response_data == "data" + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 0) + finally: + context.detach(token) + + async def test_unary_stream_with_suppress_key(self): + token = context.attach( + context.set_value(_SUPPRESS_INSTRUMENTATION_KEY, True) + ) + try: + async for response in server_streaming_method(self._stub): + self.assertEqual(response.response_data, "data") + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 0) + finally: + context.detach(token) + + async def test_stream_unary_with_suppress_key(self): + token = context.attach( + context.set_value(_SUPPRESS_INSTRUMENTATION_KEY, True) + ) + try: + response = await client_streaming_method(self._stub) + assert response.response_data == "data" + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 0) + finally: + context.detach(token) + + async def test_stream_stream_with_suppress_key(self): + token = context.attach( + context.set_value(_SUPPRESS_INSTRUMENTATION_KEY, True) + ) + try: + async for response in bidirectional_streaming_method(self._stub): + self.assertEqual(response.response_data, "data") + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 0) + finally: + context.detach(token) diff --git a/instrumentation/opentelemetry-instrumentation-grpc/tests/test_aio_client_interceptor_filter.py b/instrumentation/opentelemetry-instrumentation-grpc/tests/test_aio_client_interceptor_filter.py new file mode 100644 index 0000000000..b8c408c6cf --- /dev/null +++ b/instrumentation/opentelemetry-instrumentation-grpc/tests/test_aio_client_interceptor_filter.py @@ -0,0 +1,167 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +try: + from unittest import IsolatedAsyncioTestCase +except ImportError: + # unittest.IsolatedAsyncioTestCase was introduced in Python 3.8. It's use + # simplifies the following tests. Without it, the amount of test code + # increases significantly, with most of the additional code handling + # the asyncio set up. + from unittest import TestCase + + class IsolatedAsyncioTestCase(TestCase): + def run(self, result=None): + self.skipTest( + "This test requires Python 3.8 for unittest.IsolatedAsyncioTestCase" + ) + + +import os +from unittest import mock + +import grpc +import pytest + +from opentelemetry.instrumentation.grpc import ( + GrpcAioInstrumentorClient, + aio_client_interceptors, + filters, +) +from opentelemetry.test.test_base import TestBase + +from ._aio_client import ( + bidirectional_streaming_method, + client_streaming_method, + server_streaming_method, + simple_method, +) +from ._server import create_test_server +from .protobuf import test_server_pb2_grpc # pylint: disable=no-name-in-module + + +@pytest.mark.asyncio +class TestAioClientInterceptorFiltered(TestBase, IsolatedAsyncioTestCase): + def setUp(self): + super().setUp() + self.server = create_test_server(25565) + self.server.start() + + interceptors = aio_client_interceptors( + filter_=filters.method_name("NotSimpleMethod") + ) + self._channel = grpc.aio.insecure_channel( + "localhost:25565", interceptors=interceptors + ) + + self._stub = test_server_pb2_grpc.GRPCTestServerStub(self._channel) + + def tearDown(self): + super().tearDown() + self.server.stop(1000) + + async def asyncTearDown(self): + await self._channel.close() + + async def test_instrument_filtered(self): + instrumentor = GrpcAioInstrumentorClient( + filter_=filters.method_name("NotSimpleMethod") + ) + + try: + instrumentor.instrument() + + channel = grpc.aio.insecure_channel("localhost:25565") + stub = test_server_pb2_grpc.GRPCTestServerStub(channel) + + response = await simple_method(stub) + assert response.response_data == "data" + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 0) + finally: + instrumentor.uninstrument() + + async def test_instrument_filtered_env(self): + with mock.patch.dict( + os.environ, + { + "OTEL_PYTHON_GRPC_EXCLUDED_SERVICES": "GRPCMockServer,GRPCTestServer" + }, + ): + instrumentor = GrpcAioInstrumentorClient() + + try: + instrumentor.instrument() + + channel = grpc.aio.insecure_channel("localhost:25565") + stub = test_server_pb2_grpc.GRPCTestServerStub(channel) + + response = await simple_method(stub) + assert response.response_data == "data" + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 0) + finally: + instrumentor.uninstrument() + + async def test_instrument_filtered_env_and_option(self): + with mock.patch.dict( + os.environ, + {"OTEL_PYTHON_GRPC_EXCLUDED_SERVICES": "GRPCMockServer"}, + ): + instrumentor = GrpcAioInstrumentorClient( + filter_=filters.service_prefix("GRPCTestServer") + ) + + try: + instrumentor.instrument() + + channel = grpc.aio.insecure_channel("localhost:25565") + stub = test_server_pb2_grpc.GRPCTestServerStub(channel) + + response = await simple_method(stub) + assert response.response_data == "data" + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + finally: + instrumentor.uninstrument() + + async def test_unary_unary_filtered(self): + response = await simple_method(self._stub) + assert response.response_data == "data" + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 0) + + async def test_unary_stream_filtered(self): + async for response in server_streaming_method(self._stub): + self.assertEqual(response.response_data, "data") + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 0) + + async def test_stream_unary_filtered(self): + response = await client_streaming_method(self._stub) + assert response.response_data == "data" + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 0) + + async def test_stream_stream_filtered(self): + async for response in bidirectional_streaming_method(self._stub): + self.assertEqual(response.response_data, "data") + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 0) diff --git a/instrumentation/opentelemetry-instrumentation-grpc/tests/test_aio_server_interceptor.py b/instrumentation/opentelemetry-instrumentation-grpc/tests/test_aio_server_interceptor.py new file mode 100644 index 0000000000..a4075fe727 --- /dev/null +++ b/instrumentation/opentelemetry-instrumentation-grpc/tests/test_aio_server_interceptor.py @@ -0,0 +1,574 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio + +try: + from unittest import IsolatedAsyncioTestCase +except ImportError: + # unittest.IsolatedAsyncioTestCase was introduced in Python 3.8. It's use + # simplifies the following tests. Without it, the amount of test code + # increases significantly, with most of the additional code handling + # the asyncio set up. + from unittest import TestCase + + class IsolatedAsyncioTestCase(TestCase): + def run(self, result=None): + self.skipTest( + "This test requires Python 3.8 for unittest.IsolatedAsyncioTestCase" + ) + + +import grpc +import grpc.aio +import pytest + +import opentelemetry.instrumentation.grpc +from opentelemetry import trace +from opentelemetry.instrumentation.grpc import ( + GrpcAioInstrumentorServer, + aio_server_interceptor, +) +from opentelemetry.sdk import trace as trace_sdk +from opentelemetry.semconv.trace import SpanAttributes +from opentelemetry.test.test_base import TestBase +from opentelemetry.trace import StatusCode + +from .protobuf.test_server_pb2 import Request, Response +from .protobuf.test_server_pb2_grpc import ( + GRPCTestServerServicer, + add_GRPCTestServerServicer_to_server, +) + +# pylint:disable=unused-argument +# pylint:disable=no-self-use + + +class Servicer(GRPCTestServerServicer): + """Our test servicer""" + + # pylint:disable=C0103 + async def SimpleMethod(self, request, context): + return Response( + server_id=request.client_id, + response_data=request.request_data, + ) + + # pylint:disable=C0103 + async def ServerStreamingMethod(self, request, context): + for data in ("one", "two", "three"): + yield Response( + server_id=request.client_id, + response_data=data, + ) + + +async def run_with_test_server( + runnable, servicer=Servicer(), add_interceptor=True +): + if add_interceptor: + interceptors = [aio_server_interceptor()] + server = grpc.aio.server(interceptors=interceptors) + else: + server = grpc.aio.server() + + add_GRPCTestServerServicer_to_server(servicer, server) + + port = server.add_insecure_port("[::]:0") + channel = grpc.aio.insecure_channel(f"localhost:{port:d}") + + await server.start() + resp = await runnable(channel) + await server.stop(1000) + + return resp + + +@pytest.mark.asyncio +class TestOpenTelemetryAioServerInterceptor(TestBase, IsolatedAsyncioTestCase): + async def test_instrumentor(self): + """Check that automatic instrumentation configures the interceptor""" + rpc_call = "/GRPCTestServer/SimpleMethod" + + grpc_aio_server_instrumentor = GrpcAioInstrumentorServer() + try: + grpc_aio_server_instrumentor.instrument() + + async def request(channel): + request = Request(client_id=1, request_data="test") + msg = request.SerializeToString() + return await channel.unary_unary(rpc_call)(msg) + + await run_with_test_server(request, add_interceptor=False) + + spans_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans_list), 1) + span = spans_list[0] + + self.assertEqual(span.name, rpc_call) + self.assertIs(span.kind, trace.SpanKind.SERVER) + + # Check version and name in span's instrumentation info + self.assertEqualSpanInstrumentationInfo( + span, opentelemetry.instrumentation.grpc + ) + + # Check attributes + self.assertSpanHasAttributes( + span, + { + SpanAttributes.NET_PEER_IP: "[::1]", + SpanAttributes.NET_PEER_NAME: "localhost", + SpanAttributes.RPC_METHOD: "SimpleMethod", + SpanAttributes.RPC_SERVICE: "GRPCTestServer", + SpanAttributes.RPC_SYSTEM: "grpc", + SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.OK.value[ + 0 + ], + }, + ) + + finally: + grpc_aio_server_instrumentor.uninstrument() + + async def test_uninstrument(self): + """Check that uninstrument removes the interceptor""" + rpc_call = "/GRPCTestServer/SimpleMethod" + + grpc_aio_server_instrumentor = GrpcAioInstrumentorServer() + grpc_aio_server_instrumentor.instrument() + grpc_aio_server_instrumentor.uninstrument() + + async def request(channel): + request = Request(client_id=1, request_data="test") + msg = request.SerializeToString() + return await channel.unary_unary(rpc_call)(msg) + + await run_with_test_server(request, add_interceptor=False) + + spans_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans_list), 0) + + async def test_create_span(self): + """Check that the interceptor wraps calls with spans server-side.""" + rpc_call = "/GRPCTestServer/SimpleMethod" + + async def request(channel): + request = Request(client_id=1, request_data="test") + msg = request.SerializeToString() + return await channel.unary_unary(rpc_call)(msg) + + await run_with_test_server(request) + + spans_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans_list), 1) + span = spans_list[0] + + self.assertEqual(span.name, rpc_call) + self.assertIs(span.kind, trace.SpanKind.SERVER) + + # Check version and name in span's instrumentation info + self.assertEqualSpanInstrumentationInfo( + span, opentelemetry.instrumentation.grpc + ) + + # Check attributes + self.assertSpanHasAttributes( + span, + { + SpanAttributes.NET_PEER_IP: "[::1]", + SpanAttributes.NET_PEER_NAME: "localhost", + SpanAttributes.RPC_METHOD: "SimpleMethod", + SpanAttributes.RPC_SERVICE: "GRPCTestServer", + SpanAttributes.RPC_SYSTEM: "grpc", + SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.OK.value[ + 0 + ], + }, + ) + + async def test_create_two_spans(self): + """Verify that the interceptor captures sub spans within the given + trace""" + rpc_call = "/GRPCTestServer/SimpleMethod" + + class TwoSpanServicer(GRPCTestServerServicer): + # pylint:disable=C0103 + async def SimpleMethod(self, request, context): + + # create another span + tracer = trace.get_tracer(__name__) + with tracer.start_as_current_span("child") as child: + child.add_event("child event") + + return Response( + server_id=request.client_id, + response_data=request.request_data, + ) + + async def request(channel): + request = Request(client_id=1, request_data="test") + msg = request.SerializeToString() + return await channel.unary_unary(rpc_call)(msg) + + await run_with_test_server(request, servicer=TwoSpanServicer()) + + spans_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans_list), 2) + child_span = spans_list[0] + parent_span = spans_list[1] + + self.assertEqual(parent_span.name, rpc_call) + self.assertIs(parent_span.kind, trace.SpanKind.SERVER) + + # Check version and name in span's instrumentation info + self.assertEqualSpanInstrumentationInfo( + parent_span, opentelemetry.instrumentation.grpc + ) + + # Check attributes + self.assertSpanHasAttributes( + parent_span, + { + SpanAttributes.NET_PEER_IP: "[::1]", + SpanAttributes.NET_PEER_NAME: "localhost", + SpanAttributes.RPC_METHOD: "SimpleMethod", + SpanAttributes.RPC_SERVICE: "GRPCTestServer", + SpanAttributes.RPC_SYSTEM: "grpc", + SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.OK.value[ + 0 + ], + }, + ) + + # Check the child span + self.assertEqual(child_span.name, "child") + self.assertEqual( + parent_span.context.trace_id, child_span.context.trace_id + ) + + async def test_create_span_streaming(self): + """Check that the interceptor wraps calls with spans server-side, on a + streaming call.""" + rpc_call = "/GRPCTestServer/ServerStreamingMethod" + + async def request(channel): + request = Request(client_id=1, request_data="test") + msg = request.SerializeToString() + async for response in channel.unary_stream(rpc_call)(msg): + print(response) + + await run_with_test_server(request) + + spans_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans_list), 1) + span = spans_list[0] + + self.assertEqual(span.name, rpc_call) + self.assertIs(span.kind, trace.SpanKind.SERVER) + + # Check version and name in span's instrumentation info + self.assertEqualSpanInstrumentationInfo( + span, opentelemetry.instrumentation.grpc + ) + + # Check attributes + self.assertSpanHasAttributes( + span, + { + SpanAttributes.NET_PEER_IP: "[::1]", + SpanAttributes.NET_PEER_NAME: "localhost", + SpanAttributes.RPC_METHOD: "ServerStreamingMethod", + SpanAttributes.RPC_SERVICE: "GRPCTestServer", + SpanAttributes.RPC_SYSTEM: "grpc", + SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.OK.value[ + 0 + ], + }, + ) + + async def test_create_two_spans_streaming(self): + """Verify that the interceptor captures sub spans within the given + trace""" + rpc_call = "/GRPCTestServer/ServerStreamingMethod" + + class TwoSpanServicer(GRPCTestServerServicer): + # pylint:disable=C0103 + async def ServerStreamingMethod(self, request, context): + # create another span + tracer = trace.get_tracer(__name__) + with tracer.start_as_current_span("child") as child: + child.add_event("child event") + + for data in ("one", "two", "three"): + yield Response( + server_id=request.client_id, + response_data=data, + ) + + async def request(channel): + request = Request(client_id=1, request_data="test") + msg = request.SerializeToString() + async for response in channel.unary_stream(rpc_call)(msg): + print(response) + + await run_with_test_server(request, servicer=TwoSpanServicer()) + + spans_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans_list), 2) + child_span = spans_list[0] + parent_span = spans_list[1] + + self.assertEqual(parent_span.name, rpc_call) + self.assertIs(parent_span.kind, trace.SpanKind.SERVER) + + # Check version and name in span's instrumentation info + self.assertEqualSpanInstrumentationInfo( + parent_span, opentelemetry.instrumentation.grpc + ) + + # Check attributes + self.assertSpanHasAttributes( + parent_span, + { + SpanAttributes.NET_PEER_IP: "[::1]", + SpanAttributes.NET_PEER_NAME: "localhost", + SpanAttributes.RPC_METHOD: "ServerStreamingMethod", + SpanAttributes.RPC_SERVICE: "GRPCTestServer", + SpanAttributes.RPC_SYSTEM: "grpc", + SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.OK.value[ + 0 + ], + }, + ) + + # Check the child span + self.assertEqual(child_span.name, "child") + self.assertEqual( + parent_span.context.trace_id, child_span.context.trace_id + ) + + async def test_span_lifetime(self): + """Verify that the interceptor captures sub spans within the given + trace""" + rpc_call = "/GRPCTestServer/SimpleMethod" + + class SpanLifetimeServicer(GRPCTestServerServicer): + # pylint:disable=C0103 + async def SimpleMethod(self, request, context): + # pylint:disable=attribute-defined-outside-init + self.span = trace.get_current_span() + + return Response( + server_id=request.client_id, + response_data=request.request_data, + ) + + async def request(channel): + request = Request(client_id=1, request_data="test") + msg = request.SerializeToString() + return await channel.unary_unary(rpc_call)(msg) + + lifetime_servicer = SpanLifetimeServicer() + active_span_before_call = trace.get_current_span() + + await run_with_test_server(request, servicer=lifetime_servicer) + + active_span_in_handler = lifetime_servicer.span + active_span_after_call = trace.get_current_span() + + self.assertEqual(active_span_before_call, trace.INVALID_SPAN) + self.assertEqual(active_span_after_call, trace.INVALID_SPAN) + self.assertIsInstance(active_span_in_handler, trace_sdk.Span) + self.assertIsNone(active_span_in_handler.parent) + + async def test_sequential_server_spans(self): + """Check that sequential RPCs get separate server spans.""" + rpc_call = "/GRPCTestServer/SimpleMethod" + + async def request(channel): + request = Request(client_id=1, request_data="test") + msg = request.SerializeToString() + return await channel.unary_unary(rpc_call)(msg) + + async def sequential_requests(channel): + await request(channel) + await request(channel) + + await run_with_test_server(sequential_requests) + + spans_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans_list), 2) + + span1 = spans_list[0] + span2 = spans_list[1] + + # Spans should belong to separate traces + self.assertNotEqual(span1.context.span_id, span2.context.span_id) + self.assertNotEqual(span1.context.trace_id, span2.context.trace_id) + + for span in (span1, span2): + # each should be a root span + self.assertIsNone(span2.parent) + + # check attributes + self.assertSpanHasAttributes( + span, + { + SpanAttributes.NET_PEER_IP: "[::1]", + SpanAttributes.NET_PEER_NAME: "localhost", + SpanAttributes.RPC_METHOD: "SimpleMethod", + SpanAttributes.RPC_SERVICE: "GRPCTestServer", + SpanAttributes.RPC_SYSTEM: "grpc", + SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.OK.value[ + 0 + ], + }, + ) + + async def test_concurrent_server_spans(self): + """Check that concurrent RPC calls don't interfere with each other. + + This is the same check as test_sequential_server_spans except that the + RPCs are concurrent. Two handlers are invoked at the same time on two + separate threads. Each one should see a different active span and + context. + """ + rpc_call = "/GRPCTestServer/SimpleMethod" + latch = get_latch(2) + + class LatchedServicer(GRPCTestServerServicer): + # pylint:disable=C0103 + async def SimpleMethod(self, request, context): + await latch() + return Response( + server_id=request.client_id, + response_data=request.request_data, + ) + + async def request(channel): + request = Request(client_id=1, request_data="test") + msg = request.SerializeToString() + return await channel.unary_unary(rpc_call)(msg) + + async def concurrent_requests(channel): + await asyncio.gather(request(channel), request(channel)) + + await run_with_test_server( + concurrent_requests, servicer=LatchedServicer() + ) + + spans_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans_list), 2) + + span1 = spans_list[0] + span2 = spans_list[1] + + # Spans should belong to separate traces + self.assertNotEqual(span1.context.span_id, span2.context.span_id) + self.assertNotEqual(span1.context.trace_id, span2.context.trace_id) + + for span in (span1, span2): + # each should be a root span + self.assertIsNone(span2.parent) + + # check attributes + self.assertSpanHasAttributes( + span, + { + SpanAttributes.NET_PEER_IP: "[::1]", + SpanAttributes.NET_PEER_NAME: "localhost", + SpanAttributes.RPC_METHOD: "SimpleMethod", + SpanAttributes.RPC_SERVICE: "GRPCTestServer", + SpanAttributes.RPC_SYSTEM: "grpc", + SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.OK.value[ + 0 + ], + }, + ) + + async def test_abort(self): + """Check that we can catch an abort properly""" + rpc_call = "/GRPCTestServer/SimpleMethod" + failure_message = "failure message" + + class AbortServicer(GRPCTestServerServicer): + # pylint:disable=C0103 + async def SimpleMethod(self, request, context): + await context.abort( + grpc.StatusCode.FAILED_PRECONDITION, failure_message + ) + + testcase = self + + async def request(channel): + request = Request(client_id=1, request_data=failure_message) + msg = request.SerializeToString() + + with testcase.assertRaises(Exception): + await channel.unary_unary(rpc_call)(msg) + + await run_with_test_server(request, servicer=AbortServicer()) + + spans_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans_list), 1) + span = spans_list[0] + + self.assertEqual(span.name, rpc_call) + self.assertIs(span.kind, trace.SpanKind.SERVER) + + # Check version and name in span's instrumentation info + self.assertEqualSpanInstrumentationInfo( + span, opentelemetry.instrumentation.grpc + ) + + # make sure this span errored, with the right status and detail + self.assertEqual(span.status.status_code, StatusCode.ERROR) + self.assertEqual( + span.status.description, + f"{grpc.StatusCode.FAILED_PRECONDITION}:{failure_message}", + ) + + # Check attributes + self.assertSpanHasAttributes( + span, + { + SpanAttributes.NET_PEER_IP: "[::1]", + SpanAttributes.NET_PEER_NAME: "localhost", + SpanAttributes.RPC_METHOD: "SimpleMethod", + SpanAttributes.RPC_SERVICE: "GRPCTestServer", + SpanAttributes.RPC_SYSTEM: "grpc", + SpanAttributes.RPC_GRPC_STATUS_CODE: grpc.StatusCode.FAILED_PRECONDITION.value[ + 0 + ], + }, + ) + + +def get_latch(num): + """Get a countdown latch function for use in n threads.""" + cv = asyncio.Condition() + count = 0 + + async def countdown_latch(): + """Block until n-1 other threads have called.""" + nonlocal count + async with cv: + count += 1 + cv.notify() + + async with cv: + while count < num: + await cv.wait() + + return countdown_latch diff --git a/instrumentation/opentelemetry-instrumentation-grpc/tests/test_aio_server_interceptor_filter.py b/instrumentation/opentelemetry-instrumentation-grpc/tests/test_aio_server_interceptor_filter.py new file mode 100644 index 0000000000..837d9c7618 --- /dev/null +++ b/instrumentation/opentelemetry-instrumentation-grpc/tests/test_aio_server_interceptor_filter.py @@ -0,0 +1,135 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +try: + from unittest import IsolatedAsyncioTestCase +except ImportError: + # unittest.IsolatedAsyncioTestCase was introduced in Python 3.8. It's use + # simplifies the following tests. Without it, the amount of test code + # increases significantly, with most of the additional code handling + # the asyncio set up. + from unittest import TestCase + + class IsolatedAsyncioTestCase(TestCase): + def run(self, result=None): + self.skipTest( + "This test requires Python 3.8 for unittest.IsolatedAsyncioTestCase" + ) + + +import grpc +import grpc.aio +import pytest + +from opentelemetry import trace +from opentelemetry.instrumentation.grpc import ( + GrpcAioInstrumentorServer, + aio_server_interceptor, + filters, +) +from opentelemetry.test.test_base import TestBase + +from .protobuf.test_server_pb2 import Request +from .protobuf.test_server_pb2_grpc import add_GRPCTestServerServicer_to_server +from .test_aio_server_interceptor import Servicer + +# pylint:disable=unused-argument +# pylint:disable=no-self-use + + +async def run_with_test_server( + runnable, filter_=None, servicer=Servicer(), add_interceptor=True +): + if add_interceptor: + interceptors = [aio_server_interceptor(filter_=filter_)] + server = grpc.aio.server(interceptors=interceptors) + else: + server = grpc.aio.server() + + add_GRPCTestServerServicer_to_server(servicer, server) + + port = server.add_insecure_port("[::]:0") + channel = grpc.aio.insecure_channel(f"localhost:{port:d}") + + await server.start() + resp = await runnable(channel) + await server.stop(1000) + + return resp + + +@pytest.mark.asyncio +class TestOpenTelemetryAioServerInterceptor(TestBase, IsolatedAsyncioTestCase): + async def test_instrumentor(self): + """Check that automatic instrumentation configures the interceptor""" + rpc_call = "/GRPCTestServer/SimpleMethod" + + grpc_aio_server_instrumentor = GrpcAioInstrumentorServer( + filter_=filters.method_name("NotSimpleMethod") + ) + try: + grpc_aio_server_instrumentor.instrument() + + async def request(channel): + request = Request(client_id=1, request_data="test") + msg = request.SerializeToString() + return await channel.unary_unary(rpc_call)(msg) + + await run_with_test_server(request, add_interceptor=False) + + spans_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans_list), 0) + + finally: + grpc_aio_server_instrumentor.uninstrument() + + async def test_create_span(self): + """ + Check that the interceptor wraps calls with spans server-side when filter + passed and RPC matches the filter. + """ + rpc_call = "/GRPCTestServer/SimpleMethod" + + async def request(channel): + request = Request(client_id=1, request_data="test") + msg = request.SerializeToString() + return await channel.unary_unary(rpc_call)(msg) + + await run_with_test_server( + request, + filter_=filters.method_name("SimpleMethod"), + ) + + spans_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans_list), 1) + span = spans_list[0] + + self.assertEqual(span.name, rpc_call) + self.assertIs(span.kind, trace.SpanKind.SERVER) + + async def test_create_span_filtered(self): + """Check that the interceptor wraps calls with spans server-side.""" + rpc_call = "/GRPCTestServer/SimpleMethod" + + async def request(channel): + request = Request(client_id=1, request_data="test") + msg = request.SerializeToString() + return await channel.unary_unary(rpc_call)(msg) + + await run_with_test_server( + request, + filter_=filters.method_name("NotSimpleMethod"), + ) + + spans_list = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans_list), 0) diff --git a/instrumentation/opentelemetry-instrumentation-grpc/tests/test_filters.py b/instrumentation/opentelemetry-instrumentation-grpc/tests/test_filters.py index f7d69074ac..81cc689edd 100644 --- a/instrumentation/opentelemetry-instrumentation-grpc/tests/test_filters.py +++ b/instrumentation/opentelemetry-instrumentation-grpc/tests/test_filters.py @@ -59,6 +59,39 @@ class _StreamClientInfo( invocation_metadata=[("tracer", "foo"), ("caller", "bar")], ), ), + ( + True, + "SimpleMethod", + grpc.aio.ClientCallDetails( + method="SimpleMethod", + timeout=3000, + metadata=None, + credentials=None, + wait_for_ready=None, + ), + ), + ( + True, + "SimpleMethod", + grpc.aio.ClientCallDetails( + method=b"SimpleMethod", + timeout=3000, + metadata=None, + credentials=None, + wait_for_ready=None, + ), + ), + ( + False, + "SimpleMethod", + grpc.aio.ClientCallDetails( + method="NotSimpleMethod", + timeout=3000, + metadata=None, + credentials=None, + wait_for_ready=None, + ), + ), ( False, "SimpleMethod", diff --git a/tox.ini b/tox.ini index 6b39f20728..d1a7da6f8e 100644 --- a/tox.ini +++ b/tox.ini @@ -233,6 +233,7 @@ deps = falcon1: falcon ==1.4.1 falcon2: falcon >=2.0.0,<3.0.0 falcon3: falcon >=3.0.0,<4.0.0 + grpc: pytest-asyncio sqlalchemy11: sqlalchemy>=1.1,<1.2 sqlalchemy14: aiosqlite sqlalchemy14: sqlalchemy~=1.4