diff --git a/distributed/spill.py b/distributed/spill.py index 7cba8161a7d..cbf03817a4b 100644 --- a/distributed/spill.py +++ b/distributed/spill.py @@ -14,7 +14,7 @@ from distributed.sizeof import safe_sizeof logger = logging.getLogger(__name__) -has_zict_210 = parse_version(zict.__version__) > parse_version("2.0.0") +has_zict_210 = parse_version(zict.__version__) >= parse_version("2.1.0") class SpilledSize(NamedTuple): @@ -63,7 +63,7 @@ def __init__( ): if max_spill is not False and not has_zict_210: - raise ValueError("zict > 2.0.0 required to set max_weight") + raise ValueError("zict >= 2.1.0 required to set max-spill") super().__init__( fast={}, diff --git a/distributed/tests/test_spill.py b/distributed/tests/test_spill.py index c30aa6cefc6..55bbb6ad8ad 100644 --- a/distributed/tests/test_spill.py +++ b/distributed/tests/test_spill.py @@ -5,16 +5,18 @@ import pytest -zict = pytest.importorskip("zict") -from packaging.version import parse as parse_version - from dask.sizeof import sizeof from distributed.compatibility import WINDOWS from distributed.protocol import serialize_bytelist -from distributed.spill import SpillBuffer +from distributed.spill import SpillBuffer, has_zict_210 from distributed.utils_test import captured_logger +requires_zict_210 = pytest.mark.skipif( + not has_zict_210, + reason="requires zict version >= 2.1.0", +) + def psize(*objs) -> tuple[int, int]: return ( @@ -105,12 +107,6 @@ def test_spillbuffer(tmpdir): assert buf.slow.total_weight == psize(d, e) -requires_zict_210 = pytest.mark.skipif( - parse_version(zict.__version__) <= parse_version("2.0.0"), - reason="requires zict version > 2.0.0", -) - - @requires_zict_210 def test_spillbuffer_maxlim(tmpdir): buf = SpillBuffer(str(tmpdir), target=200, max_spill=600, min_log_interval=0)