diff --git a/datastore/google/cloud/datastore/_gax.py b/datastore/google/cloud/datastore/_gax.py index 7475340dcbe5..8037e7ff78dd 100644 --- a/datastore/google/cloud/datastore/_gax.py +++ b/datastore/google/cloud/datastore/_gax.py @@ -152,24 +152,6 @@ def run_query(self, project, request_pb): with _grpc_catch_rendezvous(): return self._stub.RunQuery(request_pb) - def begin_transaction(self, project, request_pb): - """Perform a ``beginTransaction`` request. - - :type project: str - :param project: The project to connect to. This is - usually your project name in the cloud console. - - :type request_pb: - :class:`.datastore_pb2.BeginTransactionRequest` - :param request_pb: The request protobuf object. - - :rtype: :class:`.datastore_pb2.BeginTransactionResponse` - :returns: The returned protobuf response object. - """ - request_pb.project_id = project - with _grpc_catch_rendezvous(): - return self._stub.BeginTransaction(request_pb) - class GAPICDatastoreAPI(datastore_client.DatastoreClient): """An API object that sends proto-over-gRPC requests. diff --git a/datastore/google/cloud/datastore/_http.py b/datastore/google/cloud/datastore/_http.py index c3c1f0bcda36..62499c8fcffc 100644 --- a/datastore/google/cloud/datastore/_http.py +++ b/datastore/google/cloud/datastore/_http.py @@ -198,24 +198,6 @@ def run_query(self, project, request_pb): self.connection.api_base_url, request_pb, _datastore_pb2.RunQueryResponse) - def begin_transaction(self, project, request_pb): - """Perform a ``beginTransaction`` request. - - :type project: str - :param project: The project to connect to. This is - usually your project name in the cloud console. - - :type request_pb: - :class:`.datastore_pb2.BeginTransactionRequest` - :param request_pb: The request protobuf object. - - :rtype: :class:`.datastore_pb2.BeginTransactionResponse` - :returns: The returned protobuf response object. - """ - return _rpc(self.connection.http, project, 'beginTransaction', - self.connection.api_base_url, - request_pb, _datastore_pb2.BeginTransactionResponse) - class Connection(connection_module.Connection): """A connection to the Google Cloud Datastore via the Protobuf API. @@ -335,20 +317,6 @@ def run_query(self, project, query_pb, namespace=None, request.query.CopyFrom(query_pb) return self._datastore_api.run_query(project, request) - def begin_transaction(self, project): - """Begin a transaction. - - Maps the ``DatastoreService.BeginTransaction`` protobuf RPC. - - :type project: str - :param project: The project to which the transaction applies. - - :rtype: :class:`.datastore_pb2.BeginTransactionResponse` - :returns: The serialized transaction that was begun. - """ - request = _datastore_pb2.BeginTransactionRequest() - return self._datastore_api.begin_transaction(project, request) - class HTTPDatastoreAPI(object): """An API object that sends proto-over-HTTP requests. @@ -362,6 +330,21 @@ class HTTPDatastoreAPI(object): def __init__(self, client): self.client = client + def begin_transaction(self, project): + """Perform a ``beginTransaction`` request. + + :type project: str + :param project: The project to connect to. This is + usually your project name in the cloud console. + + :rtype: :class:`.datastore_pb2.BeginTransactionResponse` + :returns: The returned protobuf response object. + """ + request_pb = _datastore_pb2.BeginTransactionRequest() + return _rpc(self.client._http, project, 'beginTransaction', + self.client._base_url, + request_pb, _datastore_pb2.BeginTransactionResponse) + def commit(self, project, mode, mutations, transaction=None): """Perform a ``commit`` request. diff --git a/datastore/google/cloud/datastore/transaction.py b/datastore/google/cloud/datastore/transaction.py index 651bbfcd0ba3..c1cd6a01321a 100644 --- a/datastore/google/cloud/datastore/transaction.py +++ b/datastore/google/cloud/datastore/transaction.py @@ -184,7 +184,7 @@ def begin(self): """ super(Transaction, self).begin() try: - response_pb = self._client._connection.begin_transaction( + response_pb = self._client._datastore_api.begin_transaction( self.project) self._id = response_pb.transaction except: # noqa: E722 do not use bare except, specify exception instead diff --git a/datastore/unit_tests/test__gax.py b/datastore/unit_tests/test__gax.py index ffc1bad756fd..0061ea106df9 100644 --- a/datastore/unit_tests/test__gax.py +++ b/datastore/unit_tests/test__gax.py @@ -240,20 +240,6 @@ def test_run_query_invalid_argument(self): exc = GrpcRendezvous(exc_state, None, None, None) self._run_query_failure_helper(exc, BadRequest) - def test_begin_transaction(self): - return_val = object() - stub = _GRPCStub(return_val) - datastore_api, _ = self._make_one(stub=stub) - - request_pb = mock.Mock(project_id=None, spec=['project_id']) - project = 'PROJECT' - result = datastore_api.begin_transaction(project, request_pb) - self.assertIs(result, return_val) - self.assertEqual(request_pb.project_id, project) - self.assertEqual( - stub.method_calls, - [(request_pb, 'BeginTransaction')]) - @unittest.skipUnless(_HAVE_GRPC, 'No gRPC') class TestGAPICDatastoreAPI(unittest.TestCase): @@ -338,6 +324,3 @@ def Lookup(self, request_pb): def RunQuery(self, request_pb): return self._method(request_pb, 'RunQuery') - - def BeginTransaction(self, request_pb): - return self._method(request_pb, 'BeginTransaction') diff --git a/datastore/unit_tests/test__http.py b/datastore/unit_tests/test__http.py index 482120c49969..67262584f67d 100644 --- a/datastore/unit_tests/test__http.py +++ b/datastore/unit_tests/test__http.py @@ -634,6 +634,23 @@ def test_run_query_w_namespace_nonempty_result(self): self.assertEqual(request.partition_id.namespace_id, namespace) self.assertEqual(request.query, q_pb) + +class TestHTTPDatastoreAPI(unittest.TestCase): + + @staticmethod + def _get_target_class(): + from google.cloud.datastore._http import HTTPDatastoreAPI + + return HTTPDatastoreAPI + + def _make_one(self, *args, **kwargs): + return self._get_target_class()(*args, **kwargs) + + def test_constructor(self): + client = object() + ds_api = self._make_one(client) + self.assertIs(ds_api.client, client) + def test_begin_transaction(self): from google.cloud.proto.datastore.v1 import datastore_pb2 @@ -648,13 +665,13 @@ def test_begin_transaction(self): _http=http, _base_url='test.invalid', spec=['_http', '_base_url']) # Make request. - conn = self._make_one(client) - response = conn.begin_transaction(project) + ds_api = self._make_one(client) + response = ds_api.begin_transaction(project) # Check the result and verify the callers. self.assertEqual(response, rsp_pb) uri = _build_expected_url( - conn.api_base_url, project, 'beginTransaction') + client._base_url, project, 'beginTransaction') cw = http._called_with _verify_protobuf_call(self, cw, uri) request = datastore_pb2.BeginTransactionRequest() @@ -662,23 +679,6 @@ def test_begin_transaction(self): # The RPC-over-HTTP request does not set the project in the request. self.assertEqual(request.project_id, u'') - -class TestHTTPDatastoreAPI(unittest.TestCase): - - @staticmethod - def _get_target_class(): - from google.cloud.datastore._http import HTTPDatastoreAPI - - return HTTPDatastoreAPI - - def _make_one(self, *args, **kwargs): - return self._get_target_class()(*args, **kwargs) - - def test_constructor(self): - client = object() - ds_api = self._make_one(client) - self.assertIs(ds_api.client, client) - def test_commit_wo_transaction(self): from google.cloud.proto.datastore.v1 import datastore_pb2 from google.cloud.datastore.helpers import _new_value_pb diff --git a/datastore/unit_tests/test_transaction.py b/datastore/unit_tests/test_transaction.py index 210f9d71cdd7..a9a4194c7dca 100644 --- a/datastore/unit_tests/test_transaction.py +++ b/datastore/unit_tests/test_transaction.py @@ -30,8 +30,7 @@ def _make_one(self, client, **kw): def test_ctor_defaults(self): project = 'PROJECT' - connection = _Connection() - client = _Client(project, connection) + client = _Client(project) xact = self._make_one(client) self.assertEqual(xact.project, project) self.assertIs(xact._client, client) @@ -45,8 +44,8 @@ def test_current(self): project = 'PROJECT' id_ = 678 - connection = _Connection(id_) - client = _Client(project, connection) + ds_api = _make_datastore_api(xact_id=id_) + client = _Client(project, datastore_api=ds_api) xact1 = self._make_one(client) xact2 = self._make_one(client) self.assertIsNone(xact1.current()) @@ -68,30 +67,35 @@ def test_current(self): self.assertIsNone(xact1.current()) self.assertIsNone(xact2.current()) - client._datastore_api.rollback.assert_not_called() - commit_method = client._datastore_api.commit + ds_api.rollback.assert_not_called() + commit_method = ds_api.commit self.assertEqual(commit_method.call_count, 2) mode = datastore_pb2.CommitRequest.TRANSACTIONAL commit_method.assert_called_with(project, mode, [], transaction=id_) + begin_txn = ds_api.begin_transaction + self.assertEqual(begin_txn.call_count, 2) + begin_txn.assert_called_with(project) + def test_begin(self): project = 'PROJECT' - connection = _Connection(234) - client = _Client(project, connection) + id_ = 889 + ds_api = _make_datastore_api(xact_id=id_) + client = _Client(project, datastore_api=ds_api) xact = self._make_one(client) xact.begin() - self.assertEqual(xact.id, 234) - self.assertEqual(connection._begun, project) + self.assertEqual(xact.id, id_) + ds_api.begin_transaction.assert_called_once_with(project) def test_begin_tombstoned(self): project = 'PROJECT' - id_ = 234 - connection = _Connection(id_) - client = _Client(project, connection) + id_ = 1094 + ds_api = _make_datastore_api(xact_id=id_) + client = _Client(project, datastore_api=ds_api) xact = self._make_one(client) xact.begin() self.assertEqual(xact.id, id_) - self.assertEqual(connection._begun, project) + ds_api.begin_transaction.assert_called_once_with(project) xact.rollback() client._datastore_api.rollback.assert_called_once_with(project, id_) @@ -101,36 +105,37 @@ def test_begin_tombstoned(self): def test_begin_w_begin_transaction_failure(self): project = 'PROJECT' - connection = _Connection(234) - client = _Client(project, connection) + id_ = 712 + ds_api = _make_datastore_api(xact_id=id_) + ds_api.begin_transaction = mock.Mock(side_effect=RuntimeError, spec=[]) + client = _Client(project, datastore_api=ds_api) xact = self._make_one(client) - connection._side_effect = RuntimeError with self.assertRaises(RuntimeError): xact.begin() self.assertIsNone(xact.id) - self.assertEqual(connection._begun, project) + ds_api.begin_transaction.assert_called_once_with(project) def test_rollback(self): project = 'PROJECT' - id_ = 234 - connection = _Connection(id_) - client = _Client(project, connection) + id_ = 239 + ds_api = _make_datastore_api(xact_id=id_) + client = _Client(project, datastore_api=ds_api) xact = self._make_one(client) xact.begin() xact.rollback() client._datastore_api.rollback.assert_called_once_with(project, id_) self.assertIsNone(xact.id) + ds_api.begin_transaction.assert_called_once_with(project) def test_commit_no_partial_keys(self): from google.cloud.proto.datastore.v1 import datastore_pb2 project = 'PROJECT' - id_ = 234 - connection = _Connection(id_) - - client = _Client(project, connection) + id_ = 1002930 + ds_api = _make_datastore_api(xact_id=id_) + client = _Client(project, datastore_api=ds_api) xact = self._make_one(client) xact.begin() xact.commit() @@ -139,6 +144,7 @@ def test_commit_no_partial_keys(self): client._datastore_api.commit.assert_called_once_with( project, mode, [], transaction=id_) self.assertIsNone(xact.id) + ds_api.begin_transaction.assert_called_once_with(project) def test_commit_w_partial_keys(self): from google.cloud.proto.datastore.v1 import datastore_pb2 @@ -147,10 +153,9 @@ def test_commit_w_partial_keys(self): kind = 'KIND' id1 = 123 key = _make_key(kind, id1, project) - ds_api = _make_datastore_api(key) id2 = 234 - connection = _Connection(id2) - client = _Client(project, connection, datastore_api=ds_api) + ds_api = _make_datastore_api(key, xact_id=id2) + client = _Client(project, datastore_api=ds_api) xact = self._make_one(client) xact.begin() entity = _Entity() @@ -162,23 +167,25 @@ def test_commit_w_partial_keys(self): project, mode, xact.mutations, transaction=id2) self.assertIsNone(xact.id) self.assertEqual(entity.key.path, [{'kind': kind, 'id': id1}]) + ds_api.begin_transaction.assert_called_once_with(project) def test_context_manager_no_raise(self): from google.cloud.proto.datastore.v1 import datastore_pb2 project = 'PROJECT' - id_ = 234 - connection = _Connection(id_) - client = _Client(project, connection) + id_ = 912830 + ds_api = _make_datastore_api(xact_id=id_) + client = _Client(project, datastore_api=ds_api) xact = self._make_one(client) with xact: self.assertEqual(xact.id, id_) - self.assertEqual(connection._begun, project) + ds_api.begin_transaction.assert_called_once_with(project) mode = datastore_pb2.CommitRequest.TRANSACTIONAL client._datastore_api.commit.assert_called_once_with( project, mode, [], transaction=id_) self.assertIsNone(xact.id) + self.assertEqual(ds_api.begin_transaction.call_count, 1) def test_context_manager_w_raise(self): @@ -186,15 +193,15 @@ class Foo(Exception): pass project = 'PROJECT' - id_ = 234 - connection = _Connection(id_) - client = _Client(project, connection) + id_ = 614416 + ds_api = _make_datastore_api(xact_id=id_) + client = _Client(project, datastore_api=ds_api) xact = self._make_one(client) xact._mutation = object() try: with xact: self.assertEqual(xact.id, id_) - self.assertEqual(connection._begun, project) + ds_api.begin_transaction.assert_called_once_with(project) raise Foo() except Foo: self.assertIsNone(xact.id) @@ -203,6 +210,7 @@ class Foo(Exception): client._datastore_api.commit.assert_not_called() self.assertIsNone(xact.id) + self.assertEqual(ds_api.begin_transaction.call_count, 1) def _make_key(kind, id_, project): @@ -216,22 +224,6 @@ def _make_key(kind, id_, project): return key -class _Connection(object): - _begun = None - _side_effect = None - - def __init__(self, xact_id=123): - self._xact_id = xact_id - - def begin_transaction(self, project): - self._begun = project - if self._side_effect is None: - return mock.Mock( - transaction=self._xact_id, spec=['transaction']) - else: - raise self._side_effect - - class _Entity(dict): def __init__(self): @@ -243,10 +235,8 @@ def __init__(self): class _Client(object): - def __init__(self, project, connection, - datastore_api=None, namespace=None): + def __init__(self, project, datastore_api=None, namespace=None): self.project = project - self._connection = connection if datastore_api is None: datastore_api = _make_datastore_api() self._datastore_api = datastore_api @@ -288,7 +278,15 @@ def _make_commit_response(*keys): return datastore_pb2.CommitResponse(mutation_results=mutation_results) -def _make_datastore_api(*keys): +def _make_datastore_api(*keys, **kwargs): commit_method = mock.Mock( return_value=_make_commit_response(*keys), spec=[]) - return mock.Mock(commit=commit_method, spec=['commit', 'rollback']) + + xact_id = kwargs.pop('xact_id', 123) + txn_pb = mock.Mock( + transaction=xact_id, spec=['transaction']) + begin_txn = mock.Mock(return_value=txn_pb, spec=[]) + + return mock.Mock( + commit=commit_method, begin_transaction=begin_txn, + spec=['begin_transaction', 'commit', 'rollback'])