Skip to content

Commit

Permalink
WIP: Zero copy numpy shuffle
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Oct 21, 2023
1 parent b4eee3f commit 98efdc4
Show file tree
Hide file tree
Showing 12 changed files with 202 additions and 99 deletions.
1 change: 1 addition & 0 deletions distributed/protocol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
dask_serialize,
deserialize,
deserialize_bytes,
deserialize_bytestream,
nested_deserialize,
register_generic,
register_serialization,
Expand Down
32 changes: 22 additions & 10 deletions distributed/protocol/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import importlib
import traceback
from array import array
from collections.abc import Iterator
from enum import Enum
from functools import partial
from types import ModuleType
Expand Down Expand Up @@ -680,20 +681,31 @@ def serialize_bytelist(
return frames2


def serialize_bytes(x, **kwargs):
def serialize_bytes(x: object, **kwargs: Any) -> bytes:
L = serialize_bytelist(x, **kwargs)
return b"".join(L)


def deserialize_bytes(b):
frames = unpack_frames(b)
header, frames = frames[0], frames[1:]
if header:
header = msgpack.loads(header, raw=False, use_list=False)
else:
header = {}
frames = decompress(header, frames)
return merge_and_deserialize(header, frames)
def deserialize_bytestream(b: bytes | bytearray | memoryview) -> Iterator[Any]:
"""Deserialize the concatenated output of multiple calls to :func:`serialize_bytes`"""
while True:
frames = unpack_frames(b, remainder=True)
bin_header, frames, remainder = frames[0], frames[1:-1], frames[-1]
if bin_header:
header = msgpack.loads(bin_header, raw=False, use_list=False)
else:
header = {}
frames2 = decompress(header, frames)
yield merge_and_deserialize(header, frames2)

if remainder.nbytes == 0:
break
b = remainder


def deserialize_bytes(b: bytes | bytearray | memoryview) -> Any:
"""Deserialize the output of a single call to :func:`serialize_bytes`"""
return next(deserialize_bytestream(b))


################################
Expand Down
15 changes: 14 additions & 1 deletion distributed/protocol/tests/test_protocol_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,21 @@ def test_pack_frames():
b = pack_frames(frames)
assert isinstance(b, bytes)
frames2 = unpack_frames(b)
assert frames2 == frames

assert frames == frames2

@pytest.mark.parametrize("extra", [b"456", b""])
def test_unpack_frames_remainder(extra):
frames = [b"123", b"asdf"]
b = pack_frames(frames)
assert isinstance(b, bytes)

frames2 = unpack_frames(b + extra)
assert frames2 == frames

frames2 = unpack_frames(b + extra, remainder=True)
assert isinstance(frames2[-1], memoryview)
assert frames2 == frames + [extra]


class TestMergeMemroyviews:
Expand Down
56 changes: 53 additions & 3 deletions distributed/protocol/tests/test_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
dask_serialize,
deserialize,
deserialize_bytes,
deserialize_bytestream,
dumps,
loads,
nested_deserialize,
Expand Down Expand Up @@ -265,21 +266,70 @@ def test_empty_loads_deep():
assert isinstance(e2[0][0][0], Empty)


@pytest.mark.skipif(np is None, reason="Test needs numpy")
@pytest.mark.parametrize("kwargs", [{}, {"serializers": ["pickle"]}])
def test_serialize_bytes(kwargs):
for x in [
1,
"abc",
np.arange(5),
b"ab" * int(40e6),
int(2**26) * b"ab",
(int(2**25) * b"ab", int(2**25) * b"ab"),
]:
b = serialize_bytes(x, **kwargs)
assert isinstance(b, bytes)
y = deserialize_bytes(b)
assert str(x) == str(y)
assert x == y


@pytest.mark.skipif(np is None, reason="Test needs numpy")
@pytest.mark.parametrize("kwargs", [{}, {"serializers": ["pickle"]}])
def test_serialize_bytes_numpy(kwargs):
x = np.arange(5)
b = serialize_bytes(x, **kwargs)
assert isinstance(b, bytes)
y = deserialize_bytes(b)
assert (x == y).all()


@pytest.mark.skipif(np is None, reason="Test needs numpy")
def test_deserialize_bytes_zero_copy_read_only():
x = np.arange(5)
x.setflags(write=False)
blob = serialize_bytes(x, compression=False)
x2 = deserialize_bytes(blob)
x3 = deserialize_bytes(blob)
addr2 = x2.__array_interface__["data"][0]
addr3 = x3.__array_interface__["data"][0]
assert addr2 == addr3


@pytest.mark.skipif(np is None, reason="Test needs numpy")
def test_deserialize_bytes_zero_copy_writeable():
x = np.arange(5)
blob = bytearray(serialize_bytes(x, compression=False))
x2 = deserialize_bytes(blob)
x3 = deserialize_bytes(blob)
x2[0] = 123
assert x3[0] == 123


@pytest.mark.parametrize("kwargs", [{}, {"serializers": ["pickle"]}])
def test_deserialize_bytestream(kwargs):
objs = [1, "abc", b"abc"]
blob = b"".join(serialize_bytes(obj, **kwargs) for obj in objs)
objs2 = list(deserialize_bytestream(blob))
assert objs == objs2


@pytest.mark.skipif(np is None, reason="Test needs numpy")
@pytest.mark.parametrize("kwargs", [{}, {"serializers": ["pickle"]}])
def test_deserialize_bytestream_numpy(kwargs):
x = np.arange(5)
y = np.arange(3, 8)
blob = serialize_bytes(x, **kwargs) + serialize_bytes(y, **kwargs)
x2, y2 = deserialize_bytestream(blob)
assert (x2 == x).all()
assert (y2 == y).all()


@pytest.mark.skipif(np is None, reason="Test needs numpy")
Expand Down
16 changes: 15 additions & 1 deletion distributed/protocol/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,23 @@ def pack_frames(frames):
return b"".join([pack_frames_prelude(frames), *frames])


def unpack_frames(b):
def unpack_frames(
b: bytes | bytearray | memoryview, *, remainder: bool = False
) -> list[memoryview]:
"""Unpack bytes into a sequence of frames
This assumes that length information is at the front of the bytestring,
as performed by pack_frames
Parameters
----------
b:
packed frames, as returned by :func:`pack_frames`
remainder:
if True, return one extra frame at the end which is the continuation of a
stream created by concatenating multiple calls to :func:`pack_frames`.
This last frame will be empty at the end of the stream.
See Also
--------
pack_frames
Expand All @@ -86,6 +97,9 @@ def unpack_frames(b):
frames.append(b[start:end])
start = end

if remainder:
frames.append(b[start:])

return frames


Expand Down
7 changes: 6 additions & 1 deletion distributed/shuffle/_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class ShardsBuffer(Generic[ShardType]):

shards: defaultdict[str, _List[ShardType]]
sizes: defaultdict[str, int]
sizes_detail: defaultdict[str, _List[int]]
concurrency_limit: int
memory_limiter: ResourceLimiter
diagnostics: dict[str, float]
Expand All @@ -71,6 +72,7 @@ def __init__(
self._accepts_input = True
self.shards = defaultdict(_List)
self.sizes = defaultdict(int)
self.sizes_detail = defaultdict(_List)
self._exception = None
self.concurrency_limit = concurrency_limit
self._inputs_done = False
Expand Down Expand Up @@ -149,7 +151,7 @@ def _continue() -> bool:
try:
shard = self.shards[part_id].pop()
shards.append(shard)
s = sizeof(shard)
s = self.sizes_detail[part_id].pop()
size += s
self.sizes[part_id] -= s
except IndexError:
Expand All @@ -159,6 +161,8 @@ def _continue() -> bool:
del self.shards[part_id]
assert not self.sizes[part_id]
del self.sizes[part_id]
assert not self.sizes_detail[part_id]
del self.sizes_detail[part_id]
else:
shards = self.shards.pop(part_id)
size = self.sizes.pop(part_id)
Expand Down Expand Up @@ -201,6 +205,7 @@ async def write(self, data: dict[str, ShardType]) -> None:
async with self._shards_available:
for worker, shard in data.items():
self.shards[worker].append(shard)
self.sizes_detail[worker].append(sizes[worker])
self.sizes[worker] += sizes[worker]
self._shards_available.notify()
await self.memory_limiter.wait_for_available()
Expand Down
4 changes: 2 additions & 2 deletions distributed/shuffle/_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class CommShardsBuffer(ShardsBuffer):

def __init__(
self,
send: Callable[[str, list[tuple[Any, bytes]]], Awaitable[None]],
send: Callable[[str, list[tuple[Any, Any]]], Awaitable[None]],
memory_limiter: ResourceLimiter,
concurrency_limit: int = 10,
):
Expand All @@ -63,7 +63,7 @@ def __init__(
)
self.send = send

async def _process(self, address: str, shards: list[tuple[Any, bytes]]) -> None:
async def _process(self, address: str, shards: list[tuple[Any, Any]]) -> None:
"""Send one message off to a neighboring worker"""
with log_errors():
# Consider boosting total_size a bit here to account for duplication
Expand Down
10 changes: 5 additions & 5 deletions distributed/shuffle/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,12 +175,12 @@ def heartbeat(self) -> dict[str, Any]:
}

async def _write_to_comm(
self, data: dict[str, tuple[_T_partition_id, bytes]]
self, data: dict[str, tuple[_T_partition_id, Any]]
) -> None:
self.raise_if_closed()
await self._comm_buffer.write(data)

async def _write_to_disk(self, data: dict[NDIndex, bytes]) -> None:
async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None:
self.raise_if_closed()
await self._disk_buffer.write(
{"_".join(str(i) for i in k): v for k, v in data.items()}
Expand Down Expand Up @@ -228,7 +228,7 @@ def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing
self.raise_if_closed()
return self._disk_buffer.read("_".join(str(i) for i in id))

async def receive(self, data: list[tuple[_T_partition_id, bytes]]) -> None:
async def receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
await self._receive(data)

async def _ensure_output_worker(self, i: _T_partition_id, key: str) -> None:
Expand All @@ -248,7 +248,7 @@ def _get_assigned_worker(self, i: _T_partition_id) -> str:
"""Get the address of the worker assigned to the output partition"""

@abc.abstractmethod
async def _receive(self, data: list[tuple[_T_partition_id, bytes]]) -> None:
async def _receive(self, data: list[tuple[_T_partition_id, Any]]) -> None:
"""Receive shards belonging to output partitions of this shuffle run"""

async def add_partition(
Expand Down Expand Up @@ -286,7 +286,7 @@ def read(self, path: Path) -> tuple[Any, int]:
"""Read shards from disk"""

@abc.abstractmethod
def deserialize(self, buffer: bytes) -> Any:
def deserialize(self, buffer: Any) -> Any:
"""Deserialize shards"""


Expand Down
29 changes: 21 additions & 8 deletions distributed/shuffle/_disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
import pathlib
import shutil
import threading
from collections.abc import Generator
from collections.abc import Callable, Generator, Iterable
from contextlib import contextmanager
from typing import Any, Callable
from typing import Any

from toolz import concat

from distributed.protocol import serialize_bytelist
from distributed.shuffle._buffer import ShardsBuffer
from distributed.shuffle._limiter import ResourceLimiter
from distributed.utils import Deadline, log_errors
Expand Down Expand Up @@ -135,7 +138,7 @@ def __init__(
self._read = read
self._directory_lock = ReadWriteLock()

async def _process(self, id: str, shards: list[bytes]) -> None:
async def _process(self, id: str, shards: list[Any]) -> None:
"""Write one buffer to file
This function was built to offload the disk IO, but since then we've
Expand All @@ -157,11 +160,21 @@ async def _process(self, id: str, shards: list[bytes]) -> None:
with self._directory_lock.read():
if self._closed:
raise RuntimeError("Already closed")
with open(
self.directory / str(id), mode="ab", buffering=100_000_000
) as f:
for shard in shards:
f.write(shard)

frames: Iterable[bytes | bytearray | memoryview]

if not shards or isinstance(shards[0], bytes):
# Manually serialized dataframes
frames = shards
else:
# Unserialized numpy arrays
frames = concat(
serialize_bytelist(shard, compression=False)
for shard in shards
)

with open(self.directory / str(id), mode="ab") as f:
f.writelines(frames)

def read(self, id: str) -> Any:
"""Read a complete file back into memory"""
Expand Down
12 changes: 5 additions & 7 deletions distributed/shuffle/_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,15 @@


class MemoryShardsBuffer(ShardsBuffer):
_deserialize: Callable[[bytes], Any]
_shards: defaultdict[str, deque[bytes]]
_deserialize: Callable[[Any], Any]
_shards: defaultdict[str, deque[Any]]

def __init__(self, deserialize: Callable[[bytes], Any]) -> None:
super().__init__(
memory_limiter=ResourceLimiter(None),
)
def __init__(self, deserialize: Callable[[Any], Any]) -> None:
super().__init__(memory_limiter=ResourceLimiter(None))
self._deserialize = deserialize
self._shards = defaultdict(deque)

async def _process(self, id: str, shards: list[bytes]) -> None:
async def _process(self, id: str, shards: list[Any]) -> None:
# TODO: This can be greatly simplified, there's no need for
# background threads at all.
with log_errors():
Expand Down
Loading

0 comments on commit 98efdc4

Please sign in to comment.