Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add type annotations to wrapped grpc calls #554

Merged
merged 16 commits into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions google/api_core/grpc_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

"""Helpers for :mod:`grpc`."""
from typing import Generic, TypeVar, Iterator

import collections
import functools
Expand Down Expand Up @@ -54,6 +55,9 @@

_LOGGER = logging.getLogger(__name__)

# denotes the type yielded from streaming calls
S = TypeVar("S")


def _patch_callable_name(callable_):
"""Fix-up gRPC callable attributes.
Expand All @@ -79,7 +83,7 @@ def error_remapped_callable(*args, **kwargs):
return error_remapped_callable


class _StreamingResponseIterator(grpc.Call):
class _StreamingResponseIterator(Generic[S], grpc.Call):
def __init__(self, wrapped, prefetch_first_result=True):
self._wrapped = wrapped

Expand All @@ -97,11 +101,11 @@ def __init__(self, wrapped, prefetch_first_result=True):
# ignore stop iteration at this time. This should be handled outside of retry.
pass

def __iter__(self):
def __iter__(self) -> Iterator[S]:
"""This iterator is also an iterable that returns itself."""
return self

def __next__(self):
def __next__(self) -> S:
"""Get the next response from the stream.

Returns:
Expand Down Expand Up @@ -144,6 +148,10 @@ def trailing_metadata(self):
return self._wrapped.trailing_metadata()


# public type alias denoting the return type of streaming gapic calls
GrpcStream = _StreamingResponseIterator[S]


