diff --git a/google/api_core/grpc_helpers.py b/google/api_core/grpc_helpers.py index f52e180a..793c884d 100644 --- a/google/api_core/grpc_helpers.py +++ b/google/api_core/grpc_helpers.py @@ -13,6 +13,7 @@ # limitations under the License. """Helpers for :mod:`grpc`.""" +from typing import Generic, TypeVar, Iterator import collections import functools @@ -54,6 +55,9 @@ _LOGGER = logging.getLogger(__name__) +# denotes the proto response type for grpc calls +P = TypeVar("P") + def _patch_callable_name(callable_): """Fix-up gRPC callable attributes. @@ -79,7 +83,7 @@ def error_remapped_callable(*args, **kwargs): return error_remapped_callable -class _StreamingResponseIterator(grpc.Call): +class _StreamingResponseIterator(Generic[P], grpc.Call): def __init__(self, wrapped, prefetch_first_result=True): self._wrapped = wrapped @@ -97,11 +101,11 @@ def __init__(self, wrapped, prefetch_first_result=True): # ignore stop iteration at this time. This should be handled outside of retry. pass - def __iter__(self): + def __iter__(self) -> Iterator[P]: """This iterator is also an iterable that returns itself.""" return self - def __next__(self): + def __next__(self) -> P: """Get the next response from the stream. Returns: @@ -144,6 +148,10 @@ def trailing_metadata(self): return self._wrapped.trailing_metadata() +# public type alias denoting the return type of streaming gapic calls +GrpcStream = _StreamingResponseIterator[P] + + def _wrap_stream_errors(callable_): """Wrap errors for Unary-Stream and Stream-Stream gRPC callables. diff --git a/google/api_core/grpc_helpers_async.py b/google/api_core/grpc_helpers_async.py index d1f69d98..5685e6f8 100644 --- a/google/api_core/grpc_helpers_async.py +++ b/google/api_core/grpc_helpers_async.py @@ -21,11 +21,15 @@ import asyncio import functools +from typing import Generic, Iterator, AsyncGenerator, TypeVar + import grpc from grpc import aio from google.api_core import exceptions, grpc_helpers +# denotes the proto response type for grpc calls +P = TypeVar("P") # NOTE(lidiz) Alternatively, we can hack "__getattribute__" to perform # automatic patching for us. But that means the overhead of creating an @@ -75,8 +79,8 @@ async def wait_for_connection(self): raise exceptions.from_grpc_error(rpc_error) from rpc_error -class _WrappedUnaryResponseMixin(_WrappedCall): - def __await__(self): +class _WrappedUnaryResponseMixin(Generic[P], _WrappedCall): + def __await__(self) -> Iterator[P]: try: response = yield from self._call.__await__() return response @@ -84,17 +88,17 @@ def __await__(self): raise exceptions.from_grpc_error(rpc_error) from rpc_error -class _WrappedStreamResponseMixin(_WrappedCall): +class _WrappedStreamResponseMixin(Generic[P], _WrappedCall): def __init__(self): self._wrapped_async_generator = None - async def read(self): + async def read(self) -> P: try: return await self._call.read() except grpc.RpcError as rpc_error: raise exceptions.from_grpc_error(rpc_error) from rpc_error - async def _wrapped_aiter(self): + async def _wrapped_aiter(self) -> AsyncGenerator[P, None]: try: # NOTE(lidiz) coverage doesn't understand the exception raised from # __anext__ method. It is covered by test case: @@ -104,7 +108,7 @@ async def _wrapped_aiter(self): except grpc.RpcError as rpc_error: raise exceptions.from_grpc_error(rpc_error) from rpc_error - def __aiter__(self): + def __aiter__(self) -> AsyncGenerator[P, None]: if not self._wrapped_async_generator: self._wrapped_async_generator = self._wrapped_aiter() return self._wrapped_async_generator @@ -127,26 +131,32 @@ async def done_writing(self): # NOTE(lidiz) Implementing each individual class separately, so we don't # expose any API that should not be seen. E.g., __aiter__ in unary-unary # RPC, or __await__ in stream-stream RPC. -class _WrappedUnaryUnaryCall(_WrappedUnaryResponseMixin, aio.UnaryUnaryCall): +class _WrappedUnaryUnaryCall(_WrappedUnaryResponseMixin[P], aio.UnaryUnaryCall): """Wrapped UnaryUnaryCall to map exceptions.""" -class _WrappedUnaryStreamCall(_WrappedStreamResponseMixin, aio.UnaryStreamCall): +class _WrappedUnaryStreamCall(_WrappedStreamResponseMixin[P], aio.UnaryStreamCall): """Wrapped UnaryStreamCall to map exceptions.""" class _WrappedStreamUnaryCall( - _WrappedUnaryResponseMixin, _WrappedStreamRequestMixin, aio.StreamUnaryCall + _WrappedUnaryResponseMixin[P], _WrappedStreamRequestMixin, aio.StreamUnaryCall ): """Wrapped StreamUnaryCall to map exceptions.""" class _WrappedStreamStreamCall( - _WrappedStreamRequestMixin, _WrappedStreamResponseMixin, aio.StreamStreamCall + _WrappedStreamRequestMixin, _WrappedStreamResponseMixin[P], aio.StreamStreamCall ): """Wrapped StreamStreamCall to map exceptions.""" +# public type alias denoting the return type of async streaming gapic calls +GrpcAsyncStream = _WrappedStreamResponseMixin[P] +# public type alias denoting the return type of unary gapic calls +AwaitableGrpcCall = _WrappedUnaryResponseMixin[P] + + def _wrap_unary_errors(callable_): """Map errors for Unary-Unary async callables.""" grpc_helpers._patch_callable_name(callable_) diff --git a/tests/asyncio/test_grpc_helpers_async.py b/tests/asyncio/test_grpc_helpers_async.py index 95242f6b..67c9b335 100644 --- a/tests/asyncio/test_grpc_helpers_async.py +++ b/tests/asyncio/test_grpc_helpers_async.py @@ -266,6 +266,28 @@ def test_wrap_errors_non_streaming(wrap_unary_errors): wrap_unary_errors.assert_called_once_with(callable_) +def test_grpc_async_stream(): + """ + GrpcAsyncStream type should be both an AsyncIterator and a grpc.aio.Call. + """ + instance = grpc_helpers_async.GrpcAsyncStream[int]() + assert isinstance(instance, grpc.aio.Call) + # should implement __aiter__ and __anext__ + assert hasattr(instance, "__aiter__") + it = instance.__aiter__() + assert hasattr(it, "__anext__") + + +def test_awaitable_grpc_call(): + """ + AwaitableGrpcCall type should be an Awaitable and a grpc.aio.Call. + """ + instance = grpc_helpers_async.AwaitableGrpcCall[int]() + assert isinstance(instance, grpc.aio.Call) + # should implement __await__ + assert hasattr(instance, "__await__") + + @mock.patch("google.api_core.grpc_helpers_async._wrap_stream_errors") def test_wrap_errors_streaming(wrap_stream_errors): callable_ = mock.create_autospec(aio.UnaryStreamMultiCallable) diff --git a/tests/unit/test_grpc_helpers.py b/tests/unit/test_grpc_helpers.py index 4eccbcaa..58a6a329 100644 --- a/tests/unit/test_grpc_helpers.py +++ b/tests/unit/test_grpc_helpers.py @@ -195,6 +195,23 @@ def test_trailing_metadata(self): wrapped.trailing_metadata.assert_called_once_with() +class TestGrpcStream(Test_StreamingResponseIterator): + @staticmethod + def _make_one(wrapped, **kw): + return grpc_helpers.GrpcStream(wrapped, **kw) + + def test_grpc_stream_attributes(self): + """ + Should be both a grpc.Call and an iterable + """ + call = self._make_one(None) + assert isinstance(call, grpc.Call) + # should implement __iter__ + assert hasattr(call, "__iter__") + it = call.__iter__() + assert hasattr(it, "__next__") + + def test_wrap_stream_okay(): expected_responses = [1, 2, 3] callable_ = mock.Mock(spec=["__call__"], return_value=iter(expected_responses))