Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure writable frames #3967

Merged
merged 13 commits into from
Jul 17, 2020
8 changes: 8 additions & 0 deletions distributed/protocol/tests/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,14 @@ def test_dumps_serialize_numpy(x):
np.testing.assert_equal(x, y)


def test_dumps_numpy_writable():
a1 = np.arange(1000)
a1.flags.writeable = False
(a2,) = loads(dumps([to_serialize(a1)]))
assert (a1 == a2).all()
assert a2.flags.writeable


@pytest.mark.parametrize(
"x",
[
Expand Down
34 changes: 24 additions & 10 deletions distributed/protocol/tests/test_protocol_utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,33 @@
import pytest

from distributed.protocol.utils import merge_frames, pack_frames, unpack_frames
from distributed.utils import ensure_bytes


def test_merge_frames():
result = merge_frames({"lengths": [3, 4]}, [b"12", b"34", b"567"])
expected = [b"123", b"4567"]

@pytest.mark.parametrize(
"lengths,frames",
[
([3], [b"123"]),
([3, 3], [b"123", b"456"]),
([2, 3, 2], [b"12345", b"67"]),
([5, 2], [b"123", b"45", b"67"]),
([3, 4], [b"12", b"34", b"567"]),
],
)
def test_merge_frames(lengths, frames):
header = {"lengths": lengths}
result = merge_frames(header, frames)

data = b"".join(frames)
expected = []
for i in lengths:
expected.append(data[:i])
data = data[i:]

assert all(isinstance(f, memoryview) for f in result)
assert all(not f.readonly for f in result)
assert list(map(ensure_bytes, result)) == expected

b = b"123"
assert merge_frames({"lengths": [3]}, [b])[0] is b

L = [b"123", b"456"]
assert merge_frames({"lengths": [3, 3]}, L) is L


def test_pack_frames():
frames = [b"123", b"asdf"]
Expand Down
50 changes: 26 additions & 24 deletions distributed/protocol/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,34 +58,36 @@ def merge_frames(header, frames):
[b'123456']
"""
lengths = list(header["lengths"])
frames = list(map(memoryview, frames))

assert sum(lengths) == sum(map(nbytes, frames))

if all(len(f) == l for f, l in zip(frames, lengths)):
return frames

frames = frames[::-1]
lengths = lengths[::-1]

out = []
while lengths:
l = lengths.pop()
L = []
while l:
frame = frames.pop()
if nbytes(frame) <= l:
L.append(frame)
l -= nbytes(frame)
if not all(len(f) == l for f, l in zip(frames, lengths)):
frames = frames[::-1]
lengths = lengths[::-1]

out = []
while lengths:
l = lengths.pop()
L = []
while l:
frame = frames.pop()
if nbytes(frame) <= l:
L.append(frame)
l -= nbytes(frame)
else:
L.append(frame[:l])
frames.append(frame[l:])
l = 0
if len(L) == 1: # no work necessary
out.append(L[0])
else:
mv = memoryview(frame)
L.append(mv[:l])
frames.append(mv[l:])
l = 0
if len(L) == 1: # no work necessary
out.extend(L)
else:
out.append(b"".join(L))
return out
out.append(memoryview(bytearray().join(L)))
frames = out

frames = [memoryview(bytearray(f)) if f.readonly else f for f in frames]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For fun I decided to micro-benchmark this:

In [1]: L = [b'0' * i for i in [10, 100, 10, 100, 1000] * 2]                                                                                                                                                      

In [2]: %timeit b"".join(L)                                                                                                                                                                                       
203 ns ± 1.6 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

In [3]: %timeit memoryview(bytearray().join(L))                                                                                                                                                                   
463 ns ± 5.9 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

260ns doesn't really matter today I don't think


return frames


def pack_frames_prelude(frames):
Expand Down