Skip to content

Commit

Permalink
Pickle WorkerState (#6623)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky authored Jun 25, 2022
1 parent 24c4d44 commit 88e1fe0
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 14 deletions.
16 changes: 16 additions & 0 deletions distributed/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,22 @@ def __init__(self, *, key: Callable[[T], Any]):
def __repr__(self) -> str:
return f"<{type(self).__name__}: {len(self)} items>"

def __reduce__(self) -> tuple[Callable, tuple]:
heap = [(k, i, v) for k, i, vref in self._heap if (v := vref()) in self._data]
return HeapSet._unpickle, (self.key, self._inc, heap)

@staticmethod
def _unpickle(
key: Callable[[T], Any], inc: int, heap: list[tuple[Any, int, T]]
) -> HeapSet[T]:
self = object.__new__(HeapSet)
self.key = key # type: ignore
self._data = {v for _, _, v in heap}
self._inc = inc
self._heap = [(k, i, weakref.ref(v)) for k, i, v in heap]
heapq.heapify(self._heap)
return self

def __contains__(self, value: object) -> bool:
return value in self._data

Expand Down
51 changes: 41 additions & 10 deletions distributed/tests/test_collections.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from __future__ import annotations

import operator
import pickle
import random

import pytest

from distributed.collections import LRU, HeapSet
Expand All @@ -22,19 +26,20 @@ def test_lru():
assert list(l.keys()) == ["c", "a", "d"]


def test_heapset():
class C:
def __init__(self, k, i):
self.k = k
self.i = i
class C:
def __init__(self, k, i):
self.k = k
self.i = i

def __hash__(self):
return hash(self.k)
def __hash__(self):
return hash(self.k)

def __eq__(self, other):
return isinstance(other, C) and other.k == self.k
def __eq__(self, other):
return isinstance(other, C) and other.k == self.k

heap = HeapSet(key=lambda c: c.i)

def test_heapset():
heap = HeapSet(key=operator.attrgetter("i"))

cx = C("x", 2)
cy = C("y", 1)
Expand Down Expand Up @@ -150,3 +155,29 @@ def __init__(self, i):
heap.add(C("unsortable_key", None))
assert len(heap) == 1
assert set(heap) == {cx}


def test_heapset_pickle():
"""Test pickle roundtrip for a HeapSet.
Note
----
To make this test work with plain pickle and not need cloudpickle, we had to avoid
lambdas and local classes in our test. Here we're testing that HeapSet doesn't add
lambdas etc. of its own.
"""
heap = HeapSet(key=operator.attrgetter("i"))

# The heap contains broken weakrefs
for i in range(200):
c = C(f"y{i}", random.random())
heap.add(c)
if random.random() > 0.7:
heap.remove(c)

heap2 = pickle.loads(pickle.dumps(heap))
assert len(heap) == len(heap2)
# Test that the heap has been re-heapified upon unpickle
assert len(heap2._heap) < len(heap._heap)
while heap:
assert heap.pop() == heap2.pop()
22 changes: 22 additions & 0 deletions distributed/tests/test_worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import gc
import pickle
from collections.abc import Iterator

import pytest
Expand Down Expand Up @@ -163,6 +164,27 @@ def test_WorkerState__to_dict(ws):
assert actual == expect


def test_WorkerState_pickle(ws):
"""Test pickle round-trip.
Big caveat
----------
WorkerState, on its own, can be serialized with pickle; it doesn't need cloudpickle.
A WorkerState extracted from a Worker might, as data contents may only be
serializable with cloudpickle. Some objects created externally and not designed
for network transfer - namely, the SpillBuffer - may not be serializable at all.
"""
ws.handle_stimulus(
AcquireReplicasEvent(
who_has={"x": ["127.0.0.1:1235"]}, nbytes={"x": 123}, stimulus_id="s1"
)
)
ws.handle_stimulus(UpdateDataEvent(data={"y": 123}, report=False, stimulus_id="s"))
ws2 = pickle.loads(pickle.dumps(ws))
assert ws2.tasks.keys() == {"x", "y"}
assert ws2.data == {"y": 123}


def traverse_subclasses(cls: type) -> Iterator[type]:
yield cls
for subcls in cls.__subclasses__():
Expand Down
7 changes: 3 additions & 4 deletions distributed/worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import abc
import asyncio
import functools
import heapq
import logging
import operator
Expand All @@ -21,7 +20,7 @@
)
from copy import copy
from dataclasses import dataclass, field
from functools import lru_cache
from functools import lru_cache, partial, singledispatchmethod
from itertools import chain
from typing import TYPE_CHECKING, Any, ClassVar, Literal, NamedTuple, TypedDict, cast

Expand Down Expand Up @@ -1086,7 +1085,7 @@ def __init__(
self.has_what = defaultdict(set)
self.data_needed = HeapSet(key=operator.attrgetter("priority"))
self.data_needed_per_worker = defaultdict(
lambda: HeapSet(key=operator.attrgetter("priority"))
partial(HeapSet[TaskState], key=operator.attrgetter("priority"))
)
self.in_flight_workers = {}
self.busy_workers = set()
Expand Down Expand Up @@ -2324,7 +2323,7 @@ def _transitions(self, recommendations: Recs, *, stimulus_id: str) -> Instructio
# Events #
##########

@functools.singledispatchmethod
@singledispatchmethod
def _handle_event(self, ev: StateMachineEvent) -> RecsInstrs:
raise TypeError(ev) # pragma: nocover

Expand Down

0 comments on commit 88e1fe0

Please sign in to comment.