From bad077e4592866b643cc90a897b3aeebaae3097b Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 22 Oct 2021 14:47:45 +0200 Subject: [PATCH 1/3] Do not drop BatchedSend payload if worker reconnects --- distributed/batched.py | 13 +++++++--- distributed/tests/test_batched.py | 22 ++++++++++++++++ distributed/tests/test_worker.py | 43 +++++++++++++++++++++++++++---- distributed/worker.py | 1 + 4 files changed, 70 insertions(+), 9 deletions(-) diff --git a/distributed/batched.py b/distributed/batched.py index 960f4fa828e..ea98c8a48ff 100644 --- a/distributed/batched.py +++ b/distributed/batched.py @@ -58,6 +58,7 @@ def __init__(self, interval, loop=None, serializers=None): def start(self, comm): self.comm = comm + self.please_stop = False self.loop.add_callback(self._background_send) def closed(self): @@ -98,6 +99,7 @@ def _background_send(self): else: self.recent_message_log.append("large-message") self.byte_count += nbytes + payload = [] # lose ref except CommClosedError: logger.info("Batched Comm Closed %r", self.comm, exc_info=True) break @@ -111,7 +113,9 @@ def _background_send(self): logger.exception("Error in batched write") break finally: - payload = None # lose ref + # If anything failed we should not loose payload. If a new comm + # is provided we can still resubmit messages + self.buffer = payload + self.buffer else: # nobreak. We've been gracefully closed. self.stopped.set() @@ -121,7 +125,6 @@ def _background_send(self): # there was an exception when using `comm`. # We can't close gracefully via `.close()` since we can't send messages. # So we just abort. - # This means that any messages in our buffer our lost. # To propagate exceptions, we rely on subsequent `BatchedSend.send` # calls to raise CommClosedErrors. self.stopped.set() @@ -152,6 +155,7 @@ def close(self, timeout=None): self.please_stop = True self.waker.set() yield self.stopped.wait(timeout=timeout) + payload = [] if not self.comm.closed(): try: if self.buffer: @@ -160,14 +164,15 @@ def close(self, timeout=None): payload, serializers=self.serializers, on_error="raise" ) except CommClosedError: - pass + # If we're closing and there is an error there is little we + # can do about this to recover. + logger.error("Lost %i payload messages.", len(payload)) yield self.comm.close() def abort(self): if self.comm is None: return self.please_stop = True - self.buffer = [] self.waker.set() if not self.comm.closed(): self.comm.abort() diff --git a/distributed/tests/test_batched.py b/distributed/tests/test_batched.py index ee84ec3224e..296be3c60fb 100644 --- a/distributed/tests/test_batched.py +++ b/distributed/tests/test_batched.py @@ -249,3 +249,25 @@ async def test_serializers(): assert "function" in value assert comm.closed() + + +@pytest.mark.asyncio +async def test_retain_buffer_commclosed(): + async with EchoServer() as e: + with captured_logger("distributed.batched") as caplog: + comm = await connect(e.address) + + b = BatchedSend(interval="1s", serializers=["msgpack"]) + b.start(comm) + b.send("foo") + assert b.buffer + await comm.close() + await asyncio.sleep(1) + + assert "Batched Comm Closed" in caplog.getvalue() + assert b.buffer + + new_comm = await connect(e.address) + b.start(new_comm) + assert await new_comm.read() == ("foo",) + assert not b.buffer diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 63d34015914..84d8fd7a20f 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -2568,11 +2568,9 @@ def fast_on_a(lock): assert "Unexpected worker completed task" in s_logs.getvalue() - # Ensure that all in-memory tasks on A have been restored on the - # scheduler after reconnect - for ts in a.tasks.values(): - if ts.state == "memory": - assert a.address in {ws.address for ws in s.tasks[ts.key].who_has} + sts = s.tasks[f3.key] + assert sts.state == "memory" + assert s.workers[a.address] in sts.who_has del f1, f2, f3 while any(w.tasks for w in [a, b]): @@ -3196,3 +3194,38 @@ async def test_deadlock_cancelled_after_inflight_before_gather_from_worker( args, kwargs = mocked_gather.call_args await Worker.gather_dep(b, *args, **kwargs) await fut3 + + +@gen_cluster(nthreads=[("", 1)]) +async def test_dont_loose_payload_reconnect(s, w): + """Ensure that payload of a BatchedSend is not lost if a worker reconnects""" + s.count = 0 + + def receive(worker, msg): + s.count += 1 + + s.stream_handlers["receive-msg"] = receive + w.batched_stream.next_deadline = w.loop.time() + 10_000 + + for x in range(100): + w.batched_stream.send({"op": "receive-msg", "msg": x}) + + await s.stream_comms[w.address].comm.close() + while not w.batched_stream.comm.closed(): + await asyncio.sleep(0.1) + before = w.batched_stream.buffer.copy() + w.batched_stream.next_deadline = w.loop.time() + assert len(w.batched_stream.buffer) == 100 + with captured_logger("distributed.batched") as caplog: + await w.batched_stream._background_send() + + assert "Batched Comm Closed" in caplog.getvalue() + after = w.batched_stream.buffer.copy() + + # Payload that couldn't be submitted is prepended + assert len(after) >= len(before) + assert after[: len(before)] == before + + await w.heartbeat() + while not s.count == 100: + await asyncio.sleep(0.1) diff --git a/distributed/worker.py b/distributed/worker.py index af44f1f6609..833d23da0d8 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1238,6 +1238,7 @@ async def handle_scheduler(self, comm): comm, every_cycle=[self.ensure_communicating, self.ensure_computing] ) except Exception as e: + self.batched_stream.please_stop = True logger.exception(e) raise finally: From 7adc0c0179ed7b1c4f15fcf3914c4b9aa2a61855 Mon Sep 17 00:00:00 2001 From: fjetter Date: Thu, 4 Nov 2021 11:31:38 +0100 Subject: [PATCH 2/3] some docs and tests around batched --- distributed/tests/test_batched.py | 128 +++++++++++++++++++++++++++++- distributed/worker.py | 1 - 2 files changed, 126 insertions(+), 3 deletions(-) diff --git a/distributed/tests/test_batched.py b/distributed/tests/test_batched.py index 296be3c60fb..1f023c1515e 100644 --- a/distributed/tests/test_batched.py +++ b/distributed/tests/test_batched.py @@ -43,9 +43,11 @@ async def test_BatchedSend(): comm = await connect(e.address) b = BatchedSend(interval=10) + assert "" == str(b) + assert "" == repr(b) + b.start(comm) assert str(len(b.buffer)) in str(b) assert str(len(b.buffer)) in repr(b) - b.start(comm) await asyncio.sleep(0.020) @@ -79,6 +81,125 @@ async def test_send_before_start(): assert result == ("hello", "world") +@pytest.mark.asyncio +async def test_closed_if_not_started(): + async with EchoServer() as e: + comm = await connect(e.address) + b = BatchedSend(interval=10) + assert b.closed() + b.start(comm) + assert not b.closed() + await b.close() + assert b.closed() + + +@pytest.mark.asyncio +async def test_start_twice_with_closing(): + async with EchoServer() as e: + comm = await connect(e.address) + comm2 = await connect(e.address) + + b = BatchedSend(interval=10) + b.start(comm) + + # Same comm is fine + b.start(comm) + + await b.close() + + b.start(comm2) + + b.send("hello") + b.send("world") + + result = await comm2.read() + assert result == ("hello", "world") + + +@pytest.mark.asyncio +async def test_start_twice_with_abort(): + async with EchoServer() as e: + comm = await connect(e.address) + comm2 = await connect(e.address) + + b = BatchedSend(interval=10) + b.start(comm) + + # Same comm is fine + b.start(comm) + + b.abort() + + b.start(comm2) + + b.send("hello") + b.send("world") + + result = await comm2.read() + assert result == ("hello", "world") + + +@pytest.mark.asyncio +async def test_start_twice_with_abort_drops_payload(): + async with EchoServer() as e: + comm = await connect(e.address) + comm2 = await connect(e.address) + + b = BatchedSend(interval=10) + b.start(comm) + b.send("hello") + b.send("world") + + # Same comm is fine + b.start(comm) + + b.abort() + + b.start(comm2) + + with pytest.raises(asyncio.TimeoutError): + res = await asyncio.wait_for(comm2.read(), 0.01) + assert not res + + +@pytest.mark.asyncio +async def test_start_closed_comm(): + async with EchoServer() as e: + comm = await connect(e.address) + await comm.close() + + b = BatchedSend(interval="10ms") + with pytest.raises(RuntimeError, match="Comm already closed."): + b.start(comm) + + +@pytest.mark.asyncio +async def test_start_twice_without_closing(): + async with EchoServer() as e: + comm = await connect(e.address) + comm2 = await connect(e.address) + + b = BatchedSend(interval=10) + b.start(comm) + + # Same comm is fine + b.start(comm) + + # different comm only allowed if already closed + with pytest.raises(RuntimeError, match="BatchedSend already started"): + b.start(comm2) + + b.send("hello") + b.send("world") + + result = await comm.read() + assert result == ("hello", "world") + + # This comm hasn't been used so there should be no message received + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(comm2.read(), 0.01) + + @pytest.mark.asyncio async def test_send_after_stream_start(): async with EchoServer() as e: @@ -113,8 +234,11 @@ async def test_send_before_close(): await asyncio.sleep(0.01) assert time() < start + 5 + msg = "123" with pytest.raises(CommClosedError): - b.send("123") + b.send(msg) + + assert msg not in b.buffer @pytest.mark.asyncio diff --git a/distributed/worker.py b/distributed/worker.py index 833d23da0d8..af44f1f6609 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1238,7 +1238,6 @@ async def handle_scheduler(self, comm): comm, every_cycle=[self.ensure_communicating, self.ensure_computing] ) except Exception as e: - self.batched_stream.please_stop = True logger.exception(e) raise finally: From 5795c5bb8b2a9b162cbc575ac04037510c22fb33 Mon Sep 17 00:00:00 2001 From: fjetter Date: Thu, 4 Nov 2021 18:54:11 +0100 Subject: [PATCH 3/3] some sleep in tcp --- distributed/batched.py | 58 ++++++++++++++++++++------ distributed/comm/tcp.py | 14 ++++++- distributed/scheduler.py | 2 +- distributed/tests/test_worker.py | 71 +++++++++++++++++++++++++------- 4 files changed, 116 insertions(+), 29 deletions(-) diff --git a/distributed/batched.py b/distributed/batched.py index ea98c8a48ff..382bdfaeeea 100644 --- a/distributed/batched.py +++ b/distributed/batched.py @@ -1,5 +1,6 @@ import logging from collections import deque +from uuid import uuid4 from tornado import gen, locks from tornado.ioloop import IOLoop @@ -7,6 +8,7 @@ import dask from dask.utils import parse_timedelta +from .comm import Comm from .core import CommClosedError logger = logging.getLogger(__name__) @@ -43,6 +45,7 @@ def __init__(self, interval, loop=None, serializers=None): self.interval = parse_timedelta(interval, default="ms") self.waker = locks.Event() self.stopped = locks.Event() + self.stopped.set() self.please_stop = False self.buffer = [] self.comm = None @@ -56,13 +59,33 @@ def __init__(self, interval, loop=None, serializers=None): self.serializers = serializers self._consecutive_failures = 0 - def start(self, comm): - self.comm = comm - self.please_stop = False - self.loop.add_callback(self._background_send) + def start(self, comm: Comm): + """ + Start the BatchedSend by providing an open Comm object. + + Calling this again on an already started BatchedSend will raise a + `RuntimeError` if the provided Comm is different to the current one. If + the provided Comm is identical this is a noop. + + In case the BatchedSend was already closed, this will use the newly + provided Comm to submit any accumulated messages in the buffer. + """ + if self.closed(): + if comm.closed(): + raise RuntimeError("Comm already closed.") + self.comm = comm + self.please_stop = False + self.loop.add_callback(self._background_send) + elif self.comm is not comm: + raise RuntimeError("BatchedSend already started.") def closed(self): - return self.comm and self.comm.closed() + """True if the BatchedSend hasn't been started or has been closed + already.""" + if self.comm is None or self.comm.closed(): + return True + else: + return False def __repr__(self): if self.closed(): @@ -99,7 +122,8 @@ def _background_send(self): else: self.recent_message_log.append("large-message") self.byte_count += nbytes - payload = [] # lose ref + + payload.clear() # lose ref except CommClosedError: logger.info("Batched Comm Closed %r", self.comm, exc_info=True) break @@ -121,21 +145,27 @@ def _background_send(self): self.stopped.set() return + self.stopped.set() # If we've reached here, it means `break` was hit above and # there was an exception when using `comm`. # We can't close gracefully via `.close()` since we can't send messages. # So we just abort. # To propagate exceptions, we rely on subsequent `BatchedSend.send` # calls to raise CommClosedErrors. - self.stopped.set() - self.abort() + + if self.comm: + self.comm.abort() + yield self.close() def send(self, *msgs): """Schedule a message for sending to the other side - This completes quickly and synchronously + This completes quickly and synchronously. + + If the BatchedSend or Comm is already closed, this raises a + CommClosedError and does not accept any further messages to the buffer. """ - if self.comm is not None and self.comm.closed(): + if self.closed(): raise CommClosedError(f"Comm {self.comm!r} already closed.") self.message_count += len(msgs) @@ -146,7 +176,7 @@ def send(self, *msgs): @gen.coroutine def close(self, timeout=None): - """Flush existing messages and then close comm + """Flush existing messages and then close Comm If set, raises `tornado.util.TimeoutError` after a timeout. """ @@ -155,8 +185,8 @@ def close(self, timeout=None): self.please_stop = True self.waker.set() yield self.stopped.wait(timeout=timeout) - payload = [] if not self.comm.closed(): + payload = [] try: if self.buffer: self.buffer, payload = [], self.buffer @@ -170,9 +200,13 @@ def close(self, timeout=None): yield self.comm.close() def abort(self): + """Close the BatchedSend immediately, without waiting for any pending + operations to complete. Buffered data will be lost.""" if self.comm is None: return + self.buffer = [] self.please_stop = True self.waker.set() if not self.comm.closed(): self.comm.abort() + self.comm = None diff --git a/distributed/comm/tcp.py b/distributed/comm/tcp.py index cad01427ceb..335f2cd6af9 100644 --- a/distributed/comm/tcp.py +++ b/distributed/comm/tcp.py @@ -1,3 +1,5 @@ +import asyncio + import ctypes import errno import functools @@ -290,7 +292,15 @@ async def write(self, msg, serializers=None, on_error="message"): stream._total_write_index += each_frame_nbytes # start writing frames - stream.write(b"") + await stream.write(b"") + # FIXME: How do I test this? Why is the stream closed _sometimes_? + # Diving into tornado, so far, I can only confirm that once the + # write future has been awaited, the entire buffer has been written + # to the socket. Not sure if one loop iteration is sufficient in + # general or just sufficient for the local tests I've been running + await asyncio.sleep(0) + if stream.closed(): + raise StreamClosedError() except StreamClosedError as e: self.stream = None self._closed = True @@ -333,6 +343,8 @@ def abort(self): stream.close() def closed(self): + if self.stream and self.stream.closed(): + self.abort() return self._closed @property diff --git a/distributed/scheduler.py b/distributed/scheduler.py index a44ba452769..4c4408ed9d5 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -5510,7 +5510,7 @@ async def handle_worker(self, comm=None, worker=None): await self.handle_stream(comm=comm, extra={"worker": worker}) finally: if worker in self.stream_comms: - worker_comm.abort() + await worker_comm.close() await self.remove_worker(address=worker) def add_plugin( diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 84d8fd7a20f..a3eef130c7a 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -3197,35 +3197,76 @@ async def test_deadlock_cancelled_after_inflight_before_gather_from_worker( @gen_cluster(nthreads=[("", 1)]) -async def test_dont_loose_payload_reconnect(s, w): +@pytest.mark.parametrize( + "sender", + [ + "worker", + "scheduler", + ], +) +async def test_dont_loose_payload_reconnect_worker_sends(s, w, sender): """Ensure that payload of a BatchedSend is not lost if a worker reconnects""" - s.count = 0 + while w.heartbeat_active: + await asyncio.sleep(0.1) + w.heartbeat_active = True + + if sender == "worker": + sender = w + sender_stream = w.batched_stream + receiver = s + receiver_stream = s.stream_comms[w.address] + else: + sender = s + sender_stream = s.stream_comms[w.address] + receiver = w + receiver_stream = w.batched_stream + + receiver.count = 0 def receive(worker, msg): - s.count += 1 + receiver.count += 1 + + receiver.stream_handlers["receive-msg"] = receive - s.stream_handlers["receive-msg"] = receive - w.batched_stream.next_deadline = w.loop.time() + 10_000 + # Wait until the buffer is empty such that we can start cleanly (e.g. + # hearbeats, status updates, etc.) + while sender_stream.buffer: + await asyncio.sleep(0.01) for x in range(100): - w.batched_stream.send({"op": "receive-msg", "msg": x}) + sender_stream.send({"op": "receive-msg", "msg": x}) + + receiver_stream.comm.abort() + + # Batch_count increases with every attempt. Therefore, if it increases we + # know the background send ran once + before_batch_count = sender_stream.batch_count - await s.stream_comms[w.address].comm.close() - while not w.batched_stream.comm.closed(): + before = sender_stream.buffer.copy() + assert len(sender_stream.buffer) == 100 + + while sender_stream.batch_count == before_batch_count: await asyncio.sleep(0.1) - before = w.batched_stream.buffer.copy() - w.batched_stream.next_deadline = w.loop.time() - assert len(w.batched_stream.buffer) == 100 - with captured_logger("distributed.batched") as caplog: - await w.batched_stream._background_send() - assert "Batched Comm Closed" in caplog.getvalue() - after = w.batched_stream.buffer.copy() + # At the time of send, we already know it is failed and the caller should + # handle this exception and trigger a reconnect + # TODO: Is the transition engine robust to this?? + new_message = {"op": "receive-msg", "msg": 100} + with pytest.raises(CommClosedError): + sender_stream.send(new_message) + assert new_message not in sender_stream.buffer + + after = sender_stream.buffer.copy() # Payload that couldn't be submitted is prepended assert len(after) >= len(before) assert after[: len(before)] == before + # No message received, yet + assert s.count == 0 + + # Now, reconnect and everythign should stabilize again + w.heartbeat_active = False await w.heartbeat() while not s.count == 100: await asyncio.sleep(0.1)