Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not drop BatchedSend payload if worker reconnects #5457

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 52 additions & 13 deletions distributed/batched.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import logging
from collections import deque
from uuid import uuid4

from tornado import gen, locks
from tornado.ioloop import IOLoop

import dask
from dask.utils import parse_timedelta

from .comm import Comm
from .core import CommClosedError

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -43,6 +45,7 @@ def __init__(self, interval, loop=None, serializers=None):
self.interval = parse_timedelta(interval, default="ms")
self.waker = locks.Event()
self.stopped = locks.Event()
self.stopped.set()
self.please_stop = False
self.buffer = []
self.comm = None
Expand All @@ -56,12 +59,33 @@ def __init__(self, interval, loop=None, serializers=None):
self.serializers = serializers
self._consecutive_failures = 0

def start(self, comm):
self.comm = comm
self.loop.add_callback(self._background_send)
def start(self, comm: Comm):
"""
Start the BatchedSend by providing an open Comm object.

Calling this again on an already started BatchedSend will raise a
`RuntimeError` if the provided Comm is different to the current one. If
the provided Comm is identical this is a noop.

In case the BatchedSend was already closed, this will use the newly
provided Comm to submit any accumulated messages in the buffer.
"""
if self.closed():
if comm.closed():
raise RuntimeError("Comm already closed.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise RuntimeError("Comm already closed.")
raise RuntimeError(f"Tried to start BatchedSend with an already-closed comm: {comm!r}.")

self.comm = comm
self.please_stop = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.please_stop = False
self.please_stop = False
self.stopped.clear()

although I also find the use of stopped a little odd, see the comment on cancel

self.loop.add_callback(self._background_send)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If self.closed is True, can we be certain this callback is not already running? Two _background_sends running at the same time might make a mess. I think there's currently a race condition:

  • _background_send is awaiting self.waker.wait
  • comm gets closed (nothing happens to BatchedSend state at this point; a closed comm is only detected once a write to it fails)
  • start is called with a new comm, launching a new _background_send coroutine
  • both coroutines are now blocking on self.waker.wait; it's a race condition which gets awakened first
  • both can run validly; from their perspective, nothing even happened (the comm was switched out under their noses while they were sleeping)

I'm not actually sure if, with the exact way the code is written now, two coroutines running at once can actually do something bad, but it still seems like a bad and brittle situation. "only one _background_send is running at once" feels like a key invariant of this class to me.

This is another way in which having an asyncio Task handle on the _background_send might make things easier to reason about #5457 (comment)

elif self.comm is not comm:
raise RuntimeError("BatchedSend already started.")

def closed(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please_stop and stopped should probably factor into this as well

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think this check is insufficient for telling whether we're closed or not. While we're in the process of closing (please_stop is True, or self.stopped is set, or _background_send is no longer running, etc.), it may still return True.

From the way this is used though, in both start and write, we probably want to treat "in the process of closing" as closed, not as running. Restarting a BatchedSend that's closing should be an error. If writing to a closed BatchedSend is an error, then so should be writing to one that's in the process of closing.

return self.comm and self.comm.closed()
"""True if the BatchedSend hasn't been started or has been closed
already."""
if self.comm is None or self.comm.closed():
return True
else:
return False

def __repr__(self):
if self.closed():
Expand Down Expand Up @@ -98,6 +122,8 @@ def _background_send(self):
else:
self.recent_message_log.append("large-message")
self.byte_count += nbytes

payload.clear() # lose ref
except CommClosedError:
logger.info("Batched Comm Closed %r", self.comm, exc_info=True)
break
Expand All @@ -111,28 +137,35 @@ def _background_send(self):
logger.exception("Error in batched write")
break
finally:
payload = None # lose ref
# If anything failed we should not loose payload. If a new comm
# is provided we can still resubmit messages
self.buffer = payload + self.buffer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure we should retain payload? Couldn't those messages have been successfully sent, and the error occurred after? Then we might be duplicating them when we reconnect.

Though maybe we'd rather duplicate messages than drop them. In that case, let's add a note saying so.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this point I am assuming a certain behaviour of the comm. Either the comm writes all or nothing. That's likely not always true but I believe we cannot do much about it on this level of abstraction. imho, that guarantee should be implemented by our protocol and/or Comm interface.

Either way, I'm happy if you have any suggestions to improve this

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After reading through TCP comm code, I don't think there's anywhere for an exception to happen after all of the payload has been sent. An exception could still happen when part has been sent and part hasn't, but either way, since we have to assume something here, I think it's more useful to assume that the payload hasn't been sent.

@jcrist and I were discussing this, and given the way the BatchedSend interface works, it actually needs to implement some sort of protocol with an ack from the receiver for each sent batch to guarantee messages can't be dropped. Since send on a BatchedSend is nonblocking, the caller is basically handing off full responsibility for the message to BatchedSend. If the message fails to send, it's too late to raise an error and let the caller figure out what to do about it—once we've been given a message, we have to ensure it's delivered. So the logical thing to do would be for the receiving side to ack each message, and only when ack'd does the sender drop the payload (with some deduplication of course).

OTOH there are lots of protocols out there for doing things like this, more performantly, robustly, and with better testing than we'll ever have. As with other things (framing in serialization), maybe the better solution is to stop duplicating functionality at the application level that should be the transport layer's job.

Could we get rid of BatchedSend entirely with some well-tuned TCP buffering settings + a serialization scheme that was more efficient for many small messages?

else:
# nobreak. We've been gracefully closed.
self.stopped.set()
return

self.stopped.set()
# If we've reached here, it means `break` was hit above and
# there was an exception when using `comm`.
# We can't close gracefully via `.close()` since we can't send messages.
# So we just abort.
# This means that any messages in our buffer our lost.
# To propagate exceptions, we rely on subsequent `BatchedSend.send`
# calls to raise CommClosedErrors.
self.stopped.set()
self.abort()

if self.comm:
self.comm.abort()
yield self.close()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why close and not abort?

  1. It's weird that calling close from user code will trigger the coroutine to call back into close
  2. Most of the things close does are already guaranteed to have happened by this point:
    • self.please_stop must be True to reach this line of code
    • self.stopped was just set a couple lines above
    • The comm was just aborted (not sure how that compares to await comm.close() though)

Not that abort is any better. The only thing it would do that hasn't already been done by this point is set self.comm = None.

I actually think we should move all shut-down/clean-up logic into the coroutine here, remove abort entirely, and have close() just be a function that tells the coroutine to please_stop and waits until it does. There's a lot of inconsistency from having shutdown logic in three different places.


def send(self, *msgs):
"""Schedule a message for sending to the other side

This completes quickly and synchronously
This completes quickly and synchronously.

If the BatchedSend or Comm is already closed, this raises a
CommClosedError and does not accept any further messages to the buffer.
"""
if self.comm is not None and self.comm.closed():
if self.closed():
raise CommClosedError(f"Comm {self.comm!r} already closed.")

self.message_count += len(msgs)
Expand All @@ -143,7 +176,7 @@ def send(self, *msgs):

@gen.coroutine
def close(self, timeout=None):
"""Flush existing messages and then close comm
"""Flush existing messages and then close Comm

If set, raises `tornado.util.TimeoutError` after a timeout.
"""
Expand All @@ -153,21 +186,27 @@ def close(self, timeout=None):
self.waker.set()
yield self.stopped.wait(timeout=timeout)
if not self.comm.closed():
payload = []
try:
if self.buffer:
self.buffer, payload = [], self.buffer
yield self.comm.write(
payload, serializers=self.serializers, on_error="raise"
)
except CommClosedError:
pass
# If we're closing and there is an error there is little we
# can do about this to recover.
logger.error("Lost %i payload messages.", len(payload))
yield self.comm.close()
Comment on lines 188 to 200
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find it odd that we try to write within close. AFAICT, this code path can only be triggered when write has been called after close, and the BatchedSend is in the "trying to close" state (please_stop is True). That's the only way self.buffer could be non-empty and stopped could be set, because otherwise the coroutine would have gone through another loop, sent the buffer, and cleared it out.

Basically, I'd prefer it if

self.next_deadline = None
self.waker.set()
await self.stopped.wait()

could guarantee that the buffer flushed. The coroutine's whole job is to flush out the buffer when it's awakened and past its deadline; why duplicate that logic elsewhere?

I believe that tightening up the logic for closed()—to have "trying to close" count as closed—would make this possible. It would prevent write from enqueuing any messages after we've started to close, so we can guarantee that once close has started, no new messages can arrive in the buffer, and there's nothing to flush out.


def abort(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should abort be part of the public interface? When/how should callers use abort vs close? What can and cannot be done with a BatchedSend after abort (or close) has been called? There are all things I'd like to see documented.

"""Close the BatchedSend immediately, without waiting for any pending
operations to complete. Buffered data will be lost."""
if self.comm is None:
return
self.please_stop = True
self.buffer = []
self.please_stop = True
self.waker.set()
if not self.comm.closed():
self.comm.abort()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment for both here and close(): why should the BatchedSend close/abort the comm when it's closed? Does the caller expect that, when calling start, it's handed off lifecycle responsibilities for the comm to BatchedSend? If so, that should be documented ("once you call start, you must not directly use the comm object ").

self.comm = None
14 changes: 13 additions & 1 deletion distributed/comm/tcp.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio

import ctypes
import errno
import functools
Expand Down Expand Up @@ -290,7 +292,15 @@ async def write(self, msg, serializers=None, on_error="message"):
stream._total_write_index += each_frame_nbytes

# start writing frames
stream.write(b"")
await stream.write(b"")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused, why were we not awaiting this before?!

I traversed through the blames and it looks like https://github.com/dask/distributed/pull/661/files#diff-581957c552b88dd04319efca1429c0c0827daff509a115542e70b661c4c04914R235 5 years ago is where we stopped awaiting this future. There was even some conversation about it: #661 (comment), xref #653.

Seems like there were some issues around concurrent writes to the same stream from multiple coroutines producing orphaned futures. I wonder if that's still the case?

# FIXME: How do I test this? Why is the stream closed _sometimes_?
# Diving into tornado, so far, I can only confirm that once the
# write future has been awaited, the entire buffer has been written
# to the socket. Not sure if one loop iteration is sufficient in
# general or just sufficient for the local tests I've been running
await asyncio.sleep(0)
if stream.closed():
raise StreamClosedError()
Comment on lines +295 to +303
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixing this problem, reusing the buffer, etc. is relatively easy to do assuming we can rely on the Com.write to raise if something is wrong. Tornado ensures that if we await BaseIOStream.write, the data is written to the socket. I dove into the code and can confirm this. To the best of my knowledge, that's working properly.
However, there is a timewindow where the data is written to the socket, the connection is closed, the Comm.write returns control but the data was never submitted. I can catch these situations for my local unit tests by waiting a loop iteration but I'm not sure what the mechanism behind this is and whether or not this is reliable.

cc @jcrist @gjoseph92

FWIW, I think even without this asyncio.sleep(0) this fix should fix >99% of our problems. For a full fix we might need to reconsider the possibility for workers to reconnect statefully or introduce another way to guarantee message delivery.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fjetter to rephrase what you're saying:

  • awaiting the stream.write is a big improvement, because it means that comm.write doesn't return until the message has been successfully written to the socket buffer
  • That's still not enough to guarantee the message has actually been sent, because the socket could get closed after the message is handed off the buffer, but before it's actually sent over the network.

For a full fix we might need to reconsider the possibility for workers to reconnect statefully or introduce another way to guarantee message delivery.

I agree. Like I mentioned in another comment, I think we need a better protocol (either application-level or more careful use of TCP) to guarantee all messages are delivered through reconnects.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's what I figured but I was confused why I never encountered this for ordinary comms.

except StreamClosedError as e:
self.stream = None
self._closed = True
Expand Down Expand Up @@ -333,6 +343,8 @@ def abort(self):
stream.close()

def closed(self):
if self.stream and self.stream.closed():
self.abort()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems odd to me that checking closed on a comm could have a side effect?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, technically this sets self.stream = None and sets _closed = True. I wouldn't consider this side effects

return self._closed

@property
Expand Down
2 changes: 1 addition & 1 deletion distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5510,7 +5510,7 @@ async def handle_worker(self, comm=None, worker=None):
await self.handle_stream(comm=comm, extra={"worker": worker})
finally:
if worker in self.stream_comms:
worker_comm.abort()
await worker_comm.close()
await self.remove_worker(address=worker)

def add_plugin(
Expand Down
150 changes: 148 additions & 2 deletions distributed/tests/test_batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,11 @@ async def test_BatchedSend():
comm = await connect(e.address)

b = BatchedSend(interval=10)
assert "<BatchedSend: closed>" == str(b)
assert "<BatchedSend: closed>" == repr(b)
b.start(comm)
assert str(len(b.buffer)) in str(b)
assert str(len(b.buffer)) in repr(b)
b.start(comm)

await asyncio.sleep(0.020)

Expand Down Expand Up @@ -79,6 +81,125 @@ async def test_send_before_start():
assert result == ("hello", "world")


@pytest.mark.asyncio
async def test_closed_if_not_started():
async with EchoServer() as e:
comm = await connect(e.address)
b = BatchedSend(interval=10)
assert b.closed()
b.start(comm)
assert not b.closed()
await b.close()
assert b.closed()


@pytest.mark.asyncio
async def test_start_twice_with_closing():
async with EchoServer() as e:
comm = await connect(e.address)
comm2 = await connect(e.address)

b = BatchedSend(interval=10)
b.start(comm)

# Same comm is fine
b.start(comm)

await b.close()

b.start(comm2)

b.send("hello")
b.send("world")

result = await comm2.read()
assert result == ("hello", "world")


@pytest.mark.asyncio
async def test_start_twice_with_abort():
async with EchoServer() as e:
comm = await connect(e.address)
comm2 = await connect(e.address)

b = BatchedSend(interval=10)
b.start(comm)

# Same comm is fine
b.start(comm)

b.abort()

b.start(comm2)

b.send("hello")
b.send("world")

result = await comm2.read()
assert result == ("hello", "world")


@pytest.mark.asyncio
async def test_start_twice_with_abort_drops_payload():
async with EchoServer() as e:
comm = await connect(e.address)
comm2 = await connect(e.address)

b = BatchedSend(interval=10)
b.start(comm)
b.send("hello")
b.send("world")

# Same comm is fine
b.start(comm)

b.abort()

b.start(comm2)

with pytest.raises(asyncio.TimeoutError):
res = await asyncio.wait_for(comm2.read(), 0.01)
assert not res


@pytest.mark.asyncio
async def test_start_closed_comm():
async with EchoServer() as e:
comm = await connect(e.address)
await comm.close()

b = BatchedSend(interval="10ms")
with pytest.raises(RuntimeError, match="Comm already closed."):
b.start(comm)


@pytest.mark.asyncio
async def test_start_twice_without_closing():
async with EchoServer() as e:
comm = await connect(e.address)
comm2 = await connect(e.address)

b = BatchedSend(interval=10)
b.start(comm)

# Same comm is fine
b.start(comm)

# different comm only allowed if already closed
with pytest.raises(RuntimeError, match="BatchedSend already started"):
b.start(comm2)

b.send("hello")
b.send("world")

result = await comm.read()
assert result == ("hello", "world")

# This comm hasn't been used so there should be no message received
with pytest.raises(asyncio.TimeoutError):
await asyncio.wait_for(comm2.read(), 0.01)


@pytest.mark.asyncio
async def test_send_after_stream_start():
async with EchoServer() as e:
Expand Down Expand Up @@ -113,8 +234,11 @@ async def test_send_before_close():
await asyncio.sleep(0.01)
assert time() < start + 5

msg = "123"
with pytest.raises(CommClosedError):
b.send("123")
b.send(msg)

assert msg not in b.buffer


@pytest.mark.asyncio
Expand Down Expand Up @@ -249,3 +373,25 @@ async def test_serializers():
assert "function" in value

assert comm.closed()


@pytest.mark.asyncio
async def test_retain_buffer_commclosed():
async with EchoServer() as e:
with captured_logger("distributed.batched") as caplog:
comm = await connect(e.address)

b = BatchedSend(interval="1s", serializers=["msgpack"])
b.start(comm)
b.send("foo")
assert b.buffer
await comm.close()
await asyncio.sleep(1)

assert "Batched Comm Closed" in caplog.getvalue()
assert b.buffer

new_comm = await connect(e.address)
b.start(new_comm)
assert await new_comm.read() == ("foo",)
assert not b.buffer
Loading