diff --git a/datastore/google/cloud/datastore/connection.py b/datastore/google/cloud/datastore/connection.py index 74070b142355..aac5c85e0a88 100644 --- a/datastore/google/cloud/datastore/connection.py +++ b/datastore/google/cloud/datastore/connection.py @@ -14,6 +14,7 @@ """Connections to Google Cloud Datastore API servers.""" +import contextlib import os from google.rpc import status_pb2 @@ -23,19 +24,35 @@ from google.cloud import connection as connection_module from google.cloud.environment_vars import DISABLE_GRPC from google.cloud.environment_vars import GCD_HOST -from google.cloud.exceptions import BadRequest -from google.cloud.exceptions import Conflict -from google.cloud.exceptions import GrpcRendezvous -from google.cloud.exceptions import make_exception +from google.cloud import exceptions from google.cloud.datastore._generated import datastore_pb2 as _datastore_pb2 try: from grpc import StatusCode from google.cloud.datastore._generated import datastore_grpc_pb2 except ImportError: # pragma: NO COVER + _GRPC_ERROR_MAPPING = {} _HAVE_GRPC = False datastore_grpc_pb2 = None StatusCode = None else: + # NOTE: We don't include OK -> 200 or CANCELLED -> 499 + _GRPC_ERROR_MAPPING = { + StatusCode.UNKNOWN: exceptions.InternalServerError, + StatusCode.INVALID_ARGUMENT: exceptions.BadRequest, + StatusCode.DEADLINE_EXCEEDED: exceptions.GatewayTimeout, + StatusCode.NOT_FOUND: exceptions.NotFound, + StatusCode.ALREADY_EXISTS: exceptions.Conflict, + StatusCode.PERMISSION_DENIED: exceptions.Forbidden, + StatusCode.UNAUTHENTICATED: exceptions.Unauthorized, + StatusCode.RESOURCE_EXHAUSTED: exceptions.TooManyRequests, + StatusCode.FAILED_PRECONDITION: exceptions.PreconditionFailed, + StatusCode.ABORTED: exceptions.Conflict, + StatusCode.OUT_OF_RANGE: exceptions.BadRequest, + StatusCode.UNIMPLEMENTED: exceptions.MethodNotImplemented, + StatusCode.INTERNAL: exceptions.InternalServerError, + StatusCode.UNAVAILABLE: exceptions.ServiceUnavailable, + StatusCode.DATA_LOSS: exceptions.InternalServerError, + } _HAVE_GRPC = True @@ -93,7 +110,8 @@ def _request(self, project, method, data): status = headers['status'] if status != '200': error_status = status_pb2.Status.FromString(content) - raise make_exception(headers, error_status.message, use_json=False) + raise exceptions.make_exception( + headers, error_status.message, use_json=False) return content @@ -220,6 +238,28 @@ def allocate_ids(self, project, request_pb): _datastore_pb2.AllocateIdsResponse) +@contextlib.contextmanager +def _grpc_catch_rendezvous(): + """Re-map gRPC exceptions that happen in context. + + .. _code.proto: https://github.com/googleapis/googleapis/blob/\ + master/google/rpc/code.proto + + Remaps gRPC exceptions to the classes defined in + :mod:`~google.cloud.exceptions` (according to the description + in `code.proto`_). + """ + try: + yield + except exceptions.GrpcRendezvous as exc: + error_code = exc.code() + error_class = _GRPC_ERROR_MAPPING.get(error_code) + if error_class is None: + raise + else: + raise error_class(exc.details()) + + class _DatastoreAPIOverGRPC(object): """Helper mapping datastore API methods. @@ -276,13 +316,8 @@ def run_query(self, project, request_pb): :returns: The returned protobuf response object. """ request_pb.project_id = project - try: + with _grpc_catch_rendezvous(): return self._stub.RunQuery(request_pb) - except GrpcRendezvous as exc: - error_code = exc.code() - if error_code == StatusCode.INVALID_ARGUMENT: - raise BadRequest(exc.details()) - raise def begin_transaction(self, project, request_pb): """Perform a ``beginTransaction`` request. @@ -299,7 +334,8 @@ def begin_transaction(self, project, request_pb): :returns: The returned protobuf response object. """ request_pb.project_id = project - return self._stub.BeginTransaction(request_pb) + with _grpc_catch_rendezvous(): + return self._stub.BeginTransaction(request_pb) def commit(self, project, request_pb): """Perform a ``commit`` request. @@ -315,15 +351,8 @@ def commit(self, project, request_pb): :returns: The returned protobuf response object. """ request_pb.project_id = project - try: + with _grpc_catch_rendezvous(): return self._stub.Commit(request_pb) - except GrpcRendezvous as exc: - error_code = exc.code() - if error_code == StatusCode.ABORTED: - raise Conflict(exc.details()) - if error_code == StatusCode.INVALID_ARGUMENT: - raise BadRequest(exc.details()) - raise def rollback(self, project, request_pb): """Perform a ``rollback`` request. @@ -339,7 +368,8 @@ def rollback(self, project, request_pb): :returns: The returned protobuf response object. """ request_pb.project_id = project - return self._stub.Rollback(request_pb) + with _grpc_catch_rendezvous(): + return self._stub.Rollback(request_pb) def allocate_ids(self, project, request_pb): """Perform an ``allocateIds`` request. @@ -355,7 +385,8 @@ def allocate_ids(self, project, request_pb): :returns: The returned protobuf response object. """ request_pb.project_id = project - return self._stub.AllocateIds(request_pb) + with _grpc_catch_rendezvous(): + return self._stub.AllocateIds(request_pb) class Connection(connection_module.Connection): diff --git a/datastore/unit_tests/test_connection.py b/datastore/unit_tests/test_connection.py index cbafc72ac3f0..973a3241506e 100644 --- a/datastore/unit_tests/test_connection.py +++ b/datastore/unit_tests/test_connection.py @@ -106,6 +106,72 @@ def test__request_not_200(self): [{'method': METHOD, 'project': PROJECT}]) +@unittest.skipUnless(_HAVE_GRPC, 'No gRPC') +class Test__grpc_catch_rendezvous(unittest.TestCase): + + def _callFUT(self): + from google.cloud.datastore.connection import _grpc_catch_rendezvous + return _grpc_catch_rendezvous() + + @staticmethod + def _fake_method(exc, result=None): + if exc is None: + return result + else: + raise exc + + def test_success(self): + expected = object() + with self._callFUT(): + result = self._fake_method(None, expected) + self.assertIs(result, expected) + + def test_failure_aborted(self): + from grpc import StatusCode + from grpc._channel import _RPCState + from google.cloud.exceptions import Conflict + from google.cloud.exceptions import GrpcRendezvous + + details = 'Bad things.' + exc_state = _RPCState((), None, None, StatusCode.ABORTED, details) + exc = GrpcRendezvous(exc_state, None, None, None) + with self.assertRaises(Conflict): + with self._callFUT(): + self._fake_method(exc) + + def test_failure_invalid_argument(self): + from grpc import StatusCode + from grpc._channel import _RPCState + from google.cloud.exceptions import BadRequest + from google.cloud.exceptions import GrpcRendezvous + + details = ('Cannot have inequality filters on multiple ' + 'properties: [created, priority]') + exc_state = _RPCState((), None, None, + StatusCode.INVALID_ARGUMENT, details) + exc = GrpcRendezvous(exc_state, None, None, None) + with self.assertRaises(BadRequest): + with self._callFUT(): + self._fake_method(exc) + + def test_failure_cancelled(self): + from grpc import StatusCode + from grpc._channel import _RPCState + from google.cloud.exceptions import GrpcRendezvous + + exc_state = _RPCState((), None, None, StatusCode.CANCELLED, None) + exc = GrpcRendezvous(exc_state, None, None, None) + with self.assertRaises(GrpcRendezvous): + with self._callFUT(): + self._fake_method(exc) + + def test_commit_failure_non_grpc_err(self): + exc = RuntimeError('Not a gRPC error') + with self.assertRaises(RuntimeError): + with self._callFUT(): + self._fake_method(exc) + + class Test_DatastoreAPIOverGRPC(unittest.TestCase): def _getTargetClass(self): @@ -227,16 +293,6 @@ def test_run_query_invalid_argument(self): exc = GrpcRendezvous(exc_state, None, None, None) self._run_query_failure_helper(exc, BadRequest) - @unittest.skipUnless(_HAVE_GRPC, 'No gRPC') - def test_run_query_cancelled(self): - from grpc import StatusCode - from grpc._channel import _RPCState - from google.cloud.exceptions import GrpcRendezvous - - exc_state = _RPCState((), None, None, StatusCode.CANCELLED, None) - exc = GrpcRendezvous(exc_state, None, None, None) - self._run_query_failure_helper(exc, GrpcRendezvous) - def test_begin_transaction(self): return_val = object() stub = _GRPCStub(return_val) @@ -264,59 +320,6 @@ def test_commit_success(self): self.assertEqual(stub.method_calls, [(request_pb, 'Commit')]) - def _commit_failure_helper(self, exc, err_class): - stub = _GRPCStub(side_effect=exc) - datastore_api = self._makeOne(stub=stub) - - request_pb = _RequestPB() - project = 'PROJECT' - with self.assertRaises(err_class): - datastore_api.commit(project, request_pb) - - self.assertEqual(request_pb.project_id, project) - self.assertEqual(stub.method_calls, - [(request_pb, 'Commit')]) - - @unittest.skipUnless(_HAVE_GRPC, 'No gRPC') - def test_commit_failure_aborted(self): - from grpc import StatusCode - from grpc._channel import _RPCState - from google.cloud.exceptions import Conflict - from google.cloud.exceptions import GrpcRendezvous - - details = 'Bad things.' - exc_state = _RPCState((), None, None, StatusCode.ABORTED, details) - exc = GrpcRendezvous(exc_state, None, None, None) - self._commit_failure_helper(exc, Conflict) - - @unittest.skipUnless(_HAVE_GRPC, 'No gRPC') - def test_commit_failure_invalid_argument(self): - from grpc import StatusCode - from grpc._channel import _RPCState - from google.cloud.exceptions import BadRequest - from google.cloud.exceptions import GrpcRendezvous - - details = 'Too long content.' - exc_state = _RPCState((), None, None, - StatusCode.INVALID_ARGUMENT, details) - exc = GrpcRendezvous(exc_state, None, None, None) - self._commit_failure_helper(exc, BadRequest) - - @unittest.skipUnless(_HAVE_GRPC, 'No gRPC') - def test_commit_failure_cancelled(self): - from grpc import StatusCode - from grpc._channel import _RPCState - from google.cloud.exceptions import GrpcRendezvous - - exc_state = _RPCState((), None, None, StatusCode.CANCELLED, None) - exc = GrpcRendezvous(exc_state, None, None, None) - self._commit_failure_helper(exc, GrpcRendezvous) - - @unittest.skipUnless(_HAVE_GRPC, 'No gRPC') - def test_commit_failure_non_grpc_err(self): - exc = RuntimeError('Not a gRPC error') - self._commit_failure_helper(exc, RuntimeError) - def test_rollback(self): return_val = object() stub = _GRPCStub(return_val) @@ -1161,27 +1164,22 @@ def __init__(self, return_val=None, side_effect=Exception): def _method(self, request_pb, name): self.method_calls.append((request_pb, name)) - return self.return_val + if self.side_effect is Exception: + return self.return_val + else: + raise self.side_effect def Lookup(self, request_pb): return self._method(request_pb, 'Lookup') def RunQuery(self, request_pb): - result = self._method(request_pb, 'RunQuery') - if self.side_effect is Exception: - return result - else: - raise self.side_effect + return self._method(request_pb, 'RunQuery') def BeginTransaction(self, request_pb): return self._method(request_pb, 'BeginTransaction') def Commit(self, request_pb): - result = self._method(request_pb, 'Commit') - if self.side_effect is Exception: - return result - else: - raise self.side_effect + return self._method(request_pb, 'Commit') def Rollback(self, request_pb): return self._method(request_pb, 'Rollback')