Skip to content

Commit 349e26e

Browse files
committed
Merge branch 'master' of github.com:dask/distributed into hlg_pack_move_to_dask
2 parents bb7a411 + 383ea03 commit 349e26e

File tree

2 files changed

+33
-36
lines changed

2 files changed

+33
-36
lines changed

distributed/comm/tcp.py

+32-34
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from .addressing import parse_host_port, unparse_host_port
2828
from .core import Comm, Connector, Listener, CommClosedError, FatalCommClosedError
2929
from .utils import to_frames, from_frames, get_tcp_server_address, ensure_concrete_host
30+
from ..protocol.utils import pack_frames_prelude, unpack_frames
3031

3132

3233
logger = logging.getLogger(__name__)
@@ -187,19 +188,16 @@ async def read(self, deserializers=None):
187188
if stream is None:
188189
raise CommClosedError
189190

191+
fmt = "Q"
192+
fmt_size = struct.calcsize(fmt)
193+
190194
try:
191-
n_frames = await stream.read_bytes(8)
192-
n_frames = struct.unpack("Q", n_frames)[0]
193-
lengths = await stream.read_bytes(8 * n_frames)
194-
lengths = struct.unpack("Q" * n_frames, lengths)
195-
196-
frames = []
197-
for length in lengths:
198-
frame = bytearray(length)
199-
if length:
200-
n = await stream.read_into(frame)
201-
assert n == length, (n, length)
202-
frames.append(frame)
195+
frames_nbytes = await stream.read_bytes(fmt_size)
196+
(frames_nbytes,) = struct.unpack(fmt, frames_nbytes)
197+
198+
frames = bytearray(frames_nbytes)
199+
n = await stream.read_into(frames)
200+
assert n == frames_nbytes, (n, frames_nbytes)
203201
except StreamClosedError as e:
204202
self.stream = None
205203
self._closed = True
@@ -214,6 +212,8 @@ async def read(self, deserializers=None):
214212
raise
215213
else:
216214
try:
215+
frames = unpack_frames(frames)
216+
217217
msg = await from_frames(
218218
frames,
219219
deserialize=self.deserialize,
@@ -243,30 +243,28 @@ async def write(self, msg, serializers=None, on_error="message"):
243243
**self.handshake_options,
244244
},
245245
)
246+
frames_nbytes = sum(map(nbytes, frames))
246247

247-
try:
248-
nframes = len(frames)
249-
lengths = [nbytes(frame) for frame in frames]
250-
length_bytes = struct.pack(f"Q{nframes}Q", nframes, *lengths)
248+
header = pack_frames_prelude(frames)
249+
header = struct.pack("Q", nbytes(header) + frames_nbytes) + header
251250

252-
frames = [length_bytes, *frames]
253-
lengths = [len(length_bytes), *lengths]
251+
frames = [header, *frames]
252+
frames_nbytes += nbytes(header)
254253

255-
if sum(lengths) < 2 ** 17: # 128kiB
256-
# small enough, send in one go
257-
stream.write(b"".join(frames))
258-
else:
259-
# avoid large memcpy, send in many
260-
for frame, frame_bytes in zip(frames, lengths):
261-
# Can't wait for the write() Future as it may be lost
262-
# ("If write is called again before that Future has resolved,
263-
# the previous future will be orphaned and will never resolve")
264-
if frame_bytes:
265-
future = stream.write(frame)
266-
bytes_since_last_yield += frame_bytes
267-
if bytes_since_last_yield > 32e6:
268-
await future
269-
bytes_since_last_yield = 0
254+
if frames_nbytes < 2 ** 17: # 128kiB
255+
# small enough, send in one go
256+
frames = [b"".join(frames)]
257+
258+
try:
259+
# trick to enque all frames for writing beforehand
260+
for each_frame in frames:
261+
each_frame_nbytes = nbytes(each_frame)
262+
if each_frame_nbytes:
263+
stream._write_buffer.append(each_frame)
264+
stream._total_write_index += each_frame_nbytes
265+
266+
# start writing frames
267+
stream.write(b"")
270268
except StreamClosedError as e:
271269
self.stream = None
272270
self._closed = True
@@ -282,7 +280,7 @@ async def write(self, msg, serializers=None, on_error="message"):
282280
self.abort()
283281
raise
284282

285-
return sum(lengths)
283+
return frames_nbytes
286284

287285
@gen.coroutine
288286
def close(self):

distributed/protocol/utils.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,7 @@ def unpack_frames(b):
125125
start = fmt_size * (1 + n_frames)
126126
for length in lengths:
127127
end = start + length
128-
frame = b[start:end]
129-
frames.append(frame)
128+
frames.append(b[start:end])
130129
start = end
131130

132131
return frames

0 commit comments

Comments
 (0)