Skip to content

Commit

Permalink
Ensure writable frames (#3967)
Browse files Browse the repository at this point in the history
User code working with NumPy or Pandas objects often expects the objects
to be mutable. However if read-only frames (like `bytes`) objects are
used, this is not true. So add a test to check for this so that we can
make sure this is true and we can catch and fix cases where that may not
be true.
  • Loading branch information
jakirkham authored Jul 17, 2020
1 parent ef168a4 commit c67705f
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 34 deletions.
8 changes: 8 additions & 0 deletions distributed/protocol/tests/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,14 @@ def test_dumps_serialize_numpy(x):
np.testing.assert_equal(x, y)


def test_dumps_numpy_writable():
a1 = np.arange(1000)
a1.flags.writeable = False
(a2,) = loads(dumps([to_serialize(a1)]))
assert (a1 == a2).all()
assert a2.flags.writeable


@pytest.mark.parametrize(
"x",
[
Expand Down
34 changes: 24 additions & 10 deletions distributed/protocol/tests/test_protocol_utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,33 @@
import pytest

from distributed.protocol.utils import merge_frames, pack_frames, unpack_frames
from distributed.utils import ensure_bytes


def test_merge_frames():
result = merge_frames({"lengths": [3, 4]}, [b"12", b"34", b"567"])
expected = [b"123", b"4567"]

@pytest.mark.parametrize(
"lengths,frames",
[
([3], [b"123"]),
([3, 3], [b"123", b"456"]),
([2, 3, 2], [b"12345", b"67"]),
([5, 2], [b"123", b"45", b"67"]),
([3, 4], [b"12", b"34", b"567"]),
],
)
def test_merge_frames(lengths, frames):
header = {"lengths": lengths}
result = merge_frames(header, frames)

data = b"".join(frames)
expected = []
for i in lengths:
expected.append(data[:i])
data = data[i:]

assert all(isinstance(f, memoryview) for f in result)
assert all(not f.readonly for f in result)
assert list(map(ensure_bytes, result)) == expected

b = b"123"
assert merge_frames({"lengths": [3]}, [b])[0] is b

L = [b"123", b"456"]
assert merge_frames({"lengths": [3, 3]}, L) is L


def test_pack_frames():
frames = [b"123", b"asdf"]
Expand Down
50 changes: 26 additions & 24 deletions distributed/protocol/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,34 +58,36 @@ def merge_frames(header, frames):
[b'123456']
"""
lengths = list(header["lengths"])
frames = list(map(memoryview, frames))

assert sum(lengths) == sum(map(nbytes, frames))

if all(len(f) == l for f, l in zip(frames, lengths)):
return frames

frames = frames[::-1]
lengths = lengths[::-1]

out = []
while lengths:
l = lengths.pop()
L = []
while l:
frame = frames.pop()
if nbytes(frame) <= l:
L.append(frame)
l -= nbytes(frame)
if not all(len(f) == l for f, l in zip(frames, lengths)):
frames = frames[::-1]
lengths = lengths[::-1]

out = []
while lengths:
l = lengths.pop()
L = []
while l:
frame = frames.pop()
if nbytes(frame) <= l:
L.append(frame)
l -= nbytes(frame)
else:
L.append(frame[:l])
frames.append(frame[l:])
l = 0
if len(L) == 1: # no work necessary
out.append(L[0])
else:
mv = memoryview(frame)
L.append(mv[:l])
frames.append(mv[l:])
l = 0
if len(L) == 1: # no work necessary
out.extend(L)
else:
out.append(b"".join(L))
return out
out.append(memoryview(bytearray().join(L)))
frames = out

frames = [memoryview(bytearray(f)) if f.readonly else f for f in frames]

return frames


def pack_frames_prelude(frames):
Expand Down

0 comments on commit c67705f

Please sign in to comment.