diff --git a/distributed/protocol/core.py b/distributed/protocol/core.py index 6df5931e22a..5818a7f23a9 100644 --- a/distributed/protocol/core.py +++ b/distributed/protocol/core.py @@ -6,18 +6,15 @@ from .compression import compressions, maybe_compress, decompress from .serialize import ( - serialize, - deserialize, Serialize, Serialized, extract_serialize, msgpack_decode_default, msgpack_encode_default, + merge_and_deserialize, + serialize_and_split, ) -from .utils import frame_split_size, merge_frames, msgpack_opts -from ..utils import is_writeable, nbytes - -_deserialize = deserialize +from .utils import msgpack_opts logger = logging.getLogger(__name__) @@ -48,7 +45,7 @@ def dumps(msg, serializers=None, on_error="message", context=None): } data = { - key: serialize( + key: serialize_and_split( value.data, serializers=serializers, on_error=on_error, context=context ) for key, value in data.items() @@ -60,39 +57,23 @@ def dumps(msg, serializers=None, on_error="message", context=None): out_frames = [] for key, (head, frames) in data.items(): - if "writeable" not in head: - head["writeable"] = tuple(map(is_writeable, frames)) - if "lengths" not in head: - head["lengths"] = tuple(map(nbytes, frames)) - # Compress frames that are not yet compressed out_compression = [] - _out_frames = [] for frame, compression in zip( frames, head.get("compression") or [None] * len(frames) ): - if compression is None: # default behavior - _frames = frame_split_size(frame) - _compression, _frames = zip( - *[maybe_compress(frame, **compress_opts) for frame in _frames] - ) - out_compression.extend(_compression) - _out_frames.extend(_frames) - else: # already specified, so pass - out_compression.append(compression) - _out_frames.append(frame) + if compression is None: + compression, frame = maybe_compress(frame, **compress_opts) + + out_compression.append(compression) + out_frames.append(frame) head["compression"] = out_compression - head["count"] = len(_out_frames) + head["count"] = len(frames) header["headers"][key] = head header["keys"].append(key) - out_frames.extend(_out_frames) for key, (head, frames) in pre.items(): - if "writeable" not in head: - head["writeable"] = tuple(map(is_writeable, frames)) - if "lengths" not in head: - head["lengths"] = tuple(map(nbytes, frames)) head["count"] = len(frames) header["headers"][key] = head header["keys"].append(key) @@ -146,9 +127,7 @@ def loads(frames, deserialize=True, deserializers=None): if deserialize or key in bytestrings: if "compression" in head: fs = decompress(head, fs) - if not any(hasattr(f, "__cuda_array_interface__") for f in fs): - fs = merge_frames(head, fs) - value = _deserialize(head, fs, deserializers=deserializers) + value = merge_and_deserialize(head, fs, deserializers=deserializers) else: value = Serialized(head, fs) diff --git a/distributed/protocol/cupy.py b/distributed/protocol/cupy.py index 856fc5adf46..eeeae687557 100644 --- a/distributed/protocol/cupy.py +++ b/distributed/protocol/cupy.py @@ -22,7 +22,6 @@ def cuda_serialize_cupy_ndarray(x): header = x.__cuda_array_interface__.copy() header["strides"] = tuple(x.strides) - header["lengths"] = [x.nbytes] frames = [ cupy.ndarray( shape=(x.nbytes,), dtype=cupy.dtype("u1"), memptr=x.data, strides=(1,) @@ -47,7 +46,6 @@ def cuda_deserialize_cupy_ndarray(header, frames): @dask_serialize.register(cupy.ndarray) def dask_serialize_cupy_ndarray(x): header, frames = cuda_serialize_cupy_ndarray(x) - header["writeable"] = (None,) * len(frames) frames = [memoryview(cupy.asnumpy(f)) for f in frames] return header, frames diff --git a/distributed/protocol/numba.py b/distributed/protocol/numba.py index e1915251f6f..668f07e0926 100644 --- a/distributed/protocol/numba.py +++ b/distributed/protocol/numba.py @@ -23,7 +23,6 @@ def cuda_serialize_numba_ndarray(x): header = x.__cuda_array_interface__.copy() header["strides"] = tuple(x.strides) - header["lengths"] = [x.nbytes] frames = [ numba.cuda.cudadrv.devicearray.DeviceNDArray( shape=(x.nbytes,), strides=(1,), dtype=np.dtype("u1"), gpu_data=x.gpu_data @@ -51,7 +50,6 @@ def cuda_deserialize_numba_ndarray(header, frames): @dask_serialize.register(numba.cuda.devicearray.DeviceNDArray) def dask_serialize_numba_ndarray(x): header, frames = cuda_serialize_numba_ndarray(x) - header["writeable"] = (None,) * len(frames) frames = [memoryview(f.copy_to_host()) for f in frames] return header, frames diff --git a/distributed/protocol/numpy.py b/distributed/protocol/numpy.py index 4ae9298f142..65f7e2f4076 100644 --- a/distributed/protocol/numpy.py +++ b/distributed/protocol/numpy.py @@ -4,7 +4,7 @@ from .serialize import dask_serialize, dask_deserialize from . import pickle -from ..utils import log_errors, nbytes +from ..utils import log_errors def itemsize(dt): @@ -29,7 +29,6 @@ def serialize_numpy_ndarray(x, context=None): buffer_callback=buffer_callback, protocol=(context or {}).get("pickle-protocol", None), ) - header["lengths"] = tuple(map(nbytes, frames)) return header, frames # We cannot blindly pickle the dtype as some may fail pickling, @@ -93,15 +92,17 @@ def serialize_numpy_ndarray(x, context=None): # "ValueError: cannot include dtype 'M' in a buffer" data = data.view("u%d" % math.gcd(x.dtype.itemsize, 8)).data - header = {"dtype": dt, "shape": x.shape, "strides": strides} + header = { + "dtype": dt, + "shape": x.shape, + "strides": strides, + "writeable": [x.flags.writeable], + } if broadcast_to is not None: header["broadcast_to"] = broadcast_to frames = [data] - - header["lengths"] = [x.nbytes] - return header, frames @@ -112,6 +113,7 @@ def deserialize_numpy_ndarray(header, frames): return pickle.loads(frames[0], buffers=frames[1:]) (frame,) = frames + (writeable,) = header["writeable"] is_custom, dt = header["dtype"] if is_custom: @@ -125,6 +127,10 @@ def deserialize_numpy_ndarray(header, frames): shape = header["shape"] x = np.ndarray(shape, dtype=dt, buffer=frame, strides=header["strides"]) + if not writeable: + x.flags.writeable = False + else: + x = np.require(x, requirements=["W"]) return x diff --git a/distributed/protocol/rmm.py b/distributed/protocol/rmm.py index 6a56a70ab76..e25919c0fbf 100644 --- a/distributed/protocol/rmm.py +++ b/distributed/protocol/rmm.py @@ -13,7 +13,6 @@ def cuda_serialize_rmm_device_buffer(x): header = x.__cuda_array_interface__.copy() header["strides"] = (1,) - header["lengths"] = [x.nbytes] frames = [x] return header, frames @@ -31,7 +30,6 @@ def cuda_deserialize_rmm_device_buffer(header, frames): @dask_serialize.register(rmm.DeviceBuffer) def dask_serialize_rmm_device_buffer(x): header, frames = cuda_serialize_rmm_device_buffer(x) - header["writeable"] = (None,) * len(frames) frames = [numba.cuda.as_cuda_array(f).copy_to_host().data for f in frames] return header, frames diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index b99e7692e43..1a447468bf3 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -10,13 +10,12 @@ import msgpack from . import pickle -from ..utils import has_keyword, nbytes, typename, ensure_bytes, is_writeable +from ..utils import has_keyword, typename, ensure_bytes from .compression import maybe_compress, decompress from .utils import ( unpack_frames, pack_frames_prelude, frame_split_size, - merge_frames, msgpack_opts, ) @@ -30,7 +29,7 @@ def dask_dumps(x, context=None): - """Serialise object using the class-based registry""" + """Serialize object using the class-based registry""" type_name = typename(type(x)) try: dumps = dask_serialize.dispatch(type(x)) @@ -54,7 +53,6 @@ def dask_loads(header, frames): def pickle_dumps(x, context=None): - header = {"serializer": "pickle"} frames = [None] buffer_callback = lambda f: frames.append(memoryview(f)) frames[0] = pickle.dumps( @@ -62,11 +60,23 @@ def pickle_dumps(x, context=None): buffer_callback=buffer_callback, protocol=context.get("pickle-protocol", None) if context else None, ) + header = { + "serializer": "pickle", + "writeable": tuple(not f.readonly for f in frames[1:]), + } return header, frames def pickle_loads(header, frames): x, buffers = frames[0], frames[1:] + writeable = header["writeable"] + for i in range(len(buffers)): + mv = memoryview(buffers[i]) + if writeable[i] == mv.readonly: + if mv.readonly: + buffers[i] = memoryview(bytearray(mv)).cast(mv.format, mv.shape) + else: + buffers[i] = memoryview(bytes(mv)).cast(mv.format, mv.shape) return pickle.loads(x, buffers=buffers) @@ -374,6 +384,72 @@ def deserialize(header, frames, deserializers=None): return loads(header, frames) +def serialize_and_split(x, serializers=None, on_error="message", context=None): + """Serialize and split compressable frames + + This function is a drop-in replacement of `serialize()` that calls `serialize()` + followed by `frame_split_size()` on frames that should be compressed. + + Use `merge_and_deserialize()` to merge and deserialize the frames back. + + See Also + -------- + serialize + merge_and_deserialize + """ + header, frames = serialize(x, serializers, on_error, context) + num_sub_frames = [] + offsets = [] + out_frames = [] + out_compression = [] + for frame, compression in zip( + frames, header.get("compression") or [None] * len(frames) + ): + if compression is None: # default behavior + sub_frames = frame_split_size(frame) + num_sub_frames.append(len(sub_frames)) + offsets.append(len(out_frames)) + out_frames.extend(sub_frames) + out_compression.extend([None] * len(sub_frames)) + else: + num_sub_frames.append(1) + offsets.append(len(out_frames)) + out_frames.append(frame) + out_compression.append(compression) + assert len(out_compression) == len(out_frames) + + # Notice, in order to match msgpack's implicit convertion to tuples, + # we convert to tuples here as well. + header["split-num-sub-frames"] = tuple(num_sub_frames) + header["split-offsets"] = tuple(offsets) + header["compression"] = tuple(out_compression) + return header, out_frames + + +def merge_and_deserialize(header, frames, deserializers=None): + """Merge and deserialize frames + + This function is a drop-in replacement of `deserialize()` that merges + frames that were split by `serialize_and_split()` + + See Also + -------- + deserialize + serialize_and_split + """ + merged_frames = [] + if "split-num-sub-frames" not in header: + merged_frames = frames + else: + for n, offset in zip(header["split-num-sub-frames"], header["split-offsets"]): + if n == 1: + merged_frames.append(frames[offset]) + else: + merged_frames.append(bytearray().join(frames[offset : offset + n])) + + return deserialize(header, merged_frames, deserializers=deserializers) + + class Serialize: """Mark an object that should be serialized @@ -534,13 +610,8 @@ def replace_inner(x): def serialize_bytelist(x, **kwargs): - header, frames = serialize(x, **kwargs) - if "writeable" not in header: - header["writeable"] = tuple(map(is_writeable, frames)) - if "lengths" not in header: - header["lengths"] = tuple(map(nbytes, frames)) + header, frames = serialize_and_split(x, **kwargs) if frames: - frames = sum(map(frame_split_size, frames), []) compression, frames = zip(*map(maybe_compress, frames)) else: compression = [] @@ -566,8 +637,7 @@ def deserialize_bytes(b): else: header = {} frames = decompress(header, frames) - frames = merge_frames(header, frames) - return deserialize(header, frames) + return merge_and_deserialize(header, frames) ################################ diff --git a/distributed/protocol/tests/test_numpy.py b/distributed/protocol/tests/test_numpy.py index ea349692e70..c52e4f5b402 100644 --- a/distributed/protocol/tests/test_numpy.py +++ b/distributed/protocol/tests/test_numpy.py @@ -278,12 +278,6 @@ def test_compression_takes_advantage_of_itemsize(): assert sum(map(nbytes, aa)) < sum(map(nbytes, bb)) -def test_large_numpy_array(): - x = np.ones((100000000,), dtype="u4") - header, frames = serialize(x) - assert sum(header["lengths"]) == sum(map(nbytes, frames)) - - @pytest.mark.parametrize( "x", [ diff --git a/distributed/protocol/tests/test_protocol_utils.py b/distributed/protocol/tests/test_protocol_utils.py index 847dec1ac3d..aed16dd0146 100644 --- a/distributed/protocol/tests/test_protocol_utils.py +++ b/distributed/protocol/tests/test_protocol_utils.py @@ -1,38 +1,4 @@ -import pytest - -from distributed.protocol.utils import merge_frames, pack_frames, unpack_frames -from distributed.utils import ensure_bytes, is_writeable - - -@pytest.mark.parametrize( - "lengths,writeable,frames", - [ - ([3], [False], [b"123"]), - ([3], [True], [b"123"]), - ([3], [None], [b"123"]), - ([3], [False], [bytearray(b"123")]), - ([3], [True], [bytearray(b"123")]), - ([3], [None], [bytearray(b"123")]), - ([3, 3], [False, False], [b"123", b"456"]), - ([2, 3, 2], [False, True, None], [b"12345", b"67"]), - ([2, 3, 2], [False, True, None], [bytearray(b"12345"), bytearray(b"67")]), - ([5, 2], [False, True], [b"123", b"45", b"67"]), - ([3, 4], [None, False], [b"12", b"34", b"567"]), - ], -) -def test_merge_frames(lengths, writeable, frames): - header = {"lengths": lengths, "writeable": writeable} - result = merge_frames(header, frames) - - data = b"".join(frames) - expected = [] - for i in lengths: - expected.append(data[:i]) - data = data[i:] - - writeables = list(map(is_writeable, result)) - assert (r == e for r, e in zip(writeables, header["writeable"]) if e is not None) - assert list(map(ensure_bytes, result)) == expected +from distributed.protocol.utils import pack_frames, unpack_frames def test_pack_frames(): diff --git a/distributed/protocol/utils.py b/distributed/protocol/utils.py index 45cb1466a89..25ccce7c9f6 100644 --- a/distributed/protocol/utils.py +++ b/distributed/protocol/utils.py @@ -34,57 +34,6 @@ def frame_split_size(frame, n=BIG_BYTES_SHARD_SIZE) -> list: return [frame[i : i + items_per_shard] for i in range(0, nitems, items_per_shard)] -def merge_frames(header, frames): - """Merge frames into original lengths - - Examples - -------- - >>> merge_frames({'lengths': [3, 3]}, [b'123456']) - [b'123', b'456'] - >>> merge_frames({'lengths': [6]}, [b'123', b'456']) - [b'123456'] - """ - lengths = list(header["lengths"]) - writeables = list(header["writeable"]) - - assert len(lengths) == len(writeables) - assert sum(lengths) == sum(map(nbytes, frames)) - - if all(len(f) == l for f, l in zip(frames, lengths)): - return [ - (bytearray(f) if w else bytes(f)) if w == memoryview(f).readonly else f - for w, f in zip(header["writeable"], frames) - ] - - frames = frames[::-1] - lengths = lengths[::-1] - writeables = writeables[::-1] - - out = [] - while lengths: - l = lengths.pop() - w = writeables.pop() - L = [] - while l: - frame = frames.pop() - if nbytes(frame) <= l: - L.append(frame) - l -= nbytes(frame) - else: - frame = memoryview(frame) - L.append(frame[:l]) - frames.append(frame[l:]) - l = 0 - if len(L) == 1 and w != memoryview(L[0]).readonly: # no work necessary - out.extend(L) - elif w: - out.append(bytearray().join(L)) - else: - out.append(bytes().join(L)) - - return out - - def pack_frames_prelude(frames): nframes = len(frames) nbytes_frames = map(nbytes, frames)