Skip to content

Commit

Permalink
Remove func packing for TaskSpec (#11496)
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter authored Nov 8, 2024
1 parent bad6b68 commit c912dc3
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 282 deletions.
206 changes: 15 additions & 191 deletions dask/_task_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,23 +78,13 @@
import sys
from collections import defaultdict
from collections.abc import Callable, Container, Iterable, Mapping, MutableMapping
from contextlib import contextmanager
from functools import partial
from typing import Any, TypeVar, cast, overload

from dask.sizeof import sizeof
from dask.typing import Key as KeyType
from dask.utils import funcname, is_namedtuple_instance

try:
from distributed.collections import LRU
except ImportError:

class LRU(dict): # type: ignore[no-redef]
def __init__(self, *args, maxsize=None, **kwargs):
super().__init__(*args, **kwargs)


_T = TypeVar("_T")
_T_GraphNode = TypeVar("_T_GraphNode", bound="GraphNode")
_T_Iterable = TypeVar("_T_Iterable", bound=Iterable)
Expand Down Expand Up @@ -415,10 +405,6 @@ def __hash__(self) -> int:
return hash(self.key)


_func_cache: MutableMapping = LRU(maxsize=1000)
_func_cache_reverse: MutableMapping = LRU(maxsize=1000)


class GraphNode:
key: KeyType
_dependencies: frozenset
Expand Down Expand Up @@ -447,12 +433,21 @@ def inline(self, dsk) -> GraphNode:
def __call__(self, values) -> Any:
raise NotImplementedError("Not implemented")

def __eq__(self, value: object) -> bool:
if type(value) is not type(self):
return False
from dask.tokenize import tokenize

return tokenize(self) == tokenize(value)

@property
def is_coro(self) -> bool:
return False

def __sizeof__(self) -> int:
return sys.getsizeof(type(self)) + sizeof(self.key) + sizeof(self.dependencies)
return sum(
sizeof(getattr(self, sl)) for sl in type(self).__slots__
) + sys.getsizeof(type(self))


_no_deps: frozenset = frozenset()
Expand All @@ -477,9 +472,6 @@ def __init__(self, key: KeyType, target: Alias | TaskRef | KeyType | None = None
def copy(self):
return Alias(self.key, self.target)

def __reduce__(self) -> str | tuple[Any, ...]:
return Alias, (self.key, self.target)

def __call__(self, values=()):
self._verify_values(values)
return values[self.target.key]
Expand All @@ -504,7 +496,7 @@ def __eq__(self, value: object) -> bool:
return False
if self.key != value.key:
return False
if self.key != value.key:
if self.target != value.target:
return False
return True

Expand Down Expand Up @@ -534,24 +526,14 @@ def __call__(self, values=()):
def __repr__(self):
return f"DataNode({self.key}, type={self.typ}, {self.value})"

def __reduce__(self):
return (DataNode, (self.key, self.value))

def __dask_tokenize__(self):
from dask.base import tokenize

return (type(self).__name__, tokenize(self.value))

def __reduce__(self) -> str | tuple[Any, ...]:
return DataNode, (self.key, self.value)

def __eq__(self, value: object) -> bool:
if not isinstance(value, DataNode):
return False
if self.value != value.value:
return False
return True

def __sizeof__(self) -> int:
return super().__sizeof__() + sizeof(self.value) + sizeof(self.typ)

def __iter__(self):
return iter(self.value)

Expand Down Expand Up @@ -595,10 +577,9 @@ def _get_dependencies(obj: object) -> set | frozenset:


class Task(GraphNode):
func: Callable | None
func: Callable
args: tuple
kwargs: dict
packed_func: None | bytes
_token: str | None
_is_coro: bool | None
_repr: str | None
Expand All @@ -615,7 +596,6 @@ def __init__(
):
self.key = key
self.func = func
self.packed_func = None
self.args = parse_input(args)
self.kwargs = parse_input(kwargs)
dependencies: set = set()
Expand All @@ -630,19 +610,14 @@ def __init__(
self._repr = None

def copy(self):
self.unpack()
return Task(self.key, self.func, *self.args, **self.kwargs)

def __hash__(self):
return hash(self._get_token())

def is_packed(self):
return self.packed_func is not None

def _get_token(self) -> str:
if self._token:
return self._token
self.unpack()
from dask.base import tokenize

self._token = tokenize(
Expand All @@ -658,19 +633,10 @@ def _get_token(self) -> str:
def __dask_tokenize__(self):
return self._get_token()

def __sizeof__(self) -> int:
return (
super().__sizeof__()
+ sizeof(self.func or self.packed_func)
+ sizeof(self.args)
+ sizeof(self.kwargs)
)

def __repr__(self) -> str:
# When `Task` is deserialized the constructor will not run and
# `self._repr` is thus undefined.
if not hasattr(self, "_repr") or not self._repr:
self.unpack()
head = funcname(self.func)
tail = ")"
label_size = 40
Expand Down Expand Up @@ -702,147 +668,21 @@ def __repr__(self) -> str:

def __call__(self, values=()):
self._verify_values(values)
try:
self.unpack()
except Exception as exc:
raise RuntimeError(
f"Exception occured during deserialization of function for task {self.key}"
) from exc
assert self.func is not None
new_argspec = _call_recursively(self.args, values)
if self.kwargs:
kwargs = _call_recursively(self.kwargs, values)
return self.func(*new_argspec, **kwargs)
return self.func(*new_argspec)

def pack(self):
# TODO: pack args and kwargs as well. Probably with a sizeof threshold
if self.is_packed():
return self
try:
from distributed.protocol.pickle import dumps
except ImportError:
from cloudpickle import dumps

try:
self.packed_func = _func_cache[self.func]
self.func = None
return self
except (KeyError, TypeError):
# We're not handling the below in this except to simplify traceback
# and cause
pass
try:
self.packed_func = dumps(self.func)
except Exception as exc:
raise RuntimeError(
f"Error during serialization of function of task {self.key}"
) from exc
try:
_func_cache_reverse[self.packed_func], _func_cache[self.func] = (
self.func,
self.packed_func,
)
except TypeError:
pass
self.func = None
return self

def lazy_pack(self):
# TODO: This could dispatch to a TPE and de/serialize the arguments
# lazily such that pack only waits for the result. This way we can use
# spare CPU cycles and parallelize on multiple CPUs.
# Thread safety should not be an issue if we guarantee idempotency and
# that the non-lazy sync also uses this
# This may block the GIL pretty aggressively
raise NotImplementedError("Not implemented")

def lazy_unpack(self):
raise NotImplementedError("Not implemented")

def inline(self, dsk) -> Task:
self.unpack()
new_args = _inline_recursively(self.args, dsk)
new_kwargs = _inline_recursively(self.kwargs, dsk)
assert self.func is not None
return Task(self.key, self.func, *new_args, **new_kwargs)

def unpack(self):
if not self.is_packed():
return
assert self.packed_func is not None

try:
from distributed.protocol.pickle import loads
except ImportError:
from cloudpickle import loads
try:
self.func = _func_cache_reverse[self.packed_func]
self.packed_func = None
return self
except KeyError:
# We're not handling the below in this except to simplify traceback
# and cause
pass
try:
self.func = loads(self.packed_func)
except Exception as exc:
raise RuntimeError(
f"Error during deserialization of function of task {self.key}"
) from exc
try:
_func_cache_reverse[self.packed_func], _func_cache[self.func] = (
self.func,
self.packed_func,
)
except TypeError:
pass
self.packed_func = None
return self

def __getstate__(self):
self.pack()
return {
"key": self.key,
"packed_func": self.packed_func,
"dependencies": self.dependencies,
"kwargs": self.kwargs,
"args": self.args,
"_is_coro": self._is_coro,
"_token": self._token,
}

def __setstate__(self, state):
self.key = state["key"]
self.packed_func = state["packed_func"]
self._dependencies = state["dependencies"]
self.kwargs = state["kwargs"]
self.args = state["args"]
self._is_coro = state["_is_coro"]
self._token = state["_token"]
self.func = None

def __eq__(self, value: object) -> bool:
if not isinstance(value, Task):
return False
if self.key != value.key:
return False
if self.packed_func != value.packed_func:
return False
if self.func != value.func:
return False
if self.dependencies != value.dependencies:
return False
if self.kwargs != value.kwargs:
return False
if self.args != value.args:
return False
return True

@property
def is_coro(self):
if self._is_coro is None:
self.unpack()
# Note: Can't use cached_property on objects without __dict__
try:
from distributed.utils import iscoroutinefunction
Expand Down Expand Up @@ -911,22 +751,6 @@ def __iter__(self):
return iter(())


@contextmanager
def no_function_cache():
"""Everything in this context will ignore the function cache on both
serialization and deserialization.
This is not threadsafe!
"""
global _func_cache, _func_cache_reverse
cache_before = _func_cache, _func_cache_reverse
_func_cache, _func_cache_reverse = _DevNullMapping(), _DevNullMapping()
try:
yield
finally:
_func_cache, _func_cache_reverse = cache_before


def execute_graph(
dsk: Iterable[GraphNode] | Mapping[KeyType, GraphNode],
cache: MutableMapping[KeyType, object] | None = None,
Expand Down
Loading

0 comments on commit c912dc3

Please sign in to comment.