Skip to content

Commit

Permalink
Serialize and split (#4541)
Browse files Browse the repository at this point in the history
* Minor clean up

* serialize numpy handles the writeable flag

* pickle handles the writeable flag

* serialize_and_split() and merge_and_deserialize()

* docstrings

* Use numpy require() to make it writeable

Co-authored-by: jakirkham <[email protected]>

* removed merge_frames()

* Removed obsolete writeable and lengths header

* use tuples to match msgpack's implicit convertion to tuples

* Make sure compression in header is extended when splitting frames

* pickle_loads(): cast shape and type

Co-authored-by: jakirkham <[email protected]>

Co-authored-by: jakirkham <[email protected]>
  • Loading branch information
madsbk and jakirkham authored Feb 26, 2021
1 parent 31119c4 commit 7f8bb81
Show file tree
Hide file tree
Showing 9 changed files with 106 additions and 148 deletions.
43 changes: 11 additions & 32 deletions distributed/protocol/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,15 @@

from .compression import compressions, maybe_compress, decompress
from .serialize import (
serialize,
deserialize,
Serialize,
Serialized,
extract_serialize,
msgpack_decode_default,
msgpack_encode_default,
merge_and_deserialize,
serialize_and_split,
)
from .utils import frame_split_size, merge_frames, msgpack_opts
from ..utils import is_writeable, nbytes

_deserialize = deserialize
from .utils import msgpack_opts


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -48,7 +45,7 @@ def dumps(msg, serializers=None, on_error="message", context=None):
}

