Skip to content

Commit

Permalink
Backport from dask#5904
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Mar 14, 2022
1 parent 2fffe74 commit 8fcc515
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 30 deletions.
4 changes: 2 additions & 2 deletions distributed/spill.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from distributed.sizeof import safe_sizeof

logger = logging.getLogger(__name__)
has_zict_210 = parse_version(zict.__version__) > parse_version("2.0.0")
has_zict_210 = parse_version(zict.__version__) >= parse_version("2.1.0")


class SpilledSize(NamedTuple):
Expand Down Expand Up @@ -63,7 +63,7 @@ def __init__(
):

if max_spill is not False and not has_zict_210:
raise ValueError("zict > 2.0.0 required to set max_weight")
raise ValueError("zict >= 2.1.0 required to set max-spill")

super().__init__(
fast={},
Expand Down
16 changes: 6 additions & 10 deletions distributed/tests/test_spill.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,18 @@

import pytest

zict = pytest.importorskip("zict")
from packaging.version import parse as parse_version

from dask.sizeof import sizeof

from distributed.compatibility import WINDOWS
from distributed.protocol import serialize_bytelist
from distributed.spill import SpillBuffer
from distributed.spill import SpillBuffer, has_zict_210
from distributed.utils_test import captured_logger

requires_zict_210 = pytest.mark.skipif(
not has_zict_210,
reason="requires zict version >= 2.1.0",
)


def psize(*objs) -> tuple[int, int]:
return (
Expand Down Expand Up @@ -105,12 +107,6 @@ def test_spillbuffer(tmpdir):
assert buf.slow.total_weight == psize(d, e)


requires_zict_210 = pytest.mark.skipif(
parse_version(zict.__version__) <= parse_version("2.0.0"),
reason="requires zict version > 2.0.0",
)


@requires_zict_210
def test_spillbuffer_maxlim(tmpdir):
buf = SpillBuffer(str(tmpdir), target=200, max_spill=600, min_log_interval=0)
Expand Down
23 changes: 5 additions & 18 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import psutil
import pytest
from packaging.version import parse as parse_version
from tlz import first, pluck, sliding_window

import dask
Expand All @@ -43,6 +42,7 @@
from distributed.metrics import time
from distributed.protocol import pickle
from distributed.scheduler import Scheduler
from distributed.spill import has_zict_210
from distributed.utils import TimeoutError
from distributed.utils_test import (
TaskStateMetadataPlugin,
Expand All @@ -63,15 +63,9 @@

pytestmark = pytest.mark.ci1

try:
import zict
except ImportError:
zict = None # type: ignore

requires_zict = pytest.mark.skipif(not zict, reason="requires zict")
requires_zict_210 = pytest.mark.skipif(
not zict or parse_version(zict.__version__) <= parse_version("2.0.0"),
reason="requires zict version > 2.0.0",
not has_zict_210,
reason="requires zict version >= 2.1.0",
)


Expand Down Expand Up @@ -924,7 +918,6 @@ async def assert_basic_futures(c: Client) -> None:
assert results == list(map(inc, range(10)))


@requires_zict
@gen_cluster(client=True)
async def test_fail_write_to_disk_target_1(c, s, a, b):
"""Test failure to spill triggered by key which is individually larger
Expand All @@ -942,7 +935,6 @@ async def test_fail_write_to_disk_target_1(c, s, a, b):
await assert_basic_futures(c)


@requires_zict
@gen_cluster(
client=True,
nthreads=[("", 1)],
Expand All @@ -965,10 +957,8 @@ async def test_fail_write_to_disk_target_2(c, s, a):

y = c.submit(lambda: "y" * 256, key="y")
await wait(y)
if parse_version(zict.__version__) <= parse_version("2.0.0"):
assert set(a.data.memory) == {"y"}
else:
assert set(a.data.memory) == {"x", "y"}

assert set(a.data.memory) == {"x", "y"} if has_zict_210 else {"y"}
assert not a.data.disk

await assert_basic_futures(c)
Expand Down Expand Up @@ -1187,7 +1177,6 @@ async def test_statistical_profiling_2(c, s, a, b):
break


@requires_zict
@gen_cluster(
client=True,
nthreads=[("", 1)],
Expand Down Expand Up @@ -1277,7 +1266,6 @@ async def test_spill_constrained(c, s, w):
assert set(w.data.disk) == {x.key}


@requires_zict
@gen_cluster(
nthreads=[("", 1)],
client=True,
Expand All @@ -1301,7 +1289,6 @@ async def test_spill_spill_threshold(c, s, a):
assert await x == 1


@requires_zict
@pytest.mark.parametrize(
"memory_target_fraction,managed,expect_spilled",
[
Expand Down

0 comments on commit 8fcc515

Please sign in to comment.