From 8d03921e8eabb84040449efb803f09427cbd6758 Mon Sep 17 00:00:00 2001 From: Tres Seaver Date: Tue, 13 Oct 2020 15:32:39 -0400 Subject: [PATCH] feat: add retry/timeout to 'client.Client.get_all' Toward #221 --- google/cloud/firestore_v1/client.py | 22 +++++++++++++ tests/unit/v1/test_client.py | 49 +++++++++++++++++++++++++++++ 2 files changed, 71 insertions(+) diff --git a/google/cloud/firestore_v1/client.py b/google/cloud/firestore_v1/client.py index e6c9f45c97..1503bfc93f 100644 --- a/google/cloud/firestore_v1/client.py +++ b/google/cloud/firestore_v1/client.py @@ -24,6 +24,8 @@ :class:`~google.cloud.firestore_v1.document.DocumentReference` """ +from google.api_core import retry as retries # type: ignore + from google.cloud.firestore_v1.base_client import ( BaseClient, DEFAULT_DATABASE, @@ -202,11 +204,25 @@ def document(self, *document_path: Tuple[str]) -> DocumentReference: *self._document_path_helper(*document_path), client=self ) + @staticmethod + def _make_retry_timeout_kwargs(retry, timeout): + kwargs = {} + + if retry is not None: + kwargs["retry"] = retry + + if timeout is not None: + kwargs["timeout"] = timeout + + return kwargs + def get_all( self, references: list, field_paths: Iterable[str] = None, transaction: Transaction = None, + retry: retries.Retry = None, + timeout: float = None, ) -> Generator[Any, Any, None]: """Retrieve a batch of documents. @@ -237,6 +253,9 @@ def get_all( transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): An existing transaction that these ``references`` will be retrieved in. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. Yields: .DocumentSnapshot: The next document snapshot that fulfills the @@ -244,6 +263,8 @@ def get_all( """ document_paths, reference_map = _reference_info(references) mask = _get_doc_mask(field_paths) + kwargs = self._make_retry_timeout_kwargs(retry, timeout) + response_iterator = self._firestore_api.batch_get_documents( request={ "database": self._database_string, @@ -252,6 +273,7 @@ def get_all( "transaction": _helpers.get_transaction_id(transaction), }, metadata=self._rpc_metadata, + **kwargs, ) for get_doc_response in response_iterator: diff --git a/tests/unit/v1/test_client.py b/tests/unit/v1/test_client.py index b943fd1e14..ce0cd4dc25 100644 --- a/tests/unit/v1/test_client.py +++ b/tests/unit/v1/test_client.py @@ -303,6 +303,55 @@ def test_get_all(self): metadata=client._rpc_metadata, ) + def test_get_all_w_retry_timeout(self): + from google.api_core.retry import Retry + from google.cloud.firestore_v1.types import common + from google.cloud.firestore_v1.document import DocumentSnapshot + + data1 = {"a": u"cheese"} + data2 = {"b": True, "c": 18} + retry = Retry(predicate=object()) + timeout = 123.0 + info = self._info_for_get_all(data1, data2) + client, document1, document2, response1, response2 = info + + # Exercise the mocked ``batch_get_documents``. + field_paths = ["a", "b"] + snapshots = self._get_all_helper( + client, + [document1, document2], + [response1, response2], + field_paths=field_paths, + retry=retry, + timeout=timeout, + ) + self.assertEqual(len(snapshots), 2) + + snapshot1 = snapshots[0] + self.assertIsInstance(snapshot1, DocumentSnapshot) + self.assertIs(snapshot1._reference, document1) + self.assertEqual(snapshot1._data, data1) + + snapshot2 = snapshots[1] + self.assertIsInstance(snapshot2, DocumentSnapshot) + self.assertIs(snapshot2._reference, document2) + self.assertEqual(snapshot2._data, data2) + + # Verify the call to the mock. + doc_paths = [document1._document_path, document2._document_path] + mask = common.DocumentMask(field_paths=field_paths) + client._firestore_api.batch_get_documents.assert_called_once_with( + request={ + "database": client._database_string, + "documents": doc_paths, + "mask": mask, + "transaction": None, + }, + retry=retry, + timeout=timeout, + metadata=client._rpc_metadata, + ) + def test_get_all_with_transaction(self): from google.cloud.firestore_v1.document import DocumentSnapshot