Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: stop / start stream after filter mismatch #502

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions google/cloud/firestore_v1/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,5 @@ def on_snapshot(collection_snapshot, changes, read_time):
# Terminate this watch
collection_watch.unsubscribe()
"""
return Watch.for_query(
self._query(),
callback,
document.DocumentSnapshot,
document.DocumentReference,
)
query = self._query()
return Watch.for_query(query, callback, document.DocumentSnapshot)
2 changes: 1 addition & 1 deletion google/cloud/firestore_v1/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,4 +489,4 @@ def on_snapshot(document_snapshot, changes, read_time):
# Terminate this watch
doc_watch.unsubscribe()
"""
return Watch.for_document(self, callback, DocumentSnapshot, DocumentReference)
return Watch.for_document(self, callback, DocumentSnapshot)
4 changes: 1 addition & 3 deletions google/cloud/firestore_v1/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,9 +329,7 @@ def on_snapshot(docs, changes, read_time):
# Terminate this watch
query_watch.unsubscribe()
"""
return Watch.for_query(
self, callback, document.DocumentSnapshot, document.DocumentReference
)
return Watch.for_query(self, callback, document.DocumentSnapshot)

@staticmethod
def _get_collection_reference_class() -> Type[
Expand Down
67 changes: 34 additions & 33 deletions google/cloud/firestore_v1/watch.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,6 @@ def __init__(
comparator,
snapshot_callback,
document_snapshot_cls,
document_reference_cls,
):
"""
Args:
Expand All @@ -192,35 +191,21 @@ def __init__(
read_time (string): The ISO 8601 time at which this
snapshot was obtained.

document_snapshot_cls: instance of DocumentSnapshot
document_reference_cls: instance of DocumentReference
document_snapshot_cls: factory for instances of DocumentSnapshot
"""
self._document_reference = document_reference
self._firestore = firestore
self._api = firestore._firestore_api
self._targets = target
self._comparator = comparator
self.DocumentSnapshot = document_snapshot_cls
self.DocumentReference = document_reference_cls
self._document_snapshot_cls = document_snapshot_cls
self._snapshot_callback = snapshot_callback
self._api = firestore._firestore_api
self._closing = threading.Lock()
self._closed = False
self._set_documents_pfx(firestore._database_string)

self.resume_token = None

rpc_request = self._get_rpc_request

self._rpc = ResumableBidiRpc(
start_rpc=self._api._transport.listen,
should_recover=_should_recover,
should_terminate=_should_terminate,
initial_request=rpc_request,
metadata=self._firestore._rpc_metadata,
)

self._rpc.add_done_callback(self._on_rpc_done)

# Initialize state for on_snapshot
# The sorted tree of QueryDocumentSnapshots as sent in the last
# snapshot. We only look at the keys.
Expand All @@ -242,17 +227,29 @@ def __init__(
# aren't docs.
self.has_pushed = False

self._init_stream()

def _init_stream(self):

rpc_request = self._get_rpc_request

self._rpc = ResumableBidiRpc(
start_rpc=self._api._transport.listen,
should_recover=_should_recover,
should_terminate=_should_terminate,
initial_request=rpc_request,
metadata=self._firestore._rpc_metadata,
)

self._rpc.add_done_callback(self._on_rpc_done)

# The server assigns and updates the resume token.
self._consumer = BackgroundConsumer(self._rpc, self.on_snapshot)
self._consumer.start()

@classmethod
def for_document(
cls,
document_ref,
snapshot_callback,
document_snapshot_cls,
document_reference_cls,
cls, document_ref, snapshot_callback, document_snapshot_cls,
):
"""
Creates a watch snapshot listener for a document. snapshot_callback
Expand All @@ -276,13 +273,10 @@ def for_document(
document_watch_comparator,
snapshot_callback,
document_snapshot_cls,
document_reference_cls,
)

@classmethod
def for_query(
cls, query, snapshot_callback, document_snapshot_cls, document_reference_cls,
):
def for_query(cls, query, snapshot_callback, document_snapshot_cls):
parent_path, _ = query._parent._parent_info()
query_target = Target.QueryTarget(
parent=parent_path, structured_query=query._to_protobuf()
Expand All @@ -295,12 +289,13 @@ def for_query(
query._comparator,
snapshot_callback,
document_snapshot_cls,
document_reference_cls,
)

def _get_rpc_request(self):
if self.resume_token is not None:
self._targets["resume_token"] = self.resume_token
else:
self._targets.pop("resume_token", None)

return ListenRequest(
database=self._firestore._database_string, add_target=self._targets
Expand Down Expand Up @@ -490,7 +485,7 @@ def on_snapshot(self, proto):
document_name = self._strip_document_pfx(document.name)
document_ref = self._firestore.document(document_name)

snapshot = self.DocumentSnapshot(
snapshot = self._document_snapshot_cls(
reference=document_ref,
data=data,
exists=True,
Expand Down Expand Up @@ -520,11 +515,17 @@ def on_snapshot(self, proto):
elif which == "filter":
_LOGGER.debug("on_snapshot: filter update")
if pb.filter.count != self._current_size():
# We need to remove all the current results.
# First, shut down current stream
_LOGGER.info("Filter mismatch -- restarting stream.")
thread = threading.Thread(
name=_RPC_ERROR_THREAD_NAME, target=self.close,
)
thread.start()
thread.join() # wait for shutdown to complete
# Then, remove all the current results.
self._reset_docs()
# The filter didn't match, so re-issue the query.
# TODO: reset stream method?
# self._reset_stream();
# Finally, restart stream.
self._init_stream()

else:
_LOGGER.debug("UNKNOWN TYPE. UHOH")
Expand Down
11 changes: 4 additions & 7 deletions tests/unit/v1/test_cross_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,6 @@ def test_listen_testprotos(test_proto): # pragma: NO COVER
# 'docs' (list of 'google.firestore_v1.Document'),
# 'changes' (list lof local 'DocChange', and 'read_time' timestamp.
from google.cloud.firestore_v1 import Client
from google.cloud.firestore_v1 import DocumentReference
from google.cloud.firestore_v1 import DocumentSnapshot
from google.cloud.firestore_v1 import Watch
import google.auth.credentials
Expand All @@ -226,6 +225,9 @@ def test_listen_testprotos(test_proto): # pragma: NO COVER

credentials = mock.Mock(spec=google.auth.credentials.Credentials)
client = Client(project="project", credentials=credentials)
# conformance data has db string as this
db_str = "projects/projectID/databases/(default)"
client._database_string_internal = db_str
with mock.patch("google.cloud.firestore_v1.watch.ResumableBidiRpc"):
with mock.patch("google.cloud.firestore_v1.watch.BackgroundConsumer"):
# conformance data sets WATCH_TARGET_ID to 1
Expand All @@ -237,12 +239,7 @@ def callback(keys, applied_changes, read_time):

collection = DummyCollection(client=client)
query = DummyQuery(parent=collection)
watch = Watch.for_query(
query, callback, DocumentSnapshot, DocumentReference
)
# conformance data has db string as this
db_str = "projects/projectID/databases/(default)"
watch._firestore._database_string_internal = db_str
watch = Watch.for_query(query, callback, DocumentSnapshot)

wrapped_responses = [
firestore.ListenResponse.wrap(proto) for proto in testcase.responses
Expand Down
18 changes: 3 additions & 15 deletions tests/unit/v1/test_watch.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,6 @@ def snapshot_callback(*args):
comparator=comparator,
snapshot_callback=snapshot_callback,
document_snapshot_cls=DummyDocumentSnapshot,
document_reference_cls=DummyDocumentReference,
)


Expand Down Expand Up @@ -224,16 +223,11 @@ def snapshot_callback(*args): # pragma: NO COVER
snapshots.append(args)

docref = DummyDocumentReference()
snapshot_class_instance = DummyDocumentSnapshot
document_reference_class_instance = DummyDocumentReference

with mock.patch("google.cloud.firestore_v1.watch.ResumableBidiRpc"):
with mock.patch("google.cloud.firestore_v1.watch.BackgroundConsumer"):
inst = Watch.for_document(
docref,
snapshot_callback,
snapshot_class_instance,
document_reference_class_instance,
docref, snapshot_callback, document_snapshot_cls=DummyDocumentSnapshot,
)

inst._consumer.start.assert_called_once_with()
Expand All @@ -246,8 +240,6 @@ def test_watch_for_query(snapshots):
def snapshot_callback(*args): # pragma: NO COVER
snapshots.append(args)

snapshot_class_instance = DummyDocumentSnapshot
document_reference_class_instance = DummyDocumentReference
client = DummyFirestore()
parent = DummyCollection(client)
query = DummyQuery(parent=parent)
Expand All @@ -258,8 +250,7 @@ def snapshot_callback(*args): # pragma: NO COVER
inst = Watch.for_query(
query,
snapshot_callback,
snapshot_class_instance,
document_reference_class_instance,
document_snapshot_cls=DummyDocumentSnapshot,
)

inst._consumer.start.assert_called_once_with()
Expand All @@ -278,8 +269,6 @@ def test_watch_for_query_nested(snapshots):
def snapshot_callback(*args): # pragma: NO COVER
snapshots.append(args)

snapshot_class_instance = DummyDocumentSnapshot
document_reference_class_instance = DummyDocumentReference
client = DummyFirestore()
root = DummyCollection(client)
grandparent = DummyDocument("document", parent=root)
Expand All @@ -292,8 +281,7 @@ def snapshot_callback(*args): # pragma: NO COVER
inst = Watch.for_query(
query,
snapshot_callback,
snapshot_class_instance,
document_reference_class_instance,
document_snapshot_cls=DummyDocumentSnapshot,
)

inst._consumer.start.assert_called_once_with()
Expand Down