From a5da3fbacaaca20bcad6566130387c555200ba98 Mon Sep 17 00:00:00 2001 From: Tres Seaver Date: Tue, 7 Dec 2021 16:26:10 -0500 Subject: [PATCH] fix: stop / start stream after filter mismatch Closes #367. Supersedes #497. --- google/cloud/firestore_v1/collection.py | 8 +-- google/cloud/firestore_v1/document.py | 2 +- google/cloud/firestore_v1/query.py | 4 +- google/cloud/firestore_v1/watch.py | 67 +++++++++++++------------ tests/unit/v1/test_cross_language.py | 11 ++-- tests/unit/v1/test_watch.py | 18 ++----- 6 files changed, 45 insertions(+), 65 deletions(-) diff --git a/google/cloud/firestore_v1/collection.py b/google/cloud/firestore_v1/collection.py index 585f46f04f..3488275dd7 100644 --- a/google/cloud/firestore_v1/collection.py +++ b/google/cloud/firestore_v1/collection.py @@ -237,9 +237,5 @@ def on_snapshot(collection_snapshot, changes, read_time): # Terminate this watch collection_watch.unsubscribe() """ - return Watch.for_query( - self._query(), - callback, - document.DocumentSnapshot, - document.DocumentReference, - ) + query = self._query() + return Watch.for_query(query, callback, document.DocumentSnapshot) diff --git a/google/cloud/firestore_v1/document.py b/google/cloud/firestore_v1/document.py index 205fda44ca..acdab69e7a 100644 --- a/google/cloud/firestore_v1/document.py +++ b/google/cloud/firestore_v1/document.py @@ -489,4 +489,4 @@ def on_snapshot(document_snapshot, changes, read_time): # Terminate this watch doc_watch.unsubscribe() """ - return Watch.for_document(self, callback, DocumentSnapshot, DocumentReference) + return Watch.for_document(self, callback, DocumentSnapshot) diff --git a/google/cloud/firestore_v1/query.py b/google/cloud/firestore_v1/query.py index 59f85c69aa..25ac92cc2f 100644 --- a/google/cloud/firestore_v1/query.py +++ b/google/cloud/firestore_v1/query.py @@ -329,9 +329,7 @@ def on_snapshot(docs, changes, read_time): # Terminate this watch query_watch.unsubscribe() """ - return Watch.for_query( - self, callback, document.DocumentSnapshot, document.DocumentReference - ) + return Watch.for_query(self, callback, document.DocumentSnapshot) @staticmethod def _get_collection_reference_class() -> Type[ diff --git a/google/cloud/firestore_v1/watch.py b/google/cloud/firestore_v1/watch.py index 6efb10ecf1..ba45832e84 100644 --- a/google/cloud/firestore_v1/watch.py +++ b/google/cloud/firestore_v1/watch.py @@ -175,7 +175,6 @@ def __init__( comparator, snapshot_callback, document_snapshot_cls, - document_reference_cls, ): """ Args: @@ -192,35 +191,21 @@ def __init__( read_time (string): The ISO 8601 time at which this snapshot was obtained. - document_snapshot_cls: instance of DocumentSnapshot - document_reference_cls: instance of DocumentReference + document_snapshot_cls: factory for instances of DocumentSnapshot """ self._document_reference = document_reference self._firestore = firestore - self._api = firestore._firestore_api self._targets = target self._comparator = comparator - self.DocumentSnapshot = document_snapshot_cls - self.DocumentReference = document_reference_cls + self._document_snapshot_cls = document_snapshot_cls self._snapshot_callback = snapshot_callback + self._api = firestore._firestore_api self._closing = threading.Lock() self._closed = False self._set_documents_pfx(firestore._database_string) self.resume_token = None - rpc_request = self._get_rpc_request - - self._rpc = ResumableBidiRpc( - start_rpc=self._api._transport.listen, - should_recover=_should_recover, - should_terminate=_should_terminate, - initial_request=rpc_request, - metadata=self._firestore._rpc_metadata, - ) - - self._rpc.add_done_callback(self._on_rpc_done) - # Initialize state for on_snapshot # The sorted tree of QueryDocumentSnapshots as sent in the last # snapshot. We only look at the keys. @@ -242,17 +227,29 @@ def __init__( # aren't docs. self.has_pushed = False + self._init_stream() + + def _init_stream(self): + + rpc_request = self._get_rpc_request + + self._rpc = ResumableBidiRpc( + start_rpc=self._api._transport.listen, + should_recover=_should_recover, + should_terminate=_should_terminate, + initial_request=rpc_request, + metadata=self._firestore._rpc_metadata, + ) + + self._rpc.add_done_callback(self._on_rpc_done) + # The server assigns and updates the resume token. self._consumer = BackgroundConsumer(self._rpc, self.on_snapshot) self._consumer.start() @classmethod def for_document( - cls, - document_ref, - snapshot_callback, - document_snapshot_cls, - document_reference_cls, + cls, document_ref, snapshot_callback, document_snapshot_cls, ): """ Creates a watch snapshot listener for a document. snapshot_callback @@ -276,13 +273,10 @@ def for_document( document_watch_comparator, snapshot_callback, document_snapshot_cls, - document_reference_cls, ) @classmethod - def for_query( - cls, query, snapshot_callback, document_snapshot_cls, document_reference_cls, - ): + def for_query(cls, query, snapshot_callback, document_snapshot_cls): parent_path, _ = query._parent._parent_info() query_target = Target.QueryTarget( parent=parent_path, structured_query=query._to_protobuf() @@ -295,12 +289,13 @@ def for_query( query._comparator, snapshot_callback, document_snapshot_cls, - document_reference_cls, ) def _get_rpc_request(self): if self.resume_token is not None: self._targets["resume_token"] = self.resume_token + else: + self._targets.pop("resume_token", None) return ListenRequest( database=self._firestore._database_string, add_target=self._targets @@ -490,7 +485,7 @@ def on_snapshot(self, proto): document_name = self._strip_document_pfx(document.name) document_ref = self._firestore.document(document_name) - snapshot = self.DocumentSnapshot( + snapshot = self._document_snapshot_cls( reference=document_ref, data=data, exists=True, @@ -520,11 +515,17 @@ def on_snapshot(self, proto): elif which == "filter": _LOGGER.debug("on_snapshot: filter update") if pb.filter.count != self._current_size(): - # We need to remove all the current results. + # First, shut down current stream + _LOGGER.info("Filter mismatch -- restarting stream.") + thread = threading.Thread( + name=_RPC_ERROR_THREAD_NAME, target=self.close, + ) + thread.start() + thread.join() # wait for shutdown to complete + # Then, remove all the current results. self._reset_docs() - # The filter didn't match, so re-issue the query. - # TODO: reset stream method? - # self._reset_stream(); + # Finally, restart stream. + self._init_stream() else: _LOGGER.debug("UNKNOWN TYPE. UHOH") diff --git a/tests/unit/v1/test_cross_language.py b/tests/unit/v1/test_cross_language.py index 85495ceb0a..64cfacfb58 100644 --- a/tests/unit/v1/test_cross_language.py +++ b/tests/unit/v1/test_cross_language.py @@ -216,7 +216,6 @@ def test_listen_testprotos(test_proto): # pragma: NO COVER # 'docs' (list of 'google.firestore_v1.Document'), # 'changes' (list lof local 'DocChange', and 'read_time' timestamp. from google.cloud.firestore_v1 import Client - from google.cloud.firestore_v1 import DocumentReference from google.cloud.firestore_v1 import DocumentSnapshot from google.cloud.firestore_v1 import Watch import google.auth.credentials @@ -226,6 +225,9 @@ def test_listen_testprotos(test_proto): # pragma: NO COVER credentials = mock.Mock(spec=google.auth.credentials.Credentials) client = Client(project="project", credentials=credentials) + # conformance data has db string as this + db_str = "projects/projectID/databases/(default)" + client._database_string_internal = db_str with mock.patch("google.cloud.firestore_v1.watch.ResumableBidiRpc"): with mock.patch("google.cloud.firestore_v1.watch.BackgroundConsumer"): # conformance data sets WATCH_TARGET_ID to 1 @@ -237,12 +239,7 @@ def callback(keys, applied_changes, read_time): collection = DummyCollection(client=client) query = DummyQuery(parent=collection) - watch = Watch.for_query( - query, callback, DocumentSnapshot, DocumentReference - ) - # conformance data has db string as this - db_str = "projects/projectID/databases/(default)" - watch._firestore._database_string_internal = db_str + watch = Watch.for_query(query, callback, DocumentSnapshot) wrapped_responses = [ firestore.ListenResponse.wrap(proto) for proto in testcase.responses diff --git a/tests/unit/v1/test_watch.py b/tests/unit/v1/test_watch.py index 70a56409e7..e3e0adfce0 100644 --- a/tests/unit/v1/test_watch.py +++ b/tests/unit/v1/test_watch.py @@ -183,7 +183,6 @@ def snapshot_callback(*args): comparator=comparator, snapshot_callback=snapshot_callback, document_snapshot_cls=DummyDocumentSnapshot, - document_reference_cls=DummyDocumentReference, ) @@ -224,16 +223,11 @@ def snapshot_callback(*args): # pragma: NO COVER snapshots.append(args) docref = DummyDocumentReference() - snapshot_class_instance = DummyDocumentSnapshot - document_reference_class_instance = DummyDocumentReference with mock.patch("google.cloud.firestore_v1.watch.ResumableBidiRpc"): with mock.patch("google.cloud.firestore_v1.watch.BackgroundConsumer"): inst = Watch.for_document( - docref, - snapshot_callback, - snapshot_class_instance, - document_reference_class_instance, + docref, snapshot_callback, document_snapshot_cls=DummyDocumentSnapshot, ) inst._consumer.start.assert_called_once_with() @@ -246,8 +240,6 @@ def test_watch_for_query(snapshots): def snapshot_callback(*args): # pragma: NO COVER snapshots.append(args) - snapshot_class_instance = DummyDocumentSnapshot - document_reference_class_instance = DummyDocumentReference client = DummyFirestore() parent = DummyCollection(client) query = DummyQuery(parent=parent) @@ -258,8 +250,7 @@ def snapshot_callback(*args): # pragma: NO COVER inst = Watch.for_query( query, snapshot_callback, - snapshot_class_instance, - document_reference_class_instance, + document_snapshot_cls=DummyDocumentSnapshot, ) inst._consumer.start.assert_called_once_with() @@ -278,8 +269,6 @@ def test_watch_for_query_nested(snapshots): def snapshot_callback(*args): # pragma: NO COVER snapshots.append(args) - snapshot_class_instance = DummyDocumentSnapshot - document_reference_class_instance = DummyDocumentReference client = DummyFirestore() root = DummyCollection(client) grandparent = DummyDocument("document", parent=root) @@ -292,8 +281,7 @@ def snapshot_callback(*args): # pragma: NO COVER inst = Watch.for_query( query, snapshot_callback, - snapshot_class_instance, - document_reference_class_instance, + document_snapshot_cls=DummyDocumentSnapshot, ) inst._consumer.start.assert_called_once_with()