diff --git a/CHANGELOG.md b/CHANGELOG.md index 1d042f845b..dc00349400 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -61,7 +61,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ([#436](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/436)) - `opentelemetry-instrumenation-flask` now supports trace response headers. ([#436](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/436)) - +- `opentelemetry-instrumentation-grpc` Keep client interceptor in sync with grpc client interceptors. + ([#442](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/442)) + ### Removed - Remove `http.status_text` from span attributes ([#406](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/406)) diff --git a/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/grpcext/_interceptor.py b/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/grpcext/_interceptor.py index 89889aceeb..3e3916fd41 100644 --- a/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/grpcext/_interceptor.py +++ b/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/grpcext/_interceptor.py @@ -13,9 +13,7 @@ # limitations under the License. # pylint:disable=relative-beyond-top-level -# pylint:disable=arguments-differ # pylint:disable=no-member -# pylint:disable=signature-differs """Implementation of gRPC Python interceptors.""" @@ -41,6 +39,11 @@ class _StreamClientInfo( ): pass +def _get_metadata_timeout(**kwargs): + metadata = kwargs.get("metadata") + timeout = kwargs.get("timeout") + return metadata, timeout + class _InterceptorUnaryUnaryMultiCallable(grpc.UnaryUnaryMultiCallable): def __init__(self, method, base_callable, interceptor): @@ -48,34 +51,55 @@ def __init__(self, method, base_callable, interceptor): self._base_callable = base_callable self._interceptor = interceptor - def __call__(self, request, timeout=None, metadata=None, credentials=None): + def __call__( + self, + request, + **kwargs, + ): def invoker(request, metadata): - return self._base_callable(request, timeout, metadata, credentials) + kwargs["metadata"] = metadata + return self._base_callable( + request, + **kwargs, + ) + metadata, timeout = _get_metadata_timeout(**kwargs) client_info = _UnaryClientInfo(self._method, timeout) return self._interceptor.intercept_unary( request, metadata, client_info, invoker ) def with_call( - self, request, timeout=None, metadata=None, credentials=None + self, + request, + **kwargs, ): def invoker(request, metadata): + kwargs["metadata"] = metadata return self._base_callable.with_call( - request, timeout, metadata, credentials + request, + **kwargs, ) + metadata, timeout = _get_metadata_timeout(**kwargs) client_info = _UnaryClientInfo(self._method, timeout) return self._interceptor.intercept_unary( request, metadata, client_info, invoker ) - def future(self, request, timeout=None, metadata=None, credentials=None): + def future( + self, + request, + **kwargs, + ): def invoker(request, metadata): + kwargs["metadata"] = metadata return self._base_callable.future( - request, timeout, metadata, credentials + request, + **kwargs, ) + metadata, timeout = _get_metadata_timeout(**kwargs) client_info = _UnaryClientInfo(self._method, timeout) return self._interceptor.intercept_unary( request, metadata, client_info, invoker @@ -88,10 +112,19 @@ def __init__(self, method, base_callable, interceptor): self._base_callable = base_callable self._interceptor = interceptor - def __call__(self, request, timeout=None, metadata=None, credentials=None): + def __call__( + self, + request, + **kwargs, + ): def invoker(request, metadata): - return self._base_callable(request, timeout, metadata, credentials) + kwargs["metadata"] = metadata + return self._base_callable( + request, + **kwargs, + ) + metadata, timeout = _get_metadata_timeout(**kwargs) client_info = _StreamClientInfo(self._method, False, True, timeout) return self._interceptor.intercept_stream( request, metadata, client_info, invoker @@ -105,39 +138,54 @@ def __init__(self, method, base_callable, interceptor): self._interceptor = interceptor def __call__( - self, request_iterator, timeout=None, metadata=None, credentials=None + self, + request_iterator, + **kwargs, ): def invoker(request_iterator, metadata): + kwargs["metadata"] = metadata return self._base_callable( - request_iterator, timeout, metadata, credentials + request_iterator, + **kwargs, ) + metadata, timeout = _get_metadata_timeout(**kwargs) client_info = _StreamClientInfo(self._method, True, False, timeout) return self._interceptor.intercept_stream( request_iterator, metadata, client_info, invoker ) def with_call( - self, request_iterator, timeout=None, metadata=None, credentials=None + self, + request_iterator, + **kwargs, ): def invoker(request_iterator, metadata): + kwargs["metadata"] = metadata return self._base_callable.with_call( - request_iterator, timeout, metadata, credentials + request_iterator, + **kwargs ) + metadata, timeout = _get_metadata_timeout(**kwargs) client_info = _StreamClientInfo(self._method, True, False, timeout) return self._interceptor.intercept_stream( request_iterator, metadata, client_info, invoker ) def future( - self, request_iterator, timeout=None, metadata=None, credentials=None + self, + request_iterator, + **kwargs ): def invoker(request_iterator, metadata): + kwargs["metadata"] = metadata return self._base_callable.future( - request_iterator, timeout, metadata, credentials + request_iterator, + **kwargs, ) + metadata, timeout = _get_metadata_timeout(**kwargs) client_info = _StreamClientInfo(self._method, True, False, timeout) return self._interceptor.intercept_stream( request_iterator, metadata, client_info, invoker @@ -151,13 +199,18 @@ def __init__(self, method, base_callable, interceptor): self._interceptor = interceptor def __call__( - self, request_iterator, timeout=None, metadata=None, credentials=None + self, + request_iterator, + **kwargs, ): def invoker(request_iterator, metadata): + kwargs["metadata"] = metadata return self._base_callable( - request_iterator, timeout, metadata, credentials + request_iterator, + **kwargs, ) + metadata, timeout = _get_metadata_timeout(**kwargs) client_info = _StreamClientInfo(self._method, True, True, timeout) return self._interceptor.intercept_stream( request_iterator, metadata, client_info, invoker diff --git a/instrumentation/opentelemetry-instrumentation-grpc/tests/test_client_interceptor.py b/instrumentation/opentelemetry-instrumentation-grpc/tests/test_client_interceptor.py index f088f5cf8c..2b06fb909d 100644 --- a/instrumentation/opentelemetry-instrumentation-grpc/tests/test_client_interceptor.py +++ b/instrumentation/opentelemetry-instrumentation-grpc/tests/test_client_interceptor.py @@ -40,6 +40,46 @@ from ._server import create_test_server from .protobuf.test_server_pb2 import Request +# User defined interceptor. Is used in the tests along with the opentelemetry client interceptor. +class Interceptor( + grpc.UnaryUnaryClientInterceptor, + grpc.UnaryStreamClientInterceptor, + grpc.StreamUnaryClientInterceptor, + grpc.StreamStreamClientInterceptor, +): + def __init__(self): + pass + + def intercept_unary_unary( + self, continuation, client_call_details, request + ): + return self._intercept_call(continuation, client_call_details, request) + + def intercept_unary_stream( + self, continuation, client_call_details, request + ): + return self._intercept_call(continuation, client_call_details, request) + + def intercept_stream_unary( + self, continuation, client_call_details, request_iterator + ): + return self._intercept_call( + continuation, client_call_details, request_iterator + ) + + def intercept_stream_stream( + self, continuation, client_call_details, request_iterator + ): + return self._intercept_call( + continuation, client_call_details, request_iterator + ) + + @staticmethod + def _intercept_call( + continuation, client_call_details, request_or_iterator + ): + return continuation(client_call_details, request_or_iterator) + class TestClientProto(TestBase): def setUp(self): @@ -47,7 +87,10 @@ def setUp(self): GrpcInstrumentorClient().instrument() self.server = create_test_server(25565) self.server.start() + # use a user defined interceptor along with the opentelemetry client interceptor + interceptors = [Interceptor()] self.channel = grpc.insecure_channel("localhost:25565") + self.channel = grpc.intercept_channel(self.channel, *interceptors) self._stub = test_server_pb2_grpc.GRPCTestServerStub(self.channel) def tearDown(self):