data = {
key: serialize(
key: serialize_and_split(
value.data, serializers=serializers, on_error=on_error, context=context
)
for key, value in data.items()
Expand All @@ -60,39 +57,23 @@ def dumps(msg, serializers=None, on_error="message", context=None):
out_frames = []

for key, (head, frames) in data.items():
if "writeable" not in head:
head["writeable"] = tuple(map(is_writeable, frames))
if "lengths" not in head:
head["lengths"] = tuple(map(nbytes, frames))

# Compress frames that are not yet compressed
out_compression = []
_out_frames = []
for frame, compression in zip(
frames, head.get("compression") or [None] * len(frames)
):
if compression is None: # default behavior
_frames = frame_split_size(frame)
_compression, _frames = zip(
*[maybe_compress(frame, **compress_opts) for frame in _frames]
)
out_compression.extend(_compression)
_out_frames.extend(_frames)
else: # already specified, so pass
out_compression.append(compression)
_out_frames.append(frame)
if compression is None:
compression, frame = maybe_compress(frame, **compress_opts)

out_compression.append(compression)
out_frames.append(frame)

head["compression"] = out_compression
head["count"] = len(_out_frames)
head["count"] = len(frames)
header["headers"][key] = head
header["keys"].append(key)
out_frames.extend(_out_frames)

for key, (head, frames) in pre.items():
if "writeable" not in head:
head["writeable"] = tuple(map(is_writeable, frames))
if "lengths" not in head:
head["lengths"] = tuple(map(nbytes, frames))
head["count"] = len(frames)
header["headers"][key] = head
header["keys"].append(key)
Expand Down Expand Up @@ -146,9 +127,7 @@ def loads(frames, deserialize=True, deserializers=None):
if deserialize or key in bytestrings:
if "compression" in head:
fs = decompress(head, fs)
if not any(hasattr(f, "__cuda_array_interface__") for f in fs):
fs = merge_frames(head, fs)
value = _deserialize(head, fs, deserializers=deserializers)
value = merge_and_deserialize(head, fs, deserializers=deserializers)
else:
value = Serialized(head, fs)

Expand Down
2 changes: 0 additions & 2 deletions distributed/protocol/cupy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def cuda_serialize_cupy_ndarray(x):

header = x.__cuda_array_interface__.copy()
header["strides"] = tuple(x.strides)
header["lengths"] = [x.nbytes]
frames = [
cupy.ndarray(
shape=(x.nbytes,), dtype=cupy.dtype("u1"), memptr=x.data, strides=(1,)
Expand All @@ -47,7 +46,6 @@ def cuda_deserialize_cupy_ndarray(header, frames):
@dask_serialize.register(cupy.ndarray)
def dask_serialize_cupy_ndarray(x):
header, frames = cuda_serialize_cupy_ndarray(x)
header["writeable"] = (None,) * len(frames)
frames = [memoryview(cupy.asnumpy(f)) for f in frames]
return header, frames

Expand Down
2 changes: 0 additions & 2 deletions distributed/protocol/numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def cuda_serialize_numba_ndarray(x):

header = x.__cuda_array_interface__.copy()
header["strides"] = tuple(x.strides)
header["lengths"] = [x.nbytes]
frames = [
numba.cuda.cudadrv.devicearray.DeviceNDArray(
shape=(x.nbytes,), strides=(1,), dtype=np.dtype("u1"), gpu_data=x.gpu_data
Expand Down Expand Up @@ -51,7 +50,6 @@ def cuda_deserialize_numba_ndarray(header, frames):
@dask_serialize.register(numba.cuda.devicearray.DeviceNDArray)
def dask_serialize_numba_ndarray(x):
header, frames = cuda_serialize_numba_ndarray(x)
header["writeable"] = (None,) * len(frames)
frames = [memoryview(f.copy_to_host()) for f in frames]
return header, frames

Expand Down
18 changes: 12 additions & 6 deletions distributed/protocol/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .serialize import dask_serialize, dask_deserialize
from . import pickle

from ..utils import log_errors, nbytes
from ..utils import log_errors


def itemsize(dt):
Expand All @@ -29,7 +29,6 @@ def serialize_numpy_ndarray(x, context=None):
buffer_callback=buffer_callback,
protocol=(context or {}).get("pickle-protocol", None),
)
header["lengths"] = tuple(map(nbytes, frames))
return header, frames

# We cannot blindly pickle the dtype as some may fail pickling,
Expand Down Expand Up @@ -93,15 +92,17 @@ def serialize_numpy_ndarray(x, context=None):
# "ValueError: cannot include dtype 'M' in a buffer"
data = data.view("u%d" % math.gcd(x.dtype.itemsize, 8)).data

header = {"dtype": dt, "shape": x.shape, "strides": strides}
header = {
"dtype": dt,
"shape": x.shape,
"strides": strides,
"writeable": [x.flags.writeable],
}

if broadcast_to is not None:
header["broadcast_to"] = broadcast_to

frames = [data]

header["lengths"] = [x.nbytes]

return header, frames


Expand All @@ -112,6 +113,7 @@ def deserialize_numpy_ndarray(header, frames):
return pickle.loads(frames[0], buffers=frames[1:])

(frame,) = frames
(writeable,) = header["writeable"]

is_custom, dt = header["dtype"]
if is_custom:
Expand All @@ -125,6 +127,10 @@ def deserialize_numpy_ndarray(header, frames):
shape = header["shape"]

x = np.ndarray(shape, dtype=dt, buffer=frame, strides=header["strides"])
if not writeable:
x.flags.writeable = False
else:
x = np.require(x, requirements=["W"])

return x

Expand Down
2 changes: 0 additions & 2 deletions distributed/protocol/rmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
def cuda_serialize_rmm_device_buffer(x):
header = x.__cuda_array_interface__.copy()
header["strides"] = (1,)
header["lengths"] = [x.nbytes]
frames = [x]
return header, frames

Expand All @@ -31,7 +30,6 @@ def cuda_deserialize_rmm_device_buffer(header, frames):
@dask_serialize.register(rmm.DeviceBuffer)
def dask_serialize_rmm_device_buffer(x):
header, frames = cuda_serialize_rmm_device_buffer(x)
header["writeable"] = (None,) * len(frames)
frames = [numba.cuda.as_cuda_array(f).copy_to_host().data for f in frames]
return header, frames

Expand Down
94 changes: 82 additions & 12 deletions distributed/protocol/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,12 @@
import msgpack

from . import pickle
from ..utils import has_keyword, nbytes, typename, ensure_bytes, is_writeable
from ..utils import has_keyword, typename, ensure_bytes
from .compression import maybe_compress, decompress
from .utils import (
unpack_frames,
pack_frames_prelude,
frame_split_size,
merge_frames,
msgpack_opts,
)

Expand All @@ -30,7 +29,7 @@


def dask_dumps(x, context=None):
"""Serialise object using the class-based registry"""
"""Serialize object using the class-based registry"""
type_name = typename(type(x))
try:
dumps = dask_serialize.dispatch(type(x))
Expand All @@ -54,19 +53,30 @@ def dask_loads(header, frames):


def pickle_dumps(x, context=None):
header = {"serializer": "pickle"}
frames = [None]
buffer_callback = lambda f: frames.append(memoryview(f))
frames[0] = pickle.dumps(
x,
buffer_callback=buffer_callback,
protocol=context.get("pickle-protocol", None) if context else None,
)
header = {
"serializer": "pickle",
"writeable": tuple(not f.readonly for f in frames[1:]),
}
return header, frames


def pickle_loads(header, frames):
x, buffers = frames[0], frames[1:]
writeable = header["writeable"]
for i in range(len(buffers)):
mv = memoryview(buffers[i])
if writeable[i] == mv.readonly:
if mv.readonly:
buffers[i] = memoryview(bytearray(mv)).cast(mv.format, mv.shape)
else:
buffers[i] = memoryview(bytes(mv)).cast(mv.format, mv.shape)
return pickle.loads(x, buffers=buffers)


Expand Down Expand Up @@ -374,6 +384,72 @@ def deserialize(header, frames, deserializers=None):
return loads(header, frames)


def serialize_and_split(x, serializers=None, on_error="message", context=None):
"""Serialize and split compressable frames
This function is a drop-in replacement of `serialize()` that calls `serialize()`
followed by `frame_split_size()` on frames that should be compressed.
Use `merge_and_deserialize()` to merge and deserialize the frames back.
See Also
--------
serialize
merge_and_deserialize
"""
header, frames = serialize(x, serializers, on_error, context)
num_sub_frames = []
offsets = []
out_frames = []
out_compression = []
for frame, compression in zip(
frames, header.get("compression") or [None] * len(frames)
):
if compression is None: # default behavior
sub_frames = frame_split_size(frame)
num_sub_frames.append(len(sub_frames))
offsets.append(len(out_frames))
out_frames.extend(sub_frames)
out_compression.extend([None] * len(sub_frames))
else:
num_sub_frames.append(1)
offsets.append(len(out_frames))
out_frames.append(frame)
out_compression.append(compression)
assert len(out_compression) == len(out_frames)

# Notice, in order to match msgpack's implicit convertion to tuples,
# we convert to tuples here as well.
header["split-num-sub-frames"] = tuple(num_sub_frames)
header["split-offsets"] = tuple(offsets)
header["compression"] = tuple(out_compression)
return header, out_frames


def merge_and_deserialize(header, frames, deserializers=None):
"""Merge and deserialize frames
This function is a drop-in replacement of `deserialize()` that merges
frames that were split by `serialize_and_split()`
See Also
--------
deserialize
serialize_and_split
"""
merged_frames = []
if "split-num-sub-frames" not in header:
merged_frames = frames
else:
for n, offset in zip(header["split-num-sub-frames"], header["split-offsets"]):
if n == 1:
merged_frames.append(frames[offset])
else:
merged_frames.append(bytearray().join(frames[offset : offset + n]))

return deserialize(header, merged_frames, deserializers=deserializers)


class Serialize:
"""Mark an object that should be serialized
Expand Down Expand Up @@ -534,13 +610,8 @@ def replace_inner(x):


def serialize_bytelist(x, **kwargs):
header, frames = serialize(x, **kwargs)
if "writeable" not in header:
header["writeable"] = tuple(map(is_writeable, frames))
if "lengths" not in header:
header["lengths"] = tuple(map(nbytes, frames))
header, frames = serialize_and_split(x, **kwargs)
if frames:
frames = sum(map(frame_split_size, frames), [])
compression, frames = zip(*map(maybe_compress, frames))
else:
compression = []
Expand All @@ -566,8 +637,7 @@ def deserialize_bytes(b):
else:
header = {}
frames = decompress(header, frames)
frames = merge_frames(header, frames)
return deserialize(header, frames)
return merge_and_deserialize(header, frames)


################################
Expand Down
6 changes: 0 additions & 6 deletions distributed/protocol/tests/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,12 +278,6 @@ def test_compression_takes_advantage_of_itemsize():
assert sum(map(nbytes, aa)) < sum(map(nbytes, bb))


def test_large_numpy_array():
x = np.ones((100000000,), dtype="u4")
header, frames = serialize(x)
assert sum(header["lengths"]) == sum(map(nbytes, frames))


@pytest.mark.parametrize(
"x",
[
Expand Down
Loading

0 comments on commit 7f8bb81

Please sign in to comment.