def _wrap_stream_errors(callable_):
"""Wrap errors for Unary-Stream and Stream-Stream gRPC callables.

Expand Down
33 changes: 23 additions & 10 deletions google/api_core/grpc_helpers_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import asyncio
import functools

from typing import Generic, Iterator, AsyncGenerator, TypeVar

import grpc
from grpc import aio

Expand All @@ -31,6 +33,11 @@
# automatic patching for us. But that means the overhead of creating an
# extra Python function spreads to every single send and receive.

# denotes the type returned from unary calls
U = TypeVar("U")
# denotes the type yielded from streaming calls
S = TypeVar("S")


class _WrappedCall(aio.Call):
def __init__(self):
Expand Down Expand Up @@ -75,26 +82,26 @@ async def wait_for_connection(self):
raise exceptions.from_grpc_error(rpc_error) from rpc_error


class _WrappedUnaryResponseMixin(_WrappedCall):
def __await__(self):
class _WrappedUnaryResponseMixin(Generic[U], _WrappedCall):
def __await__(self) -> Iterator[U]:
try:
response = yield from self._call.__await__()
return response
except grpc.RpcError as rpc_error:
raise exceptions.from_grpc_error(rpc_error) from rpc_error


class _WrappedStreamResponseMixin(_WrappedCall):
class _WrappedStreamResponseMixin(Generic[S], _WrappedCall):
def __init__(self):
self._wrapped_async_generator = None

async def read(self):
async def read(self) -> S:
try:
return await self._call.read()
except grpc.RpcError as rpc_error:
raise exceptions.from_grpc_error(rpc_error) from rpc_error

async def _wrapped_aiter(self):
async def _wrapped_aiter(self) -> AsyncGenerator[S, None]:
try:
# NOTE(lidiz) coverage doesn't understand the exception raised from
# __anext__ method. It is covered by test case:
Expand All @@ -104,7 +111,7 @@ async def _wrapped_aiter(self):
except grpc.RpcError as rpc_error:
raise exceptions.from_grpc_error(rpc_error) from rpc_error

def __aiter__(self):
def __aiter__(self) -> AsyncGenerator[S, None]:
if not self._wrapped_async_generator:
self._wrapped_async_generator = self._wrapped_aiter()
return self._wrapped_async_generator
Expand All @@ -127,26 +134,32 @@ async def done_writing(self):
# NOTE(lidiz) Implementing each individual class separately, so we don't
# expose any API that should not be seen. E.g., __aiter__ in unary-unary
# RPC, or __await__ in stream-stream RPC.
class _WrappedUnaryUnaryCall(_WrappedUnaryResponseMixin, aio.UnaryUnaryCall):
class _WrappedUnaryUnaryCall(_WrappedUnaryResponseMixin[U], aio.UnaryUnaryCall):
"""Wrapped UnaryUnaryCall to map exceptions."""


class _WrappedUnaryStreamCall(_WrappedStreamResponseMixin, aio.UnaryStreamCall):
class _WrappedUnaryStreamCall(_WrappedStreamResponseMixin[S], aio.UnaryStreamCall):
"""Wrapped UnaryStreamCall to map exceptions."""


class _WrappedStreamUnaryCall(
_WrappedUnaryResponseMixin, _WrappedStreamRequestMixin, aio.StreamUnaryCall
_WrappedUnaryResponseMixin[S], _WrappedStreamRequestMixin, aio.StreamUnaryCall
):
"""Wrapped StreamUnaryCall to map exceptions."""


class _WrappedStreamStreamCall(
_WrappedStreamRequestMixin, _WrappedStreamResponseMixin, aio.StreamStreamCall
_WrappedStreamRequestMixin, _WrappedStreamResponseMixin[S], aio.StreamStreamCall
):
"""Wrapped StreamStreamCall to map exceptions."""


# public type alias denoting the return type of async streaming gapic calls
GrpcAsyncStream = _WrappedStreamResponseMixin[S]
# public type alias denoting the return type of unary gapic calls
AwaitableGrpcCall = _WrappedUnaryResponseMixin[U]
Copy link
Contributor Author

@daniel-sanche daniel-sanche Nov 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me know if you have other naming suggestions for these (AsyncGrpcCall? GrpcAsyncIterable?)

I liked Awaitable because it's clear how to interact with it, and Stream instead of Iterable because it can do more than just iterate. But names are hard and I'm open to alternatives



def _wrap_unary_errors(callable_):
"""Map errors for Unary-Unary async callables."""
grpc_helpers._patch_callable_name(callable_)
Expand Down
22 changes: 22 additions & 0 deletions tests/asyncio/test_grpc_helpers_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,28 @@ def test_wrap_errors_non_streaming(wrap_unary_errors):
wrap_unary_errors.assert_called_once_with(callable_)


def test_grpc_async_stream():
"""
GrpcAsyncStream type should be both an AsyncIterator and a grpc.aio.Call.
"""
instance = grpc_helpers_async.GrpcAsyncStream[int]()
assert isinstance(instance, grpc.aio.Call)
# should implement __aiter__ and __anext__
assert hasattr(instance, "__aiter__")
it = instance.__aiter__()
assert hasattr(it, "__anext__")


def test_awaitable_grpc_call():
"""
AwaitableGrpcCall type should be an Awaitable and a grpc.aio.Call.
"""
instance = grpc_helpers_async.AwaitableGrpcCall[int]()
assert isinstance(instance, grpc.aio.Call)
# should implement __await__
assert hasattr(instance, "__await__")


@mock.patch("google.api_core.grpc_helpers_async._wrap_stream_errors")
def test_wrap_errors_streaming(wrap_stream_errors):
callable_ = mock.create_autospec(aio.UnaryStreamMultiCallable)
Expand Down
17 changes: 17 additions & 0 deletions tests/unit/test_grpc_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,23 @@ def test_trailing_metadata(self):
wrapped.trailing_metadata.assert_called_once_with()


class TestGrpcStream(Test_StreamingResponseIterator):
@staticmethod
def _make_one(wrapped, **kw):
return grpc_helpers.GrpcStream(wrapped, **kw)

def test_grpc_stream_attributes(self):
"""
Should be both a grpc.Call and an iterable
"""
call = self._make_one(None)
assert isinstance(call, grpc.Call)
# should implement __iter__
assert hasattr(call, "__iter__")
it = call.__iter__()
assert hasattr(it, "__next__")


def test_wrap_stream_okay():
expected_responses = [1, 2, 3]
callable_ = mock.Mock(spec=["__call__"], return_value=iter(expected_responses))
Expand Down