From 98efdc40a49dab2d9dc63569e291b215263d467d Mon Sep 17 00:00:00 2001 From: crusaderky Date: Sat, 21 Oct 2023 02:09:10 +0100 Subject: [PATCH] WIP: Zero copy numpy shuffle --- distributed/protocol/__init__.py | 1 + distributed/protocol/serialize.py | 32 +++-- .../protocol/tests/test_protocol_utils.py | 15 ++- distributed/protocol/tests/test_serialize.py | 56 ++++++++- distributed/protocol/utils.py | 16 ++- distributed/shuffle/_buffer.py | 7 +- distributed/shuffle/_comms.py | 4 +- distributed/shuffle/_core.py | 10 +- distributed/shuffle/_disk.py | 29 +++-- distributed/shuffle/_memory.py | 12 +- distributed/shuffle/_rechunk.py | 117 +++++++++--------- distributed/shuffle/_shuffle.py | 2 +- 12 files changed, 202 insertions(+), 99 deletions(-) diff --git a/distributed/protocol/__init__.py b/distributed/protocol/__init__.py index ad43feb4937..9e6361e00d0 100644 --- a/distributed/protocol/__init__.py +++ b/distributed/protocol/__init__.py @@ -12,6 +12,7 @@ dask_serialize, deserialize, deserialize_bytes, + deserialize_bytestream, nested_deserialize, register_generic, register_serialization, diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index ee41264a27b..0f8150551d1 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -4,6 +4,7 @@ import importlib import traceback from array import array +from collections.abc import Iterator from enum import Enum from functools import partial from types import ModuleType @@ -680,20 +681,31 @@ def serialize_bytelist( return frames2 -def serialize_bytes(x, **kwargs): +def serialize_bytes(x: object, **kwargs: Any) -> bytes: L = serialize_bytelist(x, **kwargs) return b"".join(L) -def deserialize_bytes(b): - frames = unpack_frames(b) - header, frames = frames[0], frames[1:] - if header: - header = msgpack.loads(header, raw=False, use_list=False) - else: - header = {} - frames = decompress(header, frames) - return merge_and_deserialize(header, frames) +def deserialize_bytestream(b: bytes | bytearray | memoryview) -> Iterator[Any]: + """Deserialize the concatenated output of multiple calls to :func:`serialize_bytes`""" + while True: + frames = unpack_frames(b, remainder=True) + bin_header, frames, remainder = frames[0], frames[1:-1], frames[-1] + if bin_header: + header = msgpack.loads(bin_header, raw=False, use_list=False) + else: + header = {} + frames2 = decompress(header, frames) + yield merge_and_deserialize(header, frames2) + + if remainder.nbytes == 0: + break + b = remainder + + +def deserialize_bytes(b: bytes | bytearray | memoryview) -> Any: + """Deserialize the output of a single call to :func:`serialize_bytes`""" + return next(deserialize_bytestream(b)) ################################ diff --git a/distributed/protocol/tests/test_protocol_utils.py b/distributed/protocol/tests/test_protocol_utils.py index 3d9b1b51df8..a64362b3872 100644 --- a/distributed/protocol/tests/test_protocol_utils.py +++ b/distributed/protocol/tests/test_protocol_utils.py @@ -10,8 +10,21 @@ def test_pack_frames(): b = pack_frames(frames) assert isinstance(b, bytes) frames2 = unpack_frames(b) + assert frames2 == frames - assert frames == frames2 + +@pytest.mark.parametrize("extra", [b"456", b""]) +def test_unpack_frames_remainder(extra): + frames = [b"123", b"asdf"] + b = pack_frames(frames) + assert isinstance(b, bytes) + + frames2 = unpack_frames(b + extra) + assert frames2 == frames + + frames2 = unpack_frames(b + extra, remainder=True) + assert isinstance(frames2[-1], memoryview) + assert frames2 == frames + [extra] class TestMergeMemroyviews: diff --git a/distributed/protocol/tests/test_serialize.py b/distributed/protocol/tests/test_serialize.py index 975a1d55c85..1b855264fa7 100644 --- a/distributed/protocol/tests/test_serialize.py +++ b/distributed/protocol/tests/test_serialize.py @@ -23,6 +23,7 @@ dask_serialize, deserialize, deserialize_bytes, + deserialize_bytestream, dumps, loads, nested_deserialize, @@ -265,13 +266,11 @@ def test_empty_loads_deep(): assert isinstance(e2[0][0][0], Empty) -@pytest.mark.skipif(np is None, reason="Test needs numpy") @pytest.mark.parametrize("kwargs", [{}, {"serializers": ["pickle"]}]) def test_serialize_bytes(kwargs): for x in [ 1, "abc", - np.arange(5), b"ab" * int(40e6), int(2**26) * b"ab", (int(2**25) * b"ab", int(2**25) * b"ab"), @@ -279,7 +278,58 @@ def test_serialize_bytes(kwargs): b = serialize_bytes(x, **kwargs) assert isinstance(b, bytes) y = deserialize_bytes(b) - assert str(x) == str(y) + assert x == y + + +@pytest.mark.skipif(np is None, reason="Test needs numpy") +@pytest.mark.parametrize("kwargs", [{}, {"serializers": ["pickle"]}]) +def test_serialize_bytes_numpy(kwargs): + x = np.arange(5) + b = serialize_bytes(x, **kwargs) + assert isinstance(b, bytes) + y = deserialize_bytes(b) + assert (x == y).all() + + +@pytest.mark.skipif(np is None, reason="Test needs numpy") +def test_deserialize_bytes_zero_copy_read_only(): + x = np.arange(5) + x.setflags(write=False) + blob = serialize_bytes(x, compression=False) + x2 = deserialize_bytes(blob) + x3 = deserialize_bytes(blob) + addr2 = x2.__array_interface__["data"][0] + addr3 = x3.__array_interface__["data"][0] + assert addr2 == addr3 + + +@pytest.mark.skipif(np is None, reason="Test needs numpy") +def test_deserialize_bytes_zero_copy_writeable(): + x = np.arange(5) + blob = bytearray(serialize_bytes(x, compression=False)) + x2 = deserialize_bytes(blob) + x3 = deserialize_bytes(blob) + x2[0] = 123 + assert x3[0] == 123 + + +@pytest.mark.parametrize("kwargs", [{}, {"serializers": ["pickle"]}]) +def test_deserialize_bytestream(kwargs): + objs = [1, "abc", b"abc"] + blob = b"".join(serialize_bytes(obj, **kwargs) for obj in objs) + objs2 = list(deserialize_bytestream(blob)) + assert objs == objs2 + + +@pytest.mark.skipif(np is None, reason="Test needs numpy") +@pytest.mark.parametrize("kwargs", [{}, {"serializers": ["pickle"]}]) +def test_deserialize_bytestream_numpy(kwargs): + x = np.arange(5) + y = np.arange(3, 8) + blob = serialize_bytes(x, **kwargs) + serialize_bytes(y, **kwargs) + x2, y2 = deserialize_bytestream(blob) + assert (x2 == x).all() + assert (y2 == y).all() @pytest.mark.skipif(np is None, reason="Test needs numpy") diff --git a/distributed/protocol/utils.py b/distributed/protocol/utils.py index e7d4b0f75c0..ff3f0e1befe 100644 --- a/distributed/protocol/utils.py +++ b/distributed/protocol/utils.py @@ -61,12 +61,23 @@ def pack_frames(frames): return b"".join([pack_frames_prelude(frames), *frames]) -def unpack_frames(b): +def unpack_frames( + b: bytes | bytearray | memoryview, *, remainder: bool = False +) -> list[memoryview]: """Unpack bytes into a sequence of frames This assumes that length information is at the front of the bytestring, as performed by pack_frames + Parameters + ---------- + b: + packed frames, as returned by :func:`pack_frames` + remainder: + if True, return one extra frame at the end which is the continuation of a + stream created by concatenating multiple calls to :func:`pack_frames`. + This last frame will be empty at the end of the stream. + See Also -------- pack_frames @@ -86,6 +97,9 @@ def unpack_frames(b): frames.append(b[start:end]) start = end + if remainder: + frames.append(b[start:]) + return frames diff --git a/distributed/shuffle/_buffer.py b/distributed/shuffle/_buffer.py index b0d24ace022..43679f50dc8 100644 --- a/distributed/shuffle/_buffer.py +++ b/distributed/shuffle/_buffer.py @@ -45,6 +45,7 @@ class ShardsBuffer(Generic[ShardType]): shards: defaultdict[str, _List[ShardType]] sizes: defaultdict[str, int] + sizes_detail: defaultdict[str, _List[int]] concurrency_limit: int memory_limiter: ResourceLimiter diagnostics: dict[str, float] @@ -71,6 +72,7 @@ def __init__( self._accepts_input = True self.shards = defaultdict(_List) self.sizes = defaultdict(int) + self.sizes_detail = defaultdict(_List) self._exception = None self.concurrency_limit = concurrency_limit self._inputs_done = False @@ -149,7 +151,7 @@ def _continue() -> bool: try: shard = self.shards[part_id].pop() shards.append(shard) - s = sizeof(shard) + s = self.sizes_detail[part_id].pop() size += s self.sizes[part_id] -= s except IndexError: @@ -159,6 +161,8 @@ def _continue() -> bool: del self.shards[part_id] assert not self.sizes[part_id] del self.sizes[part_id] + assert not self.sizes_detail[part_id] + del self.sizes_detail[part_id] else: shards = self.shards.pop(part_id) size = self.sizes.pop(part_id) @@ -201,6 +205,7 @@ async def write(self, data: dict[str, ShardType]) -> None: async with self._shards_available: for worker, shard in data.items(): self.shards[worker].append(shard) + self.sizes_detail[worker].append(sizes[worker]) self.sizes[worker] += sizes[worker] self._shards_available.notify() await self.memory_limiter.wait_for_available() diff --git a/distributed/shuffle/_comms.py b/distributed/shuffle/_comms.py index 020313debe6..6dac1a3d30a 100644 --- a/distributed/shuffle/_comms.py +++ b/distributed/shuffle/_comms.py @@ -52,7 +52,7 @@ class CommShardsBuffer(ShardsBuffer): def __init__( self, - send: Callable[[str, list[tuple[Any, bytes]]], Awaitable[None]], + send: Callable[[str, list[tuple[Any, Any]]], Awaitable[None]], memory_limiter: ResourceLimiter, concurrency_limit: int = 10, ): @@ -63,7 +63,7 @@ def __init__( ) self.send = send - async def _process(self, address: str, shards: list[tuple[Any, bytes]]) -> None: + async def _process(self, address: str, shards: list[tuple[Any, Any]]) -> None: """Send one message off to a neighboring worker""" with log_errors(): # Consider boosting total_size a bit here to account for duplication diff --git a/distributed/shuffle/_core.py b/distributed/shuffle/_core.py index c09d5f04363..0e0dda3dcf6 100644 --- a/distributed/shuffle/_core.py +++ b/distributed/shuffle/_core.py @@ -175,12 +175,12 @@ def heartbeat(self) -> dict[str, Any]: } async def _write_to_comm( - self, data: dict[str, tuple[_T_partition_id, bytes]] + self, data: dict[str, tuple[_T_partition_id, Any]] ) -> None: self.raise_if_closed() await self._comm_buffer.write(data) - async def _write_to_disk(self, data: dict[NDIndex, bytes]) -> None: + async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None: self.raise_if_closed() await self._disk_buffer.write( {"_".join(str(i) for i in k): v for k, v in data.items()} @@ -228,7 +228,7 @@ def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing self.raise_if_closed() return self._disk_buffer.read("_".join(str(i) for i in id)) - async def receive(self, data: list[tuple[_T_partition_id, bytes]]) -> None: + async def receive(self, data: list[tuple[_T_partition_id, Any]]) -> None: await self._receive(data) async def _ensure_output_worker(self, i: _T_partition_id, key: str) -> None: @@ -248,7 +248,7 @@ def _get_assigned_worker(self, i: _T_partition_id) -> str: """Get the address of the worker assigned to the output partition""" @abc.abstractmethod - async def _receive(self, data: list[tuple[_T_partition_id, bytes]]) -> None: + async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None: """Receive shards belonging to output partitions of this shuffle run""" async def add_partition( @@ -286,7 +286,7 @@ def read(self, path: Path) -> tuple[Any, int]: """Read shards from disk""" @abc.abstractmethod - def deserialize(self, buffer: bytes) -> Any: + def deserialize(self, buffer: Any) -> Any: """Deserialize shards""" diff --git a/distributed/shuffle/_disk.py b/distributed/shuffle/_disk.py index 87fea2cb99f..e9f0bd8c891 100644 --- a/distributed/shuffle/_disk.py +++ b/distributed/shuffle/_disk.py @@ -4,10 +4,13 @@ import pathlib import shutil import threading -from collections.abc import Generator +from collections.abc import Callable, Generator, Iterable from contextlib import contextmanager -from typing import Any, Callable +from typing import Any +from toolz import concat + +from distributed.protocol import serialize_bytelist from distributed.shuffle._buffer import ShardsBuffer from distributed.shuffle._limiter import ResourceLimiter from distributed.utils import Deadline, log_errors @@ -135,7 +138,7 @@ def __init__( self._read = read self._directory_lock = ReadWriteLock() - async def _process(self, id: str, shards: list[bytes]) -> None: + async def _process(self, id: str, shards: list[Any]) -> None: """Write one buffer to file This function was built to offload the disk IO, but since then we've @@ -157,11 +160,21 @@ async def _process(self, id: str, shards: list[bytes]) -> None: with self._directory_lock.read(): if self._closed: raise RuntimeError("Already closed") - with open( - self.directory / str(id), mode="ab", buffering=100_000_000 - ) as f: - for shard in shards: - f.write(shard) + + frames: Iterable[bytes | bytearray | memoryview] + + if not shards or isinstance(shards[0], bytes): + # Manually serialized dataframes + frames = shards + else: + # Unserialized numpy arrays + frames = concat( + serialize_bytelist(shard, compression=False) + for shard in shards + ) + + with open(self.directory / str(id), mode="ab") as f: + f.writelines(frames) def read(self, id: str) -> Any: """Read a complete file back into memory""" diff --git a/distributed/shuffle/_memory.py b/distributed/shuffle/_memory.py index 27c00dac909..ac523b8fe4c 100644 --- a/distributed/shuffle/_memory.py +++ b/distributed/shuffle/_memory.py @@ -11,17 +11,15 @@ class MemoryShardsBuffer(ShardsBuffer): - _deserialize: Callable[[bytes], Any] - _shards: defaultdict[str, deque[bytes]] + _deserialize: Callable[[Any], Any] + _shards: defaultdict[str, deque[Any]] - def __init__(self, deserialize: Callable[[bytes], Any]) -> None: - super().__init__( - memory_limiter=ResourceLimiter(None), - ) + def __init__(self, deserialize: Callable[[Any], Any]) -> None: + super().__init__(memory_limiter=ResourceLimiter(None)) self._deserialize = deserialize self._shards = defaultdict(deque) - async def _process(self, id: str, shards: list[bytes]) -> None: + async def _process(self, id: str, shards: list[Any]) -> None: # TODO: This can be greatly simplified, there's no need for # background threads at all. with log_errors(): diff --git a/distributed/shuffle/_rechunk.py b/distributed/shuffle/_rechunk.py index 8fb8172508d..cfddb6ddeee 100644 --- a/distributed/shuffle/_rechunk.py +++ b/distributed/shuffle/_rechunk.py @@ -96,8 +96,8 @@ from __future__ import annotations +import mmap import os -import pickle from collections import defaultdict from collections.abc import Callable, Sequence from concurrent.futures import ThreadPoolExecutor @@ -111,6 +111,7 @@ from dask.highlevelgraph import HighLevelGraph, MaterializedLayer from distributed.core import PooledRPCCall +from distributed.protocol import deserialize_bytestream from distributed.shuffle._core import ( NDIndex, ShuffleId, @@ -279,6 +280,9 @@ def convert_chunk(shards: list[list[tuple[NDIndex, np.ndarray]]]) -> np.ndarray: for index, shard in indexed.items(): rec_cat_arg[tuple(index)] = shard arrs = rec_cat_arg.tolist() + + # This may block for several seconds, as it physically reads the memory-mapped + # buffers from disk return concatenate3(arrs) @@ -360,93 +364,86 @@ def __init__( self.worker_for = worker_for self.split_axes = split_axes(old, new) - async def _receive(self, data: list[tuple[NDIndex, bytes]]) -> None: + async def _receive( + self, + data: list[tuple[NDIndex, list[tuple[NDIndex, tuple[NDIndex, np.ndarray]]]]], + ) -> None: self.raise_if_closed() - filtered = [] + # Repartition shards and filter out already received ones + shards = defaultdict(list) for d in data: - id, payload = d - if id in self.received: + id1, payload = d + if id1 in self.received: continue - filtered.append(payload) - self.received.add(id) + self.received.add(id1) + for id2, shard in payload: + shards[id2].append(shard) self.total_recvd += sizeof(d) del data - if not filtered: + if not shards: return + try: - shards = await self.offload(self._repartition_shards, filtered) - del filtered await self._write_to_disk(shards) except Exception as e: self._exception = e raise - def _repartition_shards(self, data: list[bytes]) -> dict[NDIndex, bytes]: - repartitioned: defaultdict[ - NDIndex, list[tuple[NDIndex, np.ndarray]] - ] = defaultdict(list) - for buffer in data: - for id, shard in pickle.loads(buffer): - repartitioned[id].append(shard) - return {k: pickle.dumps(v) for k, v in repartitioned.items()} - async def _add_partition( self, data: np.ndarray, partition_id: NDIndex, **kwargs: Any ) -> int: - def _() -> dict[str, tuple[NDIndex, bytes]]: - """Return a mapping of worker addresses to a tuple of input partition - IDs and shard data. - - - TODO: Overhaul! - As shard data, we serialize the payload together with the sub-index of the - slice within the new chunk. To assemble the new chunk from its shards, it - needs the sub-index to know where each shard belongs within the chunk. - Adding the sub-index into the serialized payload on the sender allows us to - write the serialized payload directly to disk on the receiver. - """ - out: dict[ - str, list[tuple[NDIndex, tuple[NDIndex, np.ndarray]]] - ] = defaultdict(list) - from itertools import product - - ndsplits = product( - *(axis[i] for axis, i in zip(self.split_axes, partition_id)) + out = defaultdict(list) + ndsplits = product(*(axis[i] for axis, i in zip(self.split_axes, partition_id))) + for ndsplit in ndsplits: + chunk_index, shard_index, ndslice = zip(*ndsplit) + out[self.worker_for[chunk_index]].append( + (chunk_index, (shard_index, data[ndslice])) ) + out2 = {k: (partition_id, v) for k, v in out.items()} - for ndsplit in ndsplits: - chunk_index, shard_index, ndslice = zip(*ndsplit) - out[self.worker_for[chunk_index]].append( - (chunk_index, (shard_index, data[ndslice])) - ) - return {k: (partition_id, pickle.dumps(v)) for k, v in out.items()} - - out = await self.offload(_) - await self._write_to_comm(out) + await self._write_to_comm(out2) return self.run_id async def _get_output_partition( self, partition_id: NDIndex, key: str, **kwargs: Any ) -> np.ndarray: def _(partition_id: NDIndex) -> np.ndarray: + # Quickly read metadata from disk. + # This is a bunch of seek()'s interleaved with short reads. data = self._read_from_disk(partition_id) - return convert_chunk(data) + # Copy the memory-mapped buffers on disk onto memory. + # This is where we'll spend most time. + with self._disk_buffer.time("read"): + return convert_chunk(data) return await self.offload(_, partition_id) - def deserialize(self, buffer: bytes) -> Any: - result = pickle.loads(buffer) - return result - - def read(self, path: Path) -> tuple[Any, int]: - shards: list[list[tuple[NDIndex, np.ndarray]]] = [] - with path.open(mode="rb") as f: - size = f.seek(0, os.SEEK_END) - f.seek(0) - while f.tell() < size: - shards.append(pickle.load(f)) - return shards, size + def deserialize(self, buffer: Any) -> Any: + return buffer + + def read(self, path: Path) -> tuple[list[list[tuple[NDIndex, np.ndarray]]], int]: + """Open a memory-mapped file descriptor to disk, read all metadata, and unpickle + all arrays. This is a fast sequence of short reads interleaved with seeks. + Do not read in memory the actual data; the arrays' buffers will point to the + memory-mapped area. + + The file descriptor will be automatically closed by the kernel when all the + returned arrays are dereferenced, which will happen after the call to + concatenate3. + """ + # distributed.protocol.numpy.deserialize_numpy_ndarray makes sure that, if a + # numpy array was writeable before serialization, remains writeable afterwards. + # If it receives a read-only buffer, it performs an expensive deep-copy. + # Note that this is a dask-specific feature; vanilla pickle.loads will instead + # return an array with flags.writeable=False. + # + # See also: zict.file.File.__getitem__ + with path.open(mode="r+b") as fh: + buffer = memoryview(mmap.mmap(fh.fileno(), 0)) + + shards = list(deserialize_bytestream(buffer)) + return shards, buffer.nbytes def _get_assigned_worker(self, id: NDIndex) -> str: return self.worker_for[id] diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index e9d76dadf05..aaa3d62084d 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -527,7 +527,7 @@ def _get_assigned_worker(self, id: int) -> str: def read(self, path: Path) -> tuple[pa.Table, int]: return read_from_disk(path) - def deserialize(self, buffer: bytes) -> Any: + def deserialize(self, buffer: Any) -> Any: return deserialize_table(buffer)