From c830b62395d1b65bb816300573afd875f2bb995d Mon Sep 17 00:00:00 2001 From: Prad Nelluru Date: Tue, 4 Feb 2020 15:06:30 -0500 Subject: [PATCH] feat(pubsub): ordering keys --- .../cloud/pubsub_v1/publisher/_batch/base.py | 32 +- .../pubsub_v1/publisher/_batch/thread.py | 96 ++++-- google/cloud/pubsub_v1/publisher/client.py | 244 +++++++++++---- .../cloud/pubsub_v1/publisher/exceptions.py | 19 +- .../subscriber/_protocol/dispatcher.py | 2 + .../pubsub_v1/subscriber/_protocol/leaser.py | 36 ++- .../subscriber/_protocol/requests.py | 14 +- .../_protocol/streaming_pull_manager.py | 133 +++++---- google/cloud/pubsub_v1/subscriber/message.py | 27 +- google/cloud/pubsub_v1/types.py | 19 ++ .../pubsub_v1/publisher/batch/test_base.py | 2 +- .../pubsub_v1/publisher/batch/test_thread.py | 207 ++++++++----- .../publisher/test_publisher_client.py | 277 +++++++++++++++--- .../pubsub_v1/subscriber/test_dispatcher.py | 59 +++- .../unit/pubsub_v1/subscriber/test_leaser.py | 64 ++-- .../unit/pubsub_v1/subscriber/test_message.py | 22 +- .../subscriber/test_streaming_pull_manager.py | 132 ++++++++- 17 files changed, 1057 insertions(+), 328 deletions(-) diff --git a/google/cloud/pubsub_v1/publisher/_batch/base.py b/google/cloud/pubsub_v1/publisher/_batch/base.py index 75f430b09..53d3dee5b 100644 --- a/google/cloud/pubsub_v1/publisher/_batch/base.py +++ b/google/cloud/pubsub_v1/publisher/_batch/base.py @@ -15,6 +15,7 @@ from __future__ import absolute_import import abc +import enum import six @@ -134,6 +135,18 @@ def will_accept(self, message): # Okay, everything is good. return True + def cancel(self, cancellation_reason): + """Complete pending futures with an exception. + + This method must be called before publishing starts (ie: while the + batch is still accepting messages.) + + Args: + cancellation_reason (BatchCancellationReason): The reason why this + batch has been cancelled. + """ + raise NotImplementedError + @abc.abstractmethod def publish(self, message): """Publish a single message. @@ -154,16 +167,21 @@ def publish(self, message): raise NotImplementedError -class BatchStatus(object): - """An enum-like class representing valid statuses for a batch. - - It is acceptable for a class to use a status that is not on this - class; this represents the list of statuses where the existing - library hooks in functionality. - """ +class BatchStatus(str, enum.Enum): + """An enum-like class representing valid statuses for a batch.""" ACCEPTING_MESSAGES = "accepting messages" STARTING = "starting" IN_PROGRESS = "in progress" ERROR = "error" SUCCESS = "success" + + +class BatchCancellationReason(str, enum.Enum): + """An enum-like class representing reasons why a batch was cancelled.""" + + PRIOR_ORDERED_MESSAGE_FAILED = ( + "Batch cancelled because prior ordered message for the same key has " + "failed. This batch has been cancelled to avoid out-of-order publish." + ) + CLIENT_STOPPED = "Batch cancelled because the publisher client has been stopped." diff --git a/google/cloud/pubsub_v1/publisher/_batch/thread.py b/google/cloud/pubsub_v1/publisher/_batch/thread.py index 4101bc518..cdd913db4 100644 --- a/google/cloud/pubsub_v1/publisher/_batch/thread.py +++ b/google/cloud/pubsub_v1/publisher/_batch/thread.py @@ -62,15 +62,23 @@ class Batch(base.Batch): settings (~.pubsub_v1.types.BatchSettings): The settings for batch publishing. These should be considered immutable once the batch has been opened. - autocommit (bool): Whether to autocommit the batch when the time - has elapsed. Defaults to True unless ``settings.max_latency`` is - inf. + batch_done_callback (Callable[[bool], Any]): Callback called when the + response for a batch publish has been received. Called with one + boolean argument: successfully published or a permanent error + occurred. Temporary errors are not surfaced because they are retried + at a lower level. + commit_when_full (bool): Whether to commit the batch when the batch + is full. """ - def __init__(self, client, topic, settings, autocommit=True): + def __init__( + self, client, topic, settings, batch_done_callback=None, commit_when_full=True + ): self._client = client self._topic = topic self._settings = settings + self._batch_done_callback = batch_done_callback + self._commit_when_full = commit_when_full self._state_lock = threading.Lock() # These members are all communicated between threads; ensure that @@ -87,15 +95,6 @@ def __init__(self, client, topic, settings, autocommit=True): self._base_request_size = types.PublishRequest(topic=topic).ByteSize() self._size = self._base_request_size - # If max latency is specified, start a thread to monitor the batch and - # commit when the max latency is reached. - self._thread = None - if autocommit and self.settings.max_latency < float("inf"): - self._thread = threading.Thread( - name="Thread-MonitorBatchPublisher", target=self.monitor - ) - self._thread.start() - @staticmethod def make_lock(): """Return a threading lock. @@ -148,6 +147,27 @@ def status(self): """ return self._status + def cancel(self, cancellation_reason): + """Complete pending futures with an exception. + + This method must be called before publishing starts (ie: while the + batch is still accepting messages.) + + Args: + cancellation_reason (BatchCancellationReason): The reason why this + batch has been cancelled. + """ + + with self._state_lock: + assert ( + self._status == base.BatchStatus.ACCEPTING_MESSAGES + ), "Cancel should not be called after sending has started." + + exc = RuntimeError(cancellation_reason.value) + for future in self._futures: + future.set_exception(exc) + self._status = base.BatchStatus.ERROR + def commit(self): """Actually publish all of the messages on the active batch. @@ -162,6 +182,7 @@ def commit(self): If the current batch is **not** accepting messages, this method does nothing. """ + # Set the status to "starting" synchronously, to ensure that # this batch will necessarily not accept new messages. with self._state_lock: @@ -170,7 +191,11 @@ def commit(self): else: return - # Start a new thread to actually handle the commit. + self._start_commit_thread() + + def _start_commit_thread(self): + """Start a new thread to actually handle the commit.""" + commit_thread = threading.Thread( name="Thread-CommitBatchPublisher", target=self._commit ) @@ -195,7 +220,10 @@ def _commit(self): # If, in the intervening period between when this method was # called and now, the batch started to be committed, or # completed a commit, then no-op at this point. - _LOGGER.debug("Batch is already in progress, exiting commit") + _LOGGER.debug( + "Batch is already in progress or has been cancelled, " + "exiting commit" + ) return # Once in the IN_PROGRESS state, no other thread can publish additional @@ -215,16 +243,24 @@ def _commit(self): # Log how long the underlying request takes. start = time.time() + batch_transport_succeeded = True try: + # Performs retries for errors defined in retry_codes.publish in the + # publisher_client_config.py file. response = self._client.api.publish(self._topic, self._messages) except google.api_core.exceptions.GoogleAPIError as exc: - # We failed to publish, set the exception on all futures and - # exit. + # We failed to publish, even after retries, so set the exception on + # all futures and exit. self._status = base.BatchStatus.ERROR for future in self._futures: future.set_exception(exc) + batch_transport_succeeded = False + if self._batch_done_callback is not None: + # Failed to publish batch. + self._batch_done_callback(batch_transport_succeeded) + _LOGGER.exception("Failed to publish %s messages.", len(self._futures)) return @@ -250,26 +286,17 @@ def _commit(self): for future in self._futures: future.set_exception(exception) + # Unknown error -> batch failed to be correctly transported/ + batch_transport_succeeded = False + _LOGGER.error( "Only %s of %s messages were published.", len(response.message_ids), len(self._futures), ) - def monitor(self): - """Commit this batch after sufficient time has elapsed. - - This simply sleeps for ``self.settings.max_latency`` seconds, - and then calls commit unless the batch has already been committed. - """ - # NOTE: This blocks; it is up to the calling code to call it - # in a separate thread. - - # Sleep for however long we should be waiting. - time.sleep(self.settings.max_latency) - - _LOGGER.debug("Monitor is waking up") - return self._commit() + if self._batch_done_callback is not None: + self._batch_done_callback(batch_transport_succeeded) def publish(self, message): """Publish a single message. @@ -294,6 +321,7 @@ def publish(self, message): pubsub_v1.publisher.exceptions.MessageTooLargeError: If publishing the ``message`` would exceed the max size limit on the backend. """ + # Coerce the type, just in case. if not isinstance(message, types.PubsubMessage): message = types.PubsubMessage(**message) @@ -301,6 +329,10 @@ def publish(self, message): future = None with self._state_lock: + assert ( + self._status != base.BatchStatus.ERROR + ), "Publish after stop() or publish error." + if not self.will_accept(message): return future @@ -333,7 +365,7 @@ def publish(self, message): # Try to commit, but it must be **without** the lock held, since # ``commit()`` will try to obtain the lock. - if overflow: + if self._commit_when_full and overflow: self.commit() return future diff --git a/google/cloud/pubsub_v1/publisher/client.py b/google/cloud/pubsub_v1/publisher/client.py index 60a03bb65..ea43b09d5 100644 --- a/google/cloud/pubsub_v1/publisher/client.py +++ b/google/cloud/pubsub_v1/publisher/client.py @@ -15,8 +15,11 @@ from __future__ import absolute_import import copy +import logging import os import pkg_resources +import threading +import time import grpc import six @@ -29,10 +32,13 @@ from google.cloud.pubsub_v1.gapic import publisher_client from google.cloud.pubsub_v1.gapic.transports import publisher_grpc_transport from google.cloud.pubsub_v1.publisher._batch import thread - +from google.cloud.pubsub_v1.publisher._sequencer import ordered_sequencer +from google.cloud.pubsub_v1.publisher._sequencer import unordered_sequencer __version__ = pkg_resources.get_distribution("google-cloud-pubsub").version +_LOGGER = logging.getLogger(__name__) + _BLACKLISTED_METHODS = ( "publish", "from_service_account_file", @@ -40,6 +46,14 @@ ) +def _set_nested_value(container, value, keys): + current = container + for key in keys[:-1]: + current = current.setdefault(key, {}) + current[keys[-1]] = value + return container + + @_gapic.add_methods(publisher_client.PublisherClient, blacklist=_BLACKLISTED_METHODS) class Client(object): """A publisher client for Google Cloud Pub/Sub. @@ -49,6 +63,9 @@ class Client(object): get sensible defaults. Args: + publisher_options (~google.cloud.pubsub_v1.types.PublisherOptions): The + options for the publisher client. Note that enabling message ordering will + override the publish retry timeout to be infinite. batch_settings (~google.cloud.pubsub_v1.types.BatchSettings): The settings for batch publishing. kwargs (dict): Any additional arguments provided are sent as keyword @@ -68,6 +85,11 @@ class Client(object): from google.cloud import pubsub_v1 publisher_client = pubsub_v1.PublisherClient( + # Optional + publisher_options = pubsub_v1.types.PublisherOptions( + enable_message_ordering=False + ), + # Optional batch_settings = pubsub_v1.types.BatchSettings( max_bytes=1024, # One kilobyte @@ -94,9 +116,7 @@ class Client(object): ) """ - _batch_class = thread.Batch - - def __init__(self, batch_settings=(), **kwargs): + def __init__(self, publisher_options=(), batch_settings=(), **kwargs): # Sanity check: Is our goal to use the emulator? # If so, create a grpc insecure channel with the emulator host # as the target. @@ -125,16 +145,40 @@ def __init__(self, batch_settings=(), **kwargs): transport = publisher_grpc_transport.PublisherGrpcTransport(channel=channel) kwargs["transport"] = transport + # For a transient failure, retry publishing the message infinitely. + self.publisher_options = types.PublisherOptions(*publisher_options) + self._enable_message_ordering = self.publisher_options[0] + if self._enable_message_ordering: + # Set retry timeout to "infinite" when message ordering is enabled. + # Note that this then also impacts messages added with an empty ordering + # key. + client_config = _set_nested_value( + kwargs.pop("client_config", {}), + 2 ** 32, + [ + "interfaces", + "google.pubsub.v1.Publisher", + "retry_params", + "messaging", + "total_timeout_millis", + ], + ) + kwargs["client_config"] = client_config + # Add the metrics headers, and instantiate the underlying GAPIC # client. self.api = publisher_client.PublisherClient(**kwargs) + self._batch_class = thread.Batch self.batch_settings = types.BatchSettings(*batch_settings) # The batches on the publisher client are responsible for holding # messages. One batch exists for each topic. self._batch_lock = self._batch_class.make_lock() - self._batches = {} + # (topic, ordering_key) => sequencers object + self._sequencers = {} self._is_stopped = False + # Thread created to commit all sequencers after a timeout. + self._commit_thread = None @classmethod def from_service_account_file(cls, filename, batch_settings=(), **kwargs): @@ -167,44 +211,60 @@ def target(self): """ return publisher_client.PublisherClient.SERVICE_ADDRESS - def _batch(self, topic, create=False, autocommit=True): - """Return the current batch for the provided topic. + def _get_or_create_sequencer(self, topic, ordering_key): + """ Get an existing sequencer or create a new one given the (topic, + ordering_key) pair. + """ + sequencer_key = (topic, ordering_key) + sequencer = self._sequencers.get(sequencer_key) + if sequencer is None: + if ordering_key == "": + sequencer = unordered_sequencer.UnorderedSequencer(self, topic) + else: + sequencer = ordered_sequencer.OrderedSequencer( + self, topic, ordering_key + ) + self._sequencers[sequencer_key] = sequencer + + return sequencer - This will create a new batch if ``create=True`` or if no batch - currently exists. + def resume_publish(self, topic, ordering_key): + """ Resume publish on an ordering key that has had unrecoverable errors. Args: - topic (str): A string representing the topic. - create (bool): Whether to create a new batch. Defaults to - :data:`False`. If :data:`True`, this will create a new batch - even if one already exists. - autocommit (bool): Whether to autocommit this batch. This is - primarily useful for debugging and testing, since it allows - the caller to avoid some side effects that batch creation - might have (e.g. spawning a worker to publish a batch). + topic (str): The topic to publish messages to. + ordering_key: A string that identifies related messages for which + publish order should be respected. - Returns: - ~.pubsub_v1._batch.Batch: The batch object. + Raises: + RuntimeError: + If called after publisher has been stopped by a `stop()` method + call. + ValueError: + If the topic/ordering key combination has not been seen before + by this client. """ - # If there is no matching batch yet, then potentially create one - # and place it on the batches dictionary. - if not create: - batch = self._batches.get(topic) - if batch is None: - create = True - - if create: - batch = self._batch_class( - autocommit=autocommit, - client=self, - settings=self.batch_settings, - topic=topic, - ) - self._batches[topic] = batch + with self._batch_lock: + if self._is_stopped: + raise RuntimeError("Cannot resume publish on a stopped publisher.") + + if not self._enable_message_ordering: + raise ValueError( + "Cannot resume publish on a topic/ordering key if ordering " + "is not enabled." + ) - return batch + sequencer_key = (topic, ordering_key) + sequencer = self._sequencers.get(sequencer_key) + if sequencer is None: + _LOGGER.debug( + "Error: The topic/ordering key combination has not " + "been seen before." + ) + else: + sequencer.unpause() - def publish(self, topic, data, **attrs): + def publish(self, topic, data, ordering_key="", **attrs): """Publish a single message. .. note:: @@ -234,6 +294,11 @@ def publish(self, topic, data, **attrs): topic (str): The topic to publish messages to. data (bytes): A bytestring representing the message body. This must be a bytestring. + ordering_key: A string that identifies related messages for which + publish order should be respected. Message ordering must be + enabled for this client to use this feature. + EXPERIMENTAL: This feature is currently available in a closed + alpha. Please contact the Cloud Pub/Sub team to use it. attrs (Mapping[str, str]): A dictionary of attributes to be sent as metadata. (These may be text strings or byte strings.) @@ -245,8 +310,11 @@ def publish(self, topic, data, **attrs): Raises: RuntimeError: - If called after publisher has been stopped - by a `stop()` method call. + If called after publisher has been stopped by a `stop()` method + call. + + pubsub_v1.publisher.exceptions.MessageTooLargeError: If publishing + the ``message`` would exceed the max size limit on the backend. """ # Sanity check: Is the data being sent as a bytestring? # If it is literally anything else, complain loudly about it. @@ -255,6 +323,12 @@ def publish(self, topic, data, **attrs): "Data being published to Pub/Sub must be sent as a bytestring." ) + if not self._enable_message_ordering and ordering_key != "": + raise ValueError( + "Cannot publish a message with an ordering key when message " + "ordering is not enabled." + ) + # Coerce all attributes to text strings. for k, v in copy.copy(attrs).items(): if isinstance(v, six.text_type): @@ -268,21 +342,74 @@ def publish(self, topic, data, **attrs): ) # Create the Pub/Sub message object. - message = types.PubsubMessage(data=data, attributes=attrs) + message = types.PubsubMessage( + data=data, ordering_key=ordering_key, attributes=attrs + ) - # Delegate the publishing to the batch. with self._batch_lock: if self._is_stopped: raise RuntimeError("Cannot publish on a stopped publisher.") - batch = self._batch(topic) - future = None - while future is None: - future = batch.publish(message) - if future is None: - batch = self._batch(topic, create=True) + sequencer = self._get_or_create_sequencer(topic, ordering_key) + + # Delegate the publishing to the sequencer. + future = sequencer.publish(message) + + # Create a timer thread if necessary to enforce the batching + # timeout. + self._ensure_commit_timer_runs_no_lock() + + return future + + def ensure_cleanup_and_commit_timer_runs(self): + """ Ensure a cleanup/commit timer thread is running. + + If a cleanup/commit timer thread is already running, this does nothing. + """ + with self._batch_lock: + self._ensure_commit_timer_runs_no_lock() + + def _ensure_commit_timer_runs_no_lock(self): + """ Ensure a commit timer thread is running, without taking + _batch_lock. + + _batch_lock must be held before calling this method. + """ + if not self._commit_thread and self.batch_settings.max_latency < float("inf"): + self._start_commit_thread() + + def _start_commit_thread(self): + """Start a new thread to actually wait and commit the sequencers.""" + self._commit_thread = threading.Thread( + name="Thread-PubSubBatchCommitter", target=self._wait_and_commit_sequencers + ) + self._commit_thread.start() + + def _wait_and_commit_sequencers(self): + """ Wait up to the batching timeout, and commit all sequencers. + """ + # Sleep for however long we should be waiting. + time.sleep(self.batch_settings.max_latency) + _LOGGER.debug("Commit thread is waking up") - return future + with self._batch_lock: + if self._is_stopped: + return + self._commit_sequencers() + self._commit_thread = None + + def _commit_sequencers(self): + """ Clean up finished sequencers and commit the rest. """ + finished_sequencer_keys = [ + key + for key, sequencer in self._sequencers.items() + if sequencer.is_finished() + ] + for sequencer_key in finished_sequencer_keys: + del self._sequencers[sequencer_key] + + for sequencer in self._sequencers.values(): + sequencer.commit() def stop(self): """Immediately publish all outstanding messages. @@ -297,6 +424,11 @@ def stop(self): This method is non-blocking. Use `Future()` objects returned by `publish()` to make sure all publish requests completed, either in success or error. + + Raises: + RuntimeError: + If called after publisher has been stopped by a `stop()` method + call. """ with self._batch_lock: if self._is_stopped: @@ -304,5 +436,19 @@ def stop(self): self._is_stopped = True - for batch in self._batches.values(): - batch.commit() + for sequencer in self._sequencers.values(): + sequencer.stop() + + # Used only for testing. + def _set_batch(self, topic, batch, ordering_key=""): + sequencer = self._get_or_create_sequencer(topic, ordering_key) + sequencer._set_batch(batch) + + # Used only for testing. + def _set_batch_class(self, batch_class): + self._batch_class = batch_class + + # Used only for testing. + def _set_sequencer(self, topic, sequencer, ordering_key=""): + sequencer_key = (topic, ordering_key) + self._sequencers[sequencer_key] = sequencer diff --git a/google/cloud/pubsub_v1/publisher/exceptions.py b/google/cloud/pubsub_v1/publisher/exceptions.py index be176bac2..856be955a 100644 --- a/google/cloud/pubsub_v1/publisher/exceptions.py +++ b/google/cloud/pubsub_v1/publisher/exceptions.py @@ -26,4 +26,21 @@ class MessageTooLargeError(ValueError): """Attempt to publish a message that would exceed the server max size limit.""" -__all__ = ("MessageTooLargeError", "PublishError", "TimeoutError") +class PublishToPausedOrderingKeyException(Exception): + """ Publish attempted to paused ordering key. To resume publishing, call + the resumePublish method on the publisher Client object with this + ordering key. Ordering keys are paused if an unrecoverable error + occurred during publish of a batch for that key. + """ + + def __init__(self, ordering_key): + self.ordering_key = ordering_key + super(PublishToPausedOrderingKeyException, self).__init__() + + +__all__ = ( + "MessageTooLargeError", + "PublishError", + "TimeoutError", + "PublishToPausedOrderingKeyException", +) diff --git a/google/cloud/pubsub_v1/subscriber/_protocol/dispatcher.py b/google/cloud/pubsub_v1/subscriber/_protocol/dispatcher.py index b1d8429cb..6a82ba046 100644 --- a/google/cloud/pubsub_v1/subscriber/_protocol/dispatcher.py +++ b/google/cloud/pubsub_v1/subscriber/_protocol/dispatcher.py @@ -155,6 +155,8 @@ def drop(self, items): items(Sequence[DropRequest]): The items to drop. """ self._manager.leaser.remove(items) + ordering_keys = (k.ordering_key for k in items if k.ordering_key) + self._manager.activate_ordering_keys(ordering_keys) self._manager.maybe_resume_consumer() def lease(self, items): diff --git a/google/cloud/pubsub_v1/subscriber/_protocol/leaser.py b/google/cloud/pubsub_v1/subscriber/_protocol/leaser.py index 8a683e4e7..b60379444 100644 --- a/google/cloud/pubsub_v1/subscriber/_protocol/leaser.py +++ b/google/cloud/pubsub_v1/subscriber/_protocol/leaser.py @@ -30,7 +30,9 @@ _LEASE_WORKER_NAME = "Thread-LeaseMaintainer" -_LeasedMessage = collections.namedtuple("_LeasedMessage", ["added_time", "size"]) +_LeasedMessage = collections.namedtuple( + "_LeasedMessage", ["sent_time", "size", "ordering_key"] +) class Leaser(object): @@ -45,6 +47,7 @@ def __init__(self, manager): # intertwined. Protects the _leased_messages and _bytes attributes. self._add_remove_lock = threading.Lock() + # Dict of ack_id -> _LeasedMessage self._leased_messages = {} """dict[str, float]: A mapping of ack IDs to the local time when the ack ID was initially leased in seconds since the epoch.""" @@ -76,12 +79,31 @@ def add(self, items): # the size counter. if item.ack_id not in self._leased_messages: self._leased_messages[item.ack_id] = _LeasedMessage( - added_time=time.time(), size=item.byte_size + sent_time=float("inf"), + size=item.byte_size, + ordering_key=item.ordering_key, ) self._bytes += item.byte_size else: _LOGGER.debug("Message %s is already lease managed", item.ack_id) + def start_lease_expiry_timer(self, ack_ids): + """Start the lease expiry timer for `items`. + + Args: + items (Sequence[str]): Sequence of ack-ids for which to start + lease expiry timers. + """ + with self._add_remove_lock: + for ack_id in ack_ids: + lease_info = self._leased_messages.get(ack_id) + # Lease info might not exist for this ack_id because it has already + # been removed by remove(). + if lease_info: + self._leased_messages[ack_id] = lease_info._replace( + sent_time=time.time() + ) + def remove(self, items): """Remove messages from lease management.""" with self._add_remove_lock: @@ -116,14 +138,14 @@ def maintain_leases(self): # we're iterating over it. leased_messages = copy.copy(self._leased_messages) - # Drop any leases that are well beyond max lease time. This - # ensures that in the event of a badly behaving actor, we can - # drop messages and allow Pub/Sub to resend them. + # Drop any leases that are beyond the max lease time. This ensures + # that in the event of a badly behaving actor, we can drop messages + # and allow the Pub/Sub server to resend them. cutoff = time.time() - self._manager.flow_control.max_lease_duration to_drop = [ - requests.DropRequest(ack_id, item.size) + requests.DropRequest(ack_id, item.size, item.ordering_key) for ack_id, item in six.iteritems(leased_messages) - if item.added_time < cutoff + if item.sent_time < cutoff ] if to_drop: diff --git a/google/cloud/pubsub_v1/subscriber/_protocol/requests.py b/google/cloud/pubsub_v1/subscriber/_protocol/requests.py index ac1df0af8..58d53a61d 100644 --- a/google/cloud/pubsub_v1/subscriber/_protocol/requests.py +++ b/google/cloud/pubsub_v1/subscriber/_protocol/requests.py @@ -21,13 +21,19 @@ # Namedtuples for management requests. Used by the Message class to communicate # items of work back to the policy. AckRequest = collections.namedtuple( - "AckRequest", ["ack_id", "byte_size", "time_to_ack"] + "AckRequest", ["ack_id", "byte_size", "time_to_ack", "ordering_key"] ) -DropRequest = collections.namedtuple("DropRequest", ["ack_id", "byte_size"]) +DropRequest = collections.namedtuple( + "DropRequest", ["ack_id", "byte_size", "ordering_key"] +) -LeaseRequest = collections.namedtuple("LeaseRequest", ["ack_id", "byte_size"]) +LeaseRequest = collections.namedtuple( + "LeaseRequest", ["ack_id", "byte_size", "ordering_key"] +) ModAckRequest = collections.namedtuple("ModAckRequest", ["ack_id", "seconds"]) -NackRequest = collections.namedtuple("NackRequest", ["ack_id", "byte_size"]) +NackRequest = collections.namedtuple( + "NackRequest", ["ack_id", "byte_size", "ordering_key"] +) diff --git a/google/cloud/pubsub_v1/subscriber/_protocol/streaming_pull_manager.py b/google/cloud/pubsub_v1/subscriber/_protocol/streaming_pull_manager.py index 26764b1a9..0a3d9141f 100644 --- a/google/cloud/pubsub_v1/subscriber/_protocol/streaming_pull_manager.py +++ b/google/cloud/pubsub_v1/subscriber/_protocol/streaming_pull_manager.py @@ -21,7 +21,6 @@ import grpc import six -from six.moves import queue from google.api_core import bidi from google.api_core import exceptions @@ -30,6 +29,7 @@ from google.cloud.pubsub_v1.subscriber._protocol import heartbeater from google.cloud.pubsub_v1.subscriber._protocol import histogram from google.cloud.pubsub_v1.subscriber._protocol import leaser +from google.cloud.pubsub_v1.subscriber._protocol import messages_on_hold from google.cloud.pubsub_v1.subscriber._protocol import requests import google.cloud.pubsub_v1.subscriber.message import google.cloud.pubsub_v1.subscriber.scheduler @@ -123,12 +123,11 @@ def __init__( else: self._scheduler = scheduler - # A FIFO queue for the messages that have been received from the server, - # but not yet added to the lease management (and not sent to user callback), - # because the FlowControl limits have been hit. - self._messages_on_hold = queue.Queue() + # A collection for the messages that have been received from the server, + # but not yet sent to the user callback. + self._messages_on_hold = messages_on_hold.MessagesOnHold() - # the total number of bytes consumed by the messages currently on hold + # The total number of bytes consumed by the messages currently on hold self._on_hold_bytes = 0 # A lock ensuring that pausing / resuming the consumer are both atomic @@ -225,7 +224,7 @@ def load(self): # be subtracted from the leaser's values. return max( [ - (self._leaser.message_count - self._messages_on_hold.qsize()) + (self._leaser.message_count - self._messages_on_hold.size) / self._flow_control.max_messages, (self._leaser.bytes - self._on_hold_bytes) / self._flow_control.max_bytes, @@ -240,6 +239,25 @@ def add_close_callback(self, callback): """ self._close_callbacks.append(callback) + def activate_ordering_keys(self, ordering_keys): + """Send the next message in the queue for each of the passed-in + ordering keys, if they exist. Clean up state for keys that no longer + have any queued messages. + + Since the load went down by one message, it's probably safe to send the + user another message for the same key. Since the released message may be + bigger than the previous one, this may increase the load above the maximum. + This decision is by design because it simplifies MessagesOnHold. + + Args: + ordering_keys(Sequence[str]): A sequence of ordering keys to + activate. May be empty. + """ + with self._pause_resume_lock: + self._messages_on_hold.activate_ordering_keys( + ordering_keys, self._schedule_message_on_hold + ) + def maybe_pause_consumer(self): """Check the current load and pause the consumer if needed.""" with self._pause_resume_lock: @@ -290,30 +308,44 @@ def _maybe_release_messages(self): The method assumes the caller has acquired the ``_pause_resume_lock``. """ - while True: - if self.load >= _MAX_LOAD: - break # already overloaded - - try: - msg = self._messages_on_hold.get_nowait() - except queue.Empty: + released_ack_ids = [] + while self.load < _MAX_LOAD: + msg = self._messages_on_hold.get() + if not msg: break - self._on_hold_bytes -= msg.size + self._schedule_message_on_hold(msg) + released_ack_ids.append(msg.ack_id) + self._leaser.start_lease_expiry_timer(released_ack_ids) - if self._on_hold_bytes < 0: - _LOGGER.warning( - "On hold bytes was unexpectedly negative: %s", self._on_hold_bytes - ) - self._on_hold_bytes = 0 + def _schedule_message_on_hold(self, msg): + """Schedule a message on hold to be sent to the user and change + on-hold-bytes. - _LOGGER.debug( - "Released held message, scheduling callback for it, " - "still on hold %s (bytes %s).", - self._messages_on_hold.qsize(), - self._on_hold_bytes, + The method assumes the caller has acquired the ``_pause_resume_lock``. + + Args: + msg (google.cloud.pubsub_v1.message.Message): The message to + schedule to be sent to the user. + """ + assert msg, "Message must not be None." + + # On-hold bytes goes down, increasing load. + self._on_hold_bytes -= msg.size + + if self._on_hold_bytes < 0: + _LOGGER.warning( + "On hold bytes was unexpectedly negative: %s", self._on_hold_bytes ) - self._scheduler.schedule(self._callback, msg) + self._on_hold_bytes = 0 + + _LOGGER.debug( + "Released held message, scheduling callback for it, " + "still on hold %s (bytes %s).", + self._messages_on_hold.size, + self._on_hold_bytes, + ) + self._scheduler.schedule(self._callback, msg) def _send_unary_request(self, request): """Send a request using a separate unary request instead of over the @@ -552,7 +584,7 @@ def _on_response(self, response): _LOGGER.debug( "Processing %s received message(s), currenty on hold %s (bytes %s).", len(response.received_messages), - self._messages_on_hold.qsize(), + self._messages_on_hold.size, self._on_hold_bytes, ) @@ -565,37 +597,26 @@ def _on_response(self, response): ] self._dispatcher.modify_ack_deadline(items) - invoke_callbacks_for = [] + with self._pause_resume_lock: + for received_message in response.received_messages: + message = google.cloud.pubsub_v1.subscriber.message.Message( + received_message.message, + received_message.ack_id, + received_message.delivery_attempt, + self._scheduler.queue, + ) + self._messages_on_hold.put(message) + self._on_hold_bytes += message.size + req = requests.LeaseRequest( + ack_id=message.ack_id, + byte_size=message.size, + ordering_key=message.ordering_key, + ) + self.leaser.add([req]) - for received_message in response.received_messages: - message = google.cloud.pubsub_v1.subscriber.message.Message( - received_message.message, - received_message.ack_id, - received_message.delivery_attempt, - self._scheduler.queue, - ) - # Making a decision based on the load, and modifying the data that - # affects the load -> needs a lock, as that state can be modified - # by different threads. - with self._pause_resume_lock: - if self.load < _MAX_LOAD: - invoke_callbacks_for.append(message) - else: - self._messages_on_hold.put(message) - self._on_hold_bytes += message.size - - req = requests.LeaseRequest(ack_id=message.ack_id, byte_size=message.size) - self.leaser.add([req]) - self.maybe_pause_consumer() + self._maybe_release_messages() - _LOGGER.debug( - "Scheduling callbacks for %s new messages, new total on hold %s (bytes %s).", - len(invoke_callbacks_for), - self._messages_on_hold.qsize(), - self._on_hold_bytes, - ) - for msg in invoke_callbacks_for: - self._scheduler.schedule(self._callback, msg) + self.maybe_pause_consumer() def _should_recover(self, exception): """Determine if an error on the RPC stream should be recovered. diff --git a/google/cloud/pubsub_v1/subscriber/message.py b/google/cloud/pubsub_v1/subscriber/message.py index 6dc7bc443..cafc34b80 100644 --- a/google/cloud/pubsub_v1/subscriber/message.py +++ b/google/cloud/pubsub_v1/subscriber/message.py @@ -26,6 +26,7 @@ _MESSAGE_REPR = """\ Message {{ data: {!r} + ordering_key: {!r} attributes: {} }}""" @@ -112,7 +113,7 @@ def __repr__(self): pretty_attrs = _indent(pretty_attrs) # We don't actually want the first line indented. pretty_attrs = pretty_attrs.lstrip() - return _MESSAGE_REPR.format(abbv_data, pretty_attrs) + return _MESSAGE_REPR.format(abbv_data, str(self.ordering_key), pretty_attrs) @property def attributes(self): @@ -156,6 +157,11 @@ def publish_time(self): ) return datetime_helpers._UTC_EPOCH + delta + @property + def ordering_key(self): + """str: the ordering key used to publish the message.""" + return self._message.ordering_key + @property def size(self): """Return the size of the underlying message, in bytes.""" @@ -207,7 +213,10 @@ def ack(self): time_to_ack = math.ceil(time.time() - self._received_timestamp) self._request_queue.put( requests.AckRequest( - ack_id=self._ack_id, byte_size=self.size, time_to_ack=time_to_ack + ack_id=self._ack_id, + byte_size=self.size, + time_to_ack=time_to_ack, + ordering_key=self.ordering_key, ) ) @@ -220,12 +229,14 @@ def drop(self): .. warning:: For most use cases, the only reason to drop a message from - lease management is on :meth:`ack` or :meth:`nack`; these methods - both call this one. You probably do not want to call this method - directly. + lease management is on `ack` or `nack`; this library + automatically drop()s the message on `ack` or `nack`. You probably + do not want to call this method directly. """ self._request_queue.put( - requests.DropRequest(ack_id=self._ack_id, byte_size=self.size) + requests.DropRequest( + ack_id=self._ack_id, byte_size=self.size, ordering_key=self.ordering_key + ) ) def modify_ack_deadline(self, seconds): @@ -253,5 +264,7 @@ def nack(self): This will cause the message to be re-delivered to the subscription. """ self._request_queue.put( - requests.NackRequest(ack_id=self._ack_id, byte_size=self.size) + requests.NackRequest( + ack_id=self._ack_id, byte_size=self.size, ordering_key=self.ordering_key + ) ) diff --git a/google/cloud/pubsub_v1/types.py b/google/cloud/pubsub_v1/types.py index 2d238b42f..28019f478 100644 --- a/google/cloud/pubsub_v1/types.py +++ b/google/cloud/pubsub_v1/types.py @@ -30,6 +30,25 @@ from google.cloud.pubsub_v1.proto import pubsub_pb2 +# Define the default publisher options. +# +# This class is used when creating a publisher client to pass in options +# to enable/disable features. +PublisherOptions = collections.namedtuple( + "PublisherConfig", ["enable_message_ordering"] +) +PublisherOptions.__new__.__defaults__ = (False,) # enable_message_ordering: False + +if sys.version_info >= (3, 5): + PublisherOptions.__doc__ = "The options for the publisher client." + PublisherOptions.enable_message_ordering.__doc__ = ( + "Whether to order messages in a batch by a supplied ordering key." + "EXPERIMENTAL: Message ordering is an alpha feature that requires " + "special permissions to use. Please contact the Cloud Pub/Sub team for " + "more information." + ) + + # Define the default values for batching. # # This class is used when creating a publisher or subscriber client, and diff --git a/tests/unit/pubsub_v1/publisher/batch/test_base.py b/tests/unit/pubsub_v1/publisher/batch/test_base.py index b19a5a1f1..96f18451d 100644 --- a/tests/unit/pubsub_v1/publisher/batch/test_base.py +++ b/tests/unit/pubsub_v1/publisher/batch/test_base.py @@ -35,7 +35,7 @@ def create_batch(status=None, settings=types.BatchSettings()): """ creds = mock.Mock(spec=credentials.Credentials) client = publisher.Client(credentials=creds) - batch = Batch(client, "topic_name", settings, autocommit=False) + batch = Batch(client, "topic_name", settings) if status: batch._status = status return batch diff --git a/tests/unit/pubsub_v1/publisher/batch/test_thread.py b/tests/unit/pubsub_v1/publisher/batch/test_thread.py index f51b314af..ce288a48e 100644 --- a/tests/unit/pubsub_v1/publisher/batch/test_thread.py +++ b/tests/unit/pubsub_v1/publisher/batch/test_thread.py @@ -25,6 +25,7 @@ from google.cloud.pubsub_v1 import types from google.cloud.pubsub_v1.publisher import exceptions from google.cloud.pubsub_v1.publisher._batch.base import BatchStatus +from google.cloud.pubsub_v1.publisher._batch.base import BatchCancellationReason from google.cloud.pubsub_v1.publisher._batch import thread from google.cloud.pubsub_v1.publisher._batch.thread import Batch @@ -34,16 +35,21 @@ def create_client(): return publisher.Client(credentials=creds) -def create_batch(autocommit=False, topic="topic_name", **batch_settings): +def create_batch( + topic="topic_name", + batch_done_callback=None, + commit_when_full=True, + **batch_settings +): """Return a batch object suitable for testing. Args: - autocommit (bool): Whether the batch should commit after - ``max_latency`` seconds. By default, this is ``False`` - for unit testing. - topic (str): The name of the topic the batch should publish - the messages to. - batch_settings (dict): Arguments passed on to the + topic (str): Topic name. + batch_done_callback (Callable[bool]): A callable that is called when + the batch is done, either with a success or a failure flag. + commit_when_full (bool): Whether to commit the batch when the batch + has reached byte-size or number-of-messages limits. + batch_settings (Mapping[str, str]): Arguments passed on to the :class:``~.pubsub_v1.types.BatchSettings`` constructor. Returns: @@ -51,29 +57,13 @@ def create_batch(autocommit=False, topic="topic_name", **batch_settings): """ client = create_client() settings = types.BatchSettings(**batch_settings) - return Batch(client, topic, settings, autocommit=autocommit) - - -def test_init(): - """Establish that a monitor thread is usually created on init.""" - client = create_client() - - # Do not actually create a thread, but do verify that one was created; - # it should be running the batch's "monitor" method (which commits the - # batch once time elapses). - with mock.patch.object(threading, "Thread", autospec=True) as Thread: - batch = Batch(client, "topic_name", types.BatchSettings()) - Thread.assert_called_once_with( - name="Thread-MonitorBatchPublisher", target=batch.monitor - ) - - # New batches start able to accept messages by default. - assert batch.status == BatchStatus.ACCEPTING_MESSAGES - - -def test_init_infinite_latency(): - batch = create_batch(max_latency=float("inf")) - assert batch._thread is None + return Batch( + client, + topic, + settings, + batch_done_callback=batch_done_callback, + commit_when_full=commit_when_full, + ) @mock.patch.object(threading, "Lock") @@ -86,20 +76,18 @@ def test_make_lock(Lock): def test_client(): client = create_client() settings = types.BatchSettings() - batch = Batch(client, "topic_name", settings, autocommit=False) + batch = Batch(client, "topic_name", settings) assert batch.client is client def test_commit(): batch = create_batch() - with mock.patch.object(threading, "Thread", autospec=True) as Thread: - batch.commit() - # A thread should have been created to do the actual commit. - Thread.assert_called_once_with( - name="Thread-CommitBatchPublisher", target=batch._commit - ) - Thread.return_value.start.assert_called_once_with() + with mock.patch.object( + Batch, "_start_commit_thread", autospec=True + ) as _start_commit_thread: + batch.commit() + _start_commit_thread.assert_called_once() # The batch's status needs to be something other than "accepting messages", # since the commit started. @@ -202,7 +190,7 @@ def test_blocking__commit_already_started(_LOGGER): assert batch._status == BatchStatus.IN_PROGRESS _LOGGER.debug.assert_called_once_with( - "Batch is already in progress, exiting commit" + "Batch is already in progress or has been cancelled, exiting commit" ) @@ -273,34 +261,6 @@ def test_block__commmit_retry_error(): assert future.exception() == error -def test_monitor(): - batch = create_batch(max_latency=5.0) - with mock.patch.object(time, "sleep") as sleep: - with mock.patch.object(type(batch), "_commit") as _commit: - batch.monitor() - - # The monitor should have waited the given latency. - sleep.assert_called_once_with(5.0) - - # Since `monitor` runs in its own thread, it should call - # the blocking commit implementation. - _commit.assert_called_once_with() - - -def test_monitor_already_committed(): - batch = create_batch(max_latency=5.0) - status = "something else" - batch._status = status - with mock.patch.object(time, "sleep") as sleep: - batch.monitor() - - # The monitor should have waited the given latency. - sleep.assert_called_once_with(5.0) - - # The status should not have changed. - assert batch._status == status - - def test_publish_updating_batch_size(): batch = create_batch(topic="topic_foo") messages = ( @@ -419,3 +379,116 @@ def test_publish_dict(): ) assert batch.messages == [expected_message] assert batch._futures == [future] + + +def test_cancel(): + batch = create_batch() + futures = ( + batch.publish({"data": b"This is my message."}), + batch.publish({"data": b"This is another message."}), + ) + + batch.cancel(BatchCancellationReason.PRIOR_ORDERED_MESSAGE_FAILED) + + # Assert all futures are cancelled with an error. + for future in futures: + exc = future.exception() + assert type(exc) is RuntimeError + assert exc.args[0] == BatchCancellationReason.PRIOR_ORDERED_MESSAGE_FAILED.value + + +def test_do_not_commit_when_full_when_flag_is_off(): + max_messages = 4 + # Set commit_when_full flag to False + batch = create_batch(max_messages=max_messages, commit_when_full=False) + messages = ( + types.PubsubMessage(data=b"foobarbaz"), + types.PubsubMessage(data=b"spameggs"), + types.PubsubMessage(data=b"1335020400"), + ) + + with mock.patch.object(batch, "commit") as commit: + # Publish 3 messages. + futures = [batch.publish(message) for message in messages] + assert len(futures) == 3 + + # When a fourth message is published, commit should not be called. + future = batch.publish(types.PubsubMessage(data=b"last one")) + assert commit.call_count == 0 + assert future is None + + +class BatchDoneCallbackTracker(object): + def __init__(self): + self.called = False + self.success = None + + def __call__(self, success): + self.called = True + self.success = success + + +def test_batch_done_callback_called_on_success(): + batch_done_callback_tracker = BatchDoneCallbackTracker() + batch = create_batch(batch_done_callback=batch_done_callback_tracker) + + # Ensure messages exist. + message = types.PubsubMessage(data=b"foobarbaz") + batch.publish(message) + + # One response for one published message. + publish_response = types.PublishResponse(message_ids=["a"]) + + with mock.patch.object( + type(batch.client.api), "publish", return_value=publish_response + ): + batch._commit() + + assert batch_done_callback_tracker.called + assert batch_done_callback_tracker.success + + +def test_batch_done_callback_called_on_publish_failure(): + batch_done_callback_tracker = BatchDoneCallbackTracker() + batch = create_batch(batch_done_callback=batch_done_callback_tracker) + + # Ensure messages exist. + message = types.PubsubMessage(data=b"foobarbaz") + batch.publish(message) + + # One response for one published message. + publish_response = types.PublishResponse(message_ids=["a"]) + + # Induce publish error. + error = google.api_core.exceptions.InternalServerError("uh oh") + + with mock.patch.object( + type(batch.client.api), + "publish", + return_value=publish_response, + side_effect=error, + ): + batch._commit() + + assert batch_done_callback_tracker.called + assert not batch_done_callback_tracker.success + + +def test_batch_done_callback_called_on_publish_response_invalid(): + batch_done_callback_tracker = BatchDoneCallbackTracker() + batch = create_batch(batch_done_callback=batch_done_callback_tracker) + + # Ensure messages exist. + message = types.PubsubMessage(data=b"foobarbaz") + batch.publish(message) + + # No message ids returned in successful publish response -> invalid. + publish_response = types.PublishResponse(message_ids=[]) + + with mock.patch.object( + type(batch.client.api), "publish", return_value=publish_response + ): + batch._commit() + + assert batch_done_callback_tracker.called + assert not batch_done_callback_tracker.success diff --git a/tests/unit/pubsub_v1/publisher/test_publisher_client.py b/tests/unit/pubsub_v1/publisher/test_publisher_client.py index a06d2d0cf..4a5d4058f 100644 --- a/tests/unit/pubsub_v1/publisher/test_publisher_client.py +++ b/tests/unit/pubsub_v1/publisher/test_publisher_client.py @@ -13,16 +13,20 @@ # limitations under the License. from __future__ import absolute_import +from __future__ import division from google.auth import credentials import mock import pytest +import time from google.cloud.pubsub_v1.gapic import publisher_client from google.cloud.pubsub_v1 import publisher from google.cloud.pubsub_v1 import types +from google.cloud.pubsub_v1.publisher._sequencer import ordered_sequencer + def test_init(): creds = mock.Mock(spec=credentials.Credentials) @@ -63,40 +67,29 @@ def test_init_emulator(monkeypatch): assert channel.target().decode("utf8") == "/foo/bar/" -def test_batch_create(): +def test_message_ordering_enabled(): creds = mock.Mock(spec=credentials.Credentials) client = publisher.Client(credentials=creds) + assert not client._enable_message_ordering - assert len(client._batches) == 0 - topic = "topic/path" - batch = client._batch(topic, autocommit=False) - assert client._batches == {topic: batch} + client = publisher.Client( + publisher_options=types.PublisherOptions(enable_message_ordering=True), + credentials=creds, + ) + assert client._enable_message_ordering -def test_batch_exists(): +def test_message_ordering_changes_retry_deadline(): creds = mock.Mock(spec=credentials.Credentials) - client = publisher.Client(credentials=creds) - - topic = "topic/path" - client._batches[topic] = mock.sentinel.batch - # A subsequent request should return the same batch. - batch = client._batch(topic, autocommit=False) - assert batch is mock.sentinel.batch - assert client._batches == {topic: batch} - - -def test_batch_create_and_exists(): - creds = mock.Mock(spec=credentials.Credentials) client = publisher.Client(credentials=creds) + assert client.api._method_configs["Publish"].retry._deadline == 60 - topic = "topic/path" - client._batches[topic] = mock.sentinel.batch - - # A subsequent request should return the same batch. - batch = client._batch(topic, create=True, autocommit=False) - assert batch is not mock.sentinel.batch - assert client._batches == {topic: batch} + client = publisher.Client( + publisher_options=types.PublisherOptions(enable_message_ordering=True), + credentials=creds, + ) + assert client.api._method_configs["Publish"].retry._deadline == 2 ** 32 / 1000 def test_publish(): @@ -110,7 +103,7 @@ def test_publish(): batch.publish.side_effect = (mock.sentinel.future1, mock.sentinel.future2) topic = "topic/path" - client._batches[topic] = batch + client._set_batch(topic, batch) # Begin publishing. future1 = client.publish(topic, b"spam") @@ -138,6 +131,24 @@ def test_publish_data_not_bytestring_error(): client.publish(topic, 42) +def test_publish_message_ordering_not_enabled_error(): + creds = mock.Mock(spec=credentials.Credentials) + client = publisher.Client(credentials=creds) + topic = "topic/path" + with pytest.raises(ValueError): + client.publish(topic, b"bytestring body", ordering_key="ABC") + + +def test_publish_empty_ordering_key_when_message_ordering_enabled(): + creds = mock.Mock(spec=credentials.Credentials) + client = publisher.Client( + publisher_options=types.PublisherOptions(enable_message_ordering=True), + credentials=creds, + ) + topic = "topic/path" + assert client.publish(topic, b"bytestring body", ordering_key="") is not None + + def test_publish_attrs_bytestring(): creds = mock.Mock(spec=credentials.Credentials) client = publisher.Client(credentials=creds) @@ -148,7 +159,7 @@ def test_publish_attrs_bytestring(): batch.will_accept.return_value = True topic = "topic/path" - client._batches[topic] = batch + client._set_batch(topic, batch) # Begin publishing. future = client.publish(topic, b"foo", bar=b"baz") @@ -174,11 +185,11 @@ def test_publish_new_batch_needed(): batch2.publish.return_value = mock.sentinel.future topic = "topic/path" - client._batches[topic] = batch1 + client._set_batch(topic, batch1) # Actually mock the batch class now. batch_class = mock.Mock(spec=(), return_value=batch2) - client._batch_class = batch_class + client._set_batch_class(batch_class) # Publish a message. future = client.publish(topic, b"foo", bar=b"baz") @@ -186,7 +197,11 @@ def test_publish_new_batch_needed(): # Check the mocks. batch_class.assert_called_once_with( - autocommit=True, client=client, settings=client.batch_settings, topic=topic + client=mock.ANY, + topic=topic, + settings=client.batch_settings, + batch_done_callback=None, + commit_when_full=True, ) message_pb = types.PubsubMessage(data=b"foo", attributes={"bar": u"baz"}) batch1.publish.assert_called_once_with(message_pb) @@ -205,28 +220,20 @@ def test_stop(): creds = mock.Mock(spec=credentials.Credentials) client = publisher.Client(credentials=creds) - batch = client._batch("topic1", autocommit=False) - batch2 = client._batch("topic2", autocommit=False) - - pubsub_msg = types.PubsubMessage(data=b"msg") - - patch = mock.patch.object(batch, "commit") - patch2 = mock.patch.object(batch2, "commit") + batch1 = mock.Mock(spec=client._batch_class) + topic = "topic/path" + client._set_batch(topic, batch1) - with patch as commit_mock, patch2 as commit_mock2: - batch.publish(pubsub_msg) - batch2.publish(pubsub_msg) + client.stop() - client.stop() + assert batch1.commit.call_count == 1 - # check if commit() called - commit_mock.assert_called() - commit_mock2.assert_called() - - # check that closed publisher doesn't accept new messages with pytest.raises(RuntimeError): client.publish("topic1", b"msg2") + with pytest.raises(RuntimeError): + client.resume_publish("topic", "ord_key") + with pytest.raises(RuntimeError): client.stop() @@ -265,3 +272,181 @@ def test_gapic_class_method_on_instance(): client = publisher.Client(credentials=creds) answer = client.topic_path("foo", "bar") assert answer == "projects/foo/topics/bar" + + +def test_commit_thread_created_on_publish(): + creds = mock.Mock(spec=credentials.Credentials) + # Max latency is not infinite so a commit thread is created. + batch_settings = types.BatchSettings(max_latency=600) + client = publisher.Client(batch_settings=batch_settings, credentials=creds) + + with mock.patch.object( + client, "_start_commit_thread", autospec=True + ) as _start_commit_thread: + # First publish should create a commit thread. + assert client.publish("topic", b"bytestring body", ordering_key="") is not None + _start_commit_thread.assert_called_once() + + # Since _start_commit_thread is a mock, no actual thread has been + # created, so let's put a sentinel there to mimic real behavior. + client._commit_thread = mock.Mock() + + # Second publish should not create a commit thread since one (the mock) + # already exists. + assert client.publish("topic", b"bytestring body", ordering_key="") is not None + # Call count should remain 1. + _start_commit_thread.assert_called_once() + + +def test_commit_thread_not_created_on_publish_if_max_latency_is_inf(): + creds = mock.Mock(spec=credentials.Credentials) + # Max latency is infinite so a commit thread is not created. + batch_settings = types.BatchSettings(max_latency=float("inf")) + client = publisher.Client(batch_settings=batch_settings, credentials=creds) + + assert client.publish("topic", b"bytestring body", ordering_key="") is not None + assert client._commit_thread is None + + +def test_wait_and_commit_sequencers(): + creds = mock.Mock(spec=credentials.Credentials) + # Max latency is infinite so a commit thread is not created. + # We don't want a commit thread to interfere with this test. + batch_settings = types.BatchSettings(max_latency=float("inf")) + client = publisher.Client(batch_settings=batch_settings, credentials=creds) + + # Mock out time so no sleep is actually done. + with mock.patch.object(time, "sleep"): + with mock.patch.object( + publisher.Client, "_commit_sequencers" + ) as _commit_sequencers: + assert ( + client.publish("topic", b"bytestring body", ordering_key="") is not None + ) + # Call _wait_and_commit_sequencers to simulate what would happen if a + # commit thread actually ran. + client._wait_and_commit_sequencers() + assert _commit_sequencers.call_count == 1 + + +def test_stopped_client_does_not_commit_sequencers(): + creds = mock.Mock(spec=credentials.Credentials) + # Max latency is infinite so a commit thread is not created. + # We don't want a commit thread to interfere with this test. + batch_settings = types.BatchSettings(max_latency=float("inf")) + client = publisher.Client(batch_settings=batch_settings, credentials=creds) + + # Mock out time so no sleep is actually done. + with mock.patch.object(time, "sleep"): + with mock.patch.object( + publisher.Client, "_commit_sequencers" + ) as _commit_sequencers: + assert ( + client.publish("topic", b"bytestring body", ordering_key="") is not None + ) + + client.stop() + + # Call _wait_and_commit_sequencers to simulate what would happen if a + # commit thread actually ran after the client was stopped. + client._wait_and_commit_sequencers() + # Should not be called since Client is stopped. + assert _commit_sequencers.call_count == 0 + + +def test_publish_with_ordering_key(): + creds = mock.Mock(spec=credentials.Credentials) + publisher_options = types.PublisherOptions(enable_message_ordering=True) + client = publisher.Client(publisher_options, credentials=creds) + + # Use a mock in lieu of the actual batch class. + batch = mock.Mock(spec=client._batch_class) + # Set the mock up to claim indiscriminately that it accepts all messages. + batch.will_accept.return_value = True + batch.publish.side_effect = (mock.sentinel.future1, mock.sentinel.future2) + + topic = "topic/path" + ordering_key = "k1" + client._set_batch(topic, batch, ordering_key=ordering_key) + + # Begin publishing. + future1 = client.publish(topic, b"spam", ordering_key=ordering_key) + future2 = client.publish(topic, b"foo", bar="baz", ordering_key=ordering_key) + + assert future1 is mock.sentinel.future1 + assert future2 is mock.sentinel.future2 + + # Check mock. + batch.publish.assert_has_calls( + [ + mock.call(types.PubsubMessage(data=b"spam", ordering_key="k1")), + mock.call( + types.PubsubMessage( + data=b"foo", attributes={"bar": "baz"}, ordering_key="k1" + ) + ), + ] + ) + + +def test_ordered_sequencer_cleaned_up(): + creds = mock.Mock(spec=credentials.Credentials) + # Max latency is infinite so a commit thread is not created. + # We don't want a commit thread to interfere with this test. + batch_settings = types.BatchSettings(max_latency=float("inf")) + publisher_options = types.PublisherOptions(enable_message_ordering=True) + client = publisher.Client( + publisher_options=publisher_options, + batch_settings=batch_settings, + credentials=creds, + ) + + topic = "topic" + ordering_key = "ord_key" + sequencer = mock.Mock(spec=ordered_sequencer.OrderedSequencer) + sequencer.is_finished.return_value = False + client._set_sequencer(topic=topic, sequencer=sequencer, ordering_key=ordering_key) + + assert len(client._sequencers) == 1 + # 'sequencer' is not finished yet so don't remove it. + client._commit_sequencers() + assert len(client._sequencers) == 1 + + sequencer.is_finished.return_value = True + # 'sequencer' is finished so remove it. + client._commit_sequencers() + assert len(client._sequencers) == 0 + + +def test_resume_publish(): + creds = mock.Mock(spec=credentials.Credentials) + publisher_options = types.PublisherOptions(enable_message_ordering=True) + client = publisher.Client(publisher_options, credentials=creds) + + topic = "topic" + ordering_key = "ord_key" + sequencer = mock.Mock(spec=ordered_sequencer.OrderedSequencer) + client._set_sequencer(topic=topic, sequencer=sequencer, ordering_key=ordering_key) + + client.resume_publish(topic, ordering_key) + assert sequencer.unpause.called_once() + + +def test_resume_publish_no_sequencer_found(): + creds = mock.Mock(spec=credentials.Credentials) + publisher_options = types.PublisherOptions(enable_message_ordering=True) + client = publisher.Client(publisher_options, credentials=creds) + + # Check no exception is thrown if a sequencer with the (topic, ordering_key) + # pair does not exist. + client.resume_publish("topic", "ord_key") + + +def test_resume_publish_ordering_keys_not_enabled(): + creds = mock.Mock(spec=credentials.Credentials) + publisher_options = types.PublisherOptions(enable_message_ordering=False) + client = publisher.Client(publisher_options, credentials=creds) + + # Throw on calling resume_publish() when enable_message_ordering is False. + with pytest.raises(ValueError): + client.resume_publish("topic", "ord_key") diff --git a/tests/unit/pubsub_v1/subscriber/test_dispatcher.py b/tests/unit/pubsub_v1/subscriber/test_dispatcher.py index 592a03c64..43822e96e 100644 --- a/tests/unit/pubsub_v1/subscriber/test_dispatcher.py +++ b/tests/unit/pubsub_v1/subscriber/test_dispatcher.py @@ -29,11 +29,11 @@ @pytest.mark.parametrize( "item,method_name", [ - (requests.AckRequest(0, 0, 0), "ack"), - (requests.DropRequest(0, 0), "drop"), - (requests.LeaseRequest(0, 0), "lease"), + (requests.AckRequest(0, 0, 0, ""), "ack"), + (requests.DropRequest(0, 0, ""), "drop"), + (requests.LeaseRequest(0, 0, ""), "lease"), (requests.ModAckRequest(0, 0), "modify_ack_deadline"), - (requests.NackRequest(0, 0), "nack"), + (requests.NackRequest(0, 0, ""), "nack"), ], ) def test_dispatch_callback(item, method_name): @@ -57,7 +57,7 @@ def test_dispatch_callback_inactive(): manager.is_active = False dispatcher_ = dispatcher.Dispatcher(manager, mock.sentinel.queue) - dispatcher_.dispatch_callback([requests.AckRequest(0, 0, 0)]) + dispatcher_.dispatch_callback([requests.AckRequest(0, 0, 0, "")]) manager.send.assert_not_called() @@ -68,7 +68,11 @@ def test_ack(): ) dispatcher_ = dispatcher.Dispatcher(manager, mock.sentinel.queue) - items = [requests.AckRequest(ack_id="ack_id_string", byte_size=0, time_to_ack=20)] + items = [ + requests.AckRequest( + ack_id="ack_id_string", byte_size=0, time_to_ack=20, ordering_key="" + ) + ] dispatcher_.ack(items) manager.send.assert_called_once_with( @@ -86,7 +90,11 @@ def test_ack_no_time(): ) dispatcher_ = dispatcher.Dispatcher(manager, mock.sentinel.queue) - items = [requests.AckRequest(ack_id="ack_id_string", byte_size=0, time_to_ack=None)] + items = [ + requests.AckRequest( + ack_id="ack_id_string", byte_size=0, time_to_ack=None, ordering_key="" + ) + ] dispatcher_.ack(items) manager.send.assert_called_once_with( @@ -104,7 +112,9 @@ def test_ack_splitting_large_payload(): items = [ # use realistic lengths for ACK IDs (max 176 bytes) - requests.AckRequest(ack_id=str(i).zfill(176), byte_size=0, time_to_ack=20) + requests.AckRequest( + ack_id=str(i).zfill(176), byte_size=0, time_to_ack=20, ordering_key="" + ) for i in range(5001) ] dispatcher_.ack(items) @@ -130,23 +140,46 @@ def test_lease(): ) dispatcher_ = dispatcher.Dispatcher(manager, mock.sentinel.queue) - items = [requests.LeaseRequest(ack_id="ack_id_string", byte_size=10)] + items = [ + requests.LeaseRequest(ack_id="ack_id_string", byte_size=10, ordering_key="") + ] dispatcher_.lease(items) manager.leaser.add.assert_called_once_with(items) manager.maybe_pause_consumer.assert_called_once() -def test_drop(): +def test_drop_unordered_messages(): + manager = mock.create_autospec( + streaming_pull_manager.StreamingPullManager, instance=True + ) + dispatcher_ = dispatcher.Dispatcher(manager, mock.sentinel.queue) + + items = [ + requests.DropRequest(ack_id="ack_id_string", byte_size=10, ordering_key="") + ] + dispatcher_.drop(items) + + manager.leaser.remove.assert_called_once_with(items) + assert list(manager.activate_ordering_keys.call_args.args[0]) == [] + manager.maybe_resume_consumer.assert_called_once() + + +def test_drop_ordered_messages(): manager = mock.create_autospec( streaming_pull_manager.StreamingPullManager, instance=True ) dispatcher_ = dispatcher.Dispatcher(manager, mock.sentinel.queue) - items = [requests.DropRequest(ack_id="ack_id_string", byte_size=10)] + items = [ + requests.DropRequest(ack_id="ack_id_string", byte_size=10, ordering_key=""), + requests.DropRequest(ack_id="ack_id_string", byte_size=10, ordering_key="key1"), + requests.DropRequest(ack_id="ack_id_string", byte_size=10, ordering_key="key2"), + ] dispatcher_.drop(items) manager.leaser.remove.assert_called_once_with(items) + assert list(manager.activate_ordering_keys.call_args.args[0]) == ["key1", "key2"] manager.maybe_resume_consumer.assert_called_once() @@ -156,7 +189,9 @@ def test_nack(): ) dispatcher_ = dispatcher.Dispatcher(manager, mock.sentinel.queue) - items = [requests.NackRequest(ack_id="ack_id_string", byte_size=10)] + items = [ + requests.NackRequest(ack_id="ack_id_string", byte_size=10, ordering_key="") + ] dispatcher_.nack(items) manager.send.assert_called_once_with( diff --git a/tests/unit/pubsub_v1/subscriber/test_leaser.py b/tests/unit/pubsub_v1/subscriber/test_leaser.py index c8b217473..ec954b89d 100644 --- a/tests/unit/pubsub_v1/subscriber/test_leaser.py +++ b/tests/unit/pubsub_v1/subscriber/test_leaser.py @@ -29,14 +29,14 @@ def test_add_and_remove(): leaser_ = leaser.Leaser(mock.sentinel.manager) - leaser_.add([requests.LeaseRequest(ack_id="ack1", byte_size=50)]) - leaser_.add([requests.LeaseRequest(ack_id="ack2", byte_size=25)]) + leaser_.add([requests.LeaseRequest(ack_id="ack1", byte_size=50, ordering_key="")]) + leaser_.add([requests.LeaseRequest(ack_id="ack2", byte_size=25, ordering_key="")]) assert leaser_.message_count == 2 assert set(leaser_.ack_ids) == set(["ack1", "ack2"]) assert leaser_.bytes == 75 - leaser_.remove([requests.DropRequest(ack_id="ack1", byte_size=50)]) + leaser_.remove([requests.DropRequest(ack_id="ack1", byte_size=50, ordering_key="")]) assert leaser_.message_count == 1 assert set(leaser_.ack_ids) == set(["ack2"]) @@ -48,8 +48,8 @@ def test_add_already_managed(caplog): leaser_ = leaser.Leaser(mock.sentinel.manager) - leaser_.add([requests.LeaseRequest(ack_id="ack1", byte_size=50)]) - leaser_.add([requests.LeaseRequest(ack_id="ack1", byte_size=50)]) + leaser_.add([requests.LeaseRequest(ack_id="ack1", byte_size=50, ordering_key="")]) + leaser_.add([requests.LeaseRequest(ack_id="ack1", byte_size=50, ordering_key="")]) assert "already lease managed" in caplog.text @@ -59,7 +59,7 @@ def test_remove_not_managed(caplog): leaser_ = leaser.Leaser(mock.sentinel.manager) - leaser_.remove([requests.DropRequest(ack_id="ack1", byte_size=50)]) + leaser_.remove([requests.DropRequest(ack_id="ack1", byte_size=50, ordering_key="")]) assert "not managed" in caplog.text @@ -69,8 +69,8 @@ def test_remove_negative_bytes(caplog): leaser_ = leaser.Leaser(mock.sentinel.manager) - leaser_.add([requests.LeaseRequest(ack_id="ack1", byte_size=50)]) - leaser_.remove([requests.DropRequest(ack_id="ack1", byte_size=75)]) + leaser_.add([requests.LeaseRequest(ack_id="ack1", byte_size=50, ordering_key="")]) + leaser_.remove([requests.DropRequest(ack_id="ack1", byte_size=75, ordering_key="")]) assert leaser_.bytes == 0 assert "unexpectedly negative" in caplog.text @@ -125,7 +125,9 @@ def test_maintain_leases_ack_ids(): manager = create_manager() leaser_ = leaser.Leaser(manager) make_sleep_mark_manager_as_inactive(leaser_) - leaser_.add([requests.LeaseRequest(ack_id="my ack id", byte_size=50)]) + leaser_.add( + [requests.LeaseRequest(ack_id="my ack id", byte_size=50, ordering_key="")] + ) leaser_.maintain_leases() @@ -150,28 +152,52 @@ def test_maintain_leases_outdated_items(time): leaser_ = leaser.Leaser(manager) make_sleep_mark_manager_as_inactive(leaser_) - # Add these items at the beginning of the timeline + # Add and start expiry timer at the beginning of the timeline. time.return_value = 0 - leaser_.add([requests.LeaseRequest(ack_id="ack1", byte_size=50)]) + leaser_.add([requests.LeaseRequest(ack_id="ack1", byte_size=50, ordering_key="")]) + leaser_.start_lease_expiry_timer(["ack1"]) + + # Add a message but don't start the lease expiry timer. + leaser_.add([requests.LeaseRequest(ack_id="ack2", byte_size=50, ordering_key="")]) - # Add another item at towards end of the timeline + # Add a message and start expiry timer towards the end of the timeline. time.return_value = manager.flow_control.max_lease_duration - 1 - leaser_.add([requests.LeaseRequest(ack_id="ack2", byte_size=50)]) + leaser_.add([requests.LeaseRequest(ack_id="ack3", byte_size=50, ordering_key="")]) + leaser_.start_lease_expiry_timer(["ack3"]) + + # Add a message towards the end of the timeline, but DO NOT start expiry + # timer. + leaser_.add([requests.LeaseRequest(ack_id="ack4", byte_size=50, ordering_key="")]) - # Now make sure time reports that we are at the end of our timeline. + # Now make sure time reports that we are past the end of our timeline. time.return_value = manager.flow_control.max_lease_duration + 1 leaser_.maintain_leases() - # Only ack2 should be renewed. ack1 should've been dropped - manager.dispatcher.modify_ack_deadline.assert_called_once_with( - [requests.ModAckRequest(ack_id="ack2", seconds=10)] - ) + # ack2, ack3, and ack4 should be renewed. ack1 should've been dropped + modacks = manager.dispatcher.modify_ack_deadline.call_args.args[0] + expected = [ + requests.ModAckRequest(ack_id="ack2", seconds=10), + requests.ModAckRequest(ack_id="ack3", seconds=10), + requests.ModAckRequest(ack_id="ack4", seconds=10), + ] + # Use sorting to allow for ordering variance. + assert sorted(modacks) == sorted(expected) + manager.dispatcher.drop.assert_called_once_with( - [requests.DropRequest(ack_id="ack1", byte_size=50)] + [requests.DropRequest(ack_id="ack1", byte_size=50, ordering_key="")] ) +def test_start_lease_expiry_timer_unknown_ack_id(): + manager = create_manager() + leaser_ = leaser.Leaser(manager) + + # Nothing happens when this method is called with an ack-id that hasn't been + # added yet. + leaser_.start_lease_expiry_timer(["ack1"]) + + @mock.patch("threading.Thread", autospec=True) def test_start(thread): manager = mock.create_autospec( diff --git a/tests/unit/pubsub_v1/subscriber/test_message.py b/tests/unit/pubsub_v1/subscriber/test_message.py index fd23deef0..0c8a6d181 100644 --- a/tests/unit/pubsub_v1/subscriber/test_message.py +++ b/tests/unit/pubsub_v1/subscriber/test_message.py @@ -33,7 +33,7 @@ PUBLISHED_SECONDS = datetime_helpers.to_milliseconds(PUBLISHED) // 1000 -def create_message(data, ack_id="ACKID", delivery_attempt=0, **attrs): +def create_message(data, ack_id="ACKID", delivery_attempt=0, ordering_key="", **attrs): with mock.patch.object(time, "time") as time_: time_.return_value = RECEIVED_SECONDS msg = message.Message( @@ -44,6 +44,7 @@ def create_message(data, ack_id="ACKID", delivery_attempt=0, **attrs): publish_time=timestamp_pb2.Timestamp( seconds=PUBLISHED_SECONDS, nanos=PUBLISHED_MICROS * 1000 ), + ordering_key=ordering_key, ), ack_id=ack_id, delivery_attempt=delivery_attempt, @@ -89,6 +90,11 @@ def test_publish_time(): assert msg.publish_time == PUBLISHED +def test_ordering_key(): + msg = create_message(b"foo", ordering_key="key1") + assert msg.ordering_key == "key1" + + def check_call_types(mock, *args, **kwargs): """Checks a mock's call types. @@ -118,7 +124,10 @@ def test_ack(): msg.ack() put.assert_called_once_with( requests.AckRequest( - ack_id="bogus_ack_id", byte_size=30, time_to_ack=mock.ANY + ack_id="bogus_ack_id", + byte_size=30, + time_to_ack=mock.ANY, + ordering_key="", ) ) check_call_types(put, requests.AckRequest) @@ -129,7 +138,7 @@ def test_drop(): with mock.patch.object(msg._request_queue, "put") as put: msg.drop() put.assert_called_once_with( - requests.DropRequest(ack_id="bogus_ack_id", byte_size=30) + requests.DropRequest(ack_id="bogus_ack_id", byte_size=30, ordering_key="") ) check_call_types(put, requests.DropRequest) @@ -149,19 +158,22 @@ def test_nack(): with mock.patch.object(msg._request_queue, "put") as put: msg.nack() put.assert_called_once_with( - requests.NackRequest(ack_id="bogus_ack_id", byte_size=30) + requests.NackRequest(ack_id="bogus_ack_id", byte_size=30, ordering_key="") ) check_call_types(put, requests.NackRequest) def test_repr(): data = b"foo" - msg = create_message(data, snow="cones", orange="juice") + ordering_key = "ord_key" + msg = create_message(data, ordering_key=ordering_key, snow="cones", orange="juice") data_line = " data: {!r}".format(data) + ordering_key_line = " ordering_key: {!r}".format(ordering_key) expected_repr = "\n".join( ( "Message {", data_line, + ordering_key_line, " attributes: {", ' "orange": "juice",', ' "snow": "cones"', diff --git a/tests/unit/pubsub_v1/subscriber/test_streaming_pull_manager.py b/tests/unit/pubsub_v1/subscriber/test_streaming_pull_manager.py index 6f8a04ac9..0886a4508 100644 --- a/tests/unit/pubsub_v1/subscriber/test_streaming_pull_manager.py +++ b/tests/unit/pubsub_v1/subscriber/test_streaming_pull_manager.py @@ -19,7 +19,6 @@ import mock import pytest -from six.moves import queue from google.api_core import bidi from google.api_core import exceptions @@ -31,6 +30,7 @@ from google.cloud.pubsub_v1.subscriber._protocol import dispatcher from google.cloud.pubsub_v1.subscriber._protocol import heartbeater from google.cloud.pubsub_v1.subscriber._protocol import leaser +from google.cloud.pubsub_v1.subscriber._protocol import messages_on_hold from google.cloud.pubsub_v1.subscriber._protocol import requests from google.cloud.pubsub_v1.subscriber._protocol import streaming_pull_manager import grpc @@ -95,6 +95,7 @@ def test_constructor_and_default_state(): assert manager._client == mock.sentinel.client assert manager._subscription == mock.sentinel.subscription assert manager._scheduler is not None + assert manager._messages_on_hold is not None def test_constructor_with_options(): @@ -166,18 +167,24 @@ def test_lease_load_and_pause(): # This should mean that our messages count is at 10%, and our bytes # are at 15%; load should return the higher (0.15), and shouldn't cause # the consumer to pause. - manager.leaser.add([requests.LeaseRequest(ack_id="one", byte_size=150)]) + manager.leaser.add( + [requests.LeaseRequest(ack_id="one", byte_size=150, ordering_key="")] + ) assert manager.load == 0.15 manager.maybe_pause_consumer() manager._consumer.pause.assert_not_called() # After this message is added, the messages should be higher at 20% # (versus 16% for bytes). - manager.leaser.add([requests.LeaseRequest(ack_id="two", byte_size=10)]) + manager.leaser.add( + [requests.LeaseRequest(ack_id="two", byte_size=10, ordering_key="")] + ) assert manager.load == 0.2 # Returning a number above 100% is fine, and it should cause this to pause. - manager.leaser.add([requests.LeaseRequest(ack_id="three", byte_size=1000)]) + manager.leaser.add( + [requests.LeaseRequest(ack_id="three", byte_size=1000, ordering_key="")] + ) assert manager.load == 1.16 manager.maybe_pause_consumer() manager._consumer.pause.assert_called_once() @@ -194,8 +201,8 @@ def test_drop_and_resume(): # Add several messages until we're over the load threshold. manager.leaser.add( [ - requests.LeaseRequest(ack_id="one", byte_size=750), - requests.LeaseRequest(ack_id="two", byte_size=250), + requests.LeaseRequest(ack_id="one", byte_size=750, ordering_key=""), + requests.LeaseRequest(ack_id="two", byte_size=250, ordering_key=""), ] ) @@ -207,7 +214,9 @@ def test_drop_and_resume(): # Drop the 200 byte message, which should put us under the resume # threshold. - manager.leaser.remove([requests.DropRequest(ack_id="two", byte_size=250)]) + manager.leaser.remove( + [requests.DropRequest(ack_id="two", byte_size=250, ordering_key="")] + ) manager.maybe_resume_consumer() manager._consumer.resume.assert_called_once() @@ -245,7 +254,7 @@ def test__maybe_release_messages_on_overload(): manager._maybe_release_messages() - assert manager._messages_on_hold.qsize() == 1 + assert manager._messages_on_hold.size == 1 manager._leaser.add.assert_not_called() manager._scheduler.schedule.assert_not_called() @@ -274,8 +283,8 @@ def test__maybe_release_messages_below_overload(): # the actual call of MUT manager._maybe_release_messages() - assert manager._messages_on_hold.qsize() == 1 - msg = manager._messages_on_hold.get_nowait() + assert manager._messages_on_hold.size == 1 + msg = manager._messages_on_hold.get() assert msg.ack_id == "ack_baz" schedule_calls = manager._scheduler.schedule.mock_calls @@ -692,7 +701,7 @@ def test__on_response_no_leaser_overload(): assert isinstance(call[1][1], message.Message) # the leaser load limit not hit, no messages had to be put on hold - assert manager._messages_on_hold.qsize() == 0 + assert manager._messages_on_hold.size == 0 def test__on_response_with_leaser_overload(): @@ -741,11 +750,10 @@ def test__on_response_with_leaser_overload(): assert call_args[1].message_id == "1" # the rest of the messages should have been put on hold - assert manager._messages_on_hold.qsize() == 2 + assert manager._messages_on_hold.size == 2 while True: - try: - msg = manager._messages_on_hold.get_nowait() - except queue.Empty: + msg = manager._messages_on_hold.get() + if msg is None: break else: assert isinstance(msg, message.Message) @@ -767,6 +775,87 @@ def test__on_response_none_data(caplog): assert "callback invoked with None" in caplog.text +def test__on_response_with_ordering_keys(): + manager, _, dispatcher, leaser, _, scheduler = make_running_manager() + manager._callback = mock.sentinel.callback + + # Set up the messages. + response = types.StreamingPullResponse( + received_messages=[ + types.ReceivedMessage( + ack_id="fack", + message=types.PubsubMessage( + data=b"foo", message_id="1", ordering_key="" + ), + ), + types.ReceivedMessage( + ack_id="back", + message=types.PubsubMessage( + data=b"bar", message_id="2", ordering_key="key1" + ), + ), + types.ReceivedMessage( + ack_id="zack", + message=types.PubsubMessage( + data=b"baz", message_id="3", ordering_key="key1" + ), + ), + ] + ) + + # Make leaser with zero initial messages, so we don't test lease management + # behavior. + fake_leaser_add(leaser, init_msg_count=0, assumed_msg_size=10) + + # Actually run the method and prove that modack and schedule are called in + # the expected way. + manager._on_response(response) + + # All messages should be added to the lease management and have their ACK + # deadline extended, even those not dispatched to callbacks. + dispatcher.modify_ack_deadline.assert_called_once_with( + [ + requests.ModAckRequest("fack", 10), + requests.ModAckRequest("back", 10), + requests.ModAckRequest("zack", 10), + ] + ) + + # The first two messages should be scheduled, The third should be put on + # hold because it's blocked by the completion of the second, which has the + # same ordering key. + schedule_calls = scheduler.schedule.mock_calls + assert len(schedule_calls) == 2 + call_args = schedule_calls[0][1] + assert call_args[0] == mock.sentinel.callback + assert isinstance(call_args[1], message.Message) + assert call_args[1].message_id == "1" + + call_args = schedule_calls[1][1] + assert call_args[0] == mock.sentinel.callback + assert isinstance(call_args[1], message.Message) + assert call_args[1].message_id == "2" + + # Message 3 should have been put on hold. + assert manager._messages_on_hold.size == 1 + # No messages available because message 2 (with "key1") has not completed yet. + assert manager._messages_on_hold.get() is None + + # Complete message 2 (with "key1"). + manager.activate_ordering_keys(["key1"]) + + # Completing message 2 should release message 3. + schedule_calls = scheduler.schedule.mock_calls + assert len(schedule_calls) == 3 + call_args = schedule_calls[2][1] + assert call_args[0] == mock.sentinel.callback + assert isinstance(call_args[1], message.Message) + assert call_args[1].message_id == "3" + + # No messages available in the queue. + assert manager._messages_on_hold.get() is None + + def test_retryable_stream_errors(): # Make sure the config matches our hard-coded tuple of exceptions. interfaces = subscriber_client_config.config["interfaces"] @@ -824,3 +913,16 @@ def test__on_rpc_done(thread): thread.assert_called_once_with( name=mock.ANY, target=manager.close, kwargs={"reason": mock.sentinel.error} ) + + +def test_activate_ordering_keys(): + manager = make_manager() + manager._messages_on_hold = mock.create_autospec( + messages_on_hold.MessagesOnHold, instance=True + ) + + manager.activate_ordering_keys(["key1", "key2"]) + + manager._messages_on_hold.activate_ordering_keys.assert_called_once_with( + ["key1", "key2"], mock.ANY + )