Skip to content

Commit

Permalink
Make Dask CUDA work with the new WorkerMemoryManager abstraction (#870)
Browse files Browse the repository at this point in the history
This PR updates dask-cuda to work with the new `WorkerMemoryManager` abstraction being introduced in  dask/distributed#5904. Once both PRs are merged, and pending the resolution of https://github.com/dask/distributed/pull/5904/files#r822084806, dask-cuda CI should be unblocked.

Authors:
  - Ashwin Srinath (https://github.com/shwina)
  - Peter Andreas Entschev (https://github.com/pentschev)

Approvers:
  - Peter Andreas Entschev (https://github.com/pentschev)

URL: #870
  • Loading branch information
shwina authored Mar 21, 2022
1 parent 8200f2d commit 381ff6d
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 39 deletions.
2 changes: 1 addition & 1 deletion dask_cuda/cuda_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
enable_proctitle_on_children,
enable_proctitle_on_current,
)
from distributed.worker import parse_memory_limit
from distributed.worker_memory import parse_memory_limit

from .device_host_file import DeviceHostFile
from .initialize import initialize
Expand Down
27 changes: 21 additions & 6 deletions dask_cuda/device_host_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ class DeviceHostFile(ZictBase):
spills to host cache once filled.
memory_limit: int
Number of bytes of host memory for host LRU cache, spills to
disk once filled. Setting this to 0 means unlimited host memory,
implies no spilling to disk.
disk once filled. Setting this to `0` or `None` means unlimited
host memory, implies no spilling to disk.
local_directory: path
Path where to store serialized objects on disk
log_spilling: bool
Expand All @@ -182,6 +182,9 @@ def __init__(
)
os.makedirs(self.disk_func_path, exist_ok=True)

if memory_limit == 0:
memory_limit = None

self.host_func = dict()
self.disk_func = Func(
functools.partial(serialize_bytelist, on_error="raise"),
Expand All @@ -197,7 +200,7 @@ def __init__(
host_buffer_kwargs = {"fast_name": "Host", "slow_name": "Disk"}
device_buffer_kwargs = {"fast_name": "Device", "slow_name": "Host"}

if memory_limit == 0:
if memory_limit is None:
self.host_buffer = self.host_func
else:
self.host_buffer = buffer_class(
Expand All @@ -220,11 +223,13 @@ def __init__(
)

self.device = self.device_buffer.fast.d
self.host = self.host_buffer if memory_limit == 0 else self.host_buffer.fast.d
self.disk = None if memory_limit == 0 else self.host_buffer.slow.d
self.host = (
self.host_buffer if memory_limit is None else self.host_buffer.fast.d
)
self.disk = None if memory_limit is None else self.host_buffer.slow.d

# For Worker compatibility only, where `fast` is host memory buffer
self.fast = self.host_buffer if memory_limit == 0 else self.host_buffer.fast
self.fast = self.host_buffer if memory_limit is None else self.host_buffer.fast

def __setitem__(self, key, value):
if key in self.device_buffer:
Expand Down Expand Up @@ -255,6 +260,16 @@ def __delitem__(self, key):
self.device_keys.discard(key)
del self.device_buffer[key]

def evict(self):
"""Evicts least recently used host buffer (aka, CPU or system memory)
Implements distributed.spill.ManualEvictProto interface"""
try:
_, _, weight = self.host_buffer.fast.evict()
return weight
except Exception: # We catch all `Exception`s, just like zict.LRU
return -1

def set_address(self, addr):
if isinstance(self.host_buffer, LoggedBuffer):
self.host_buffer.set_address(addr)
Expand Down
2 changes: 1 addition & 1 deletion dask_cuda/local_cuda_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import dask
from dask.utils import parse_bytes
from distributed import LocalCluster, Nanny, Worker
from distributed.worker import parse_memory_limit
from distributed.worker_memory import parse_memory_limit

from .device_host_file import DeviceHostFile
from .initialize import initialize
Expand Down
4 changes: 3 additions & 1 deletion dask_cuda/tests/test_local_cuda_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def get_visible_devices():
)

# Use full memory, checked with some buffer to ignore rounding difference
full_mem = sum(w.memory_limit for w in cluster.workers.values())
full_mem = sum(
w.memory_manager.memory_limit for w in cluster.workers.values()
)
assert full_mem >= MEMORY_LIMIT - 1024 and full_mem < MEMORY_LIMIT + 1024

for w, devices in result.items():
Expand Down
13 changes: 7 additions & 6 deletions dask_cuda/tests/test_proxify_host_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ async def test_worker_force_spill_to_disk():
"""Test Dask triggering CPU-to-Disk spilling """
cudf = pytest.importorskip("cudf")

with dask.config.set({"distributed.worker.memory.terminate": 0}):
with dask.config.set({"distributed.worker.memory.terminate": None}):
async with dask_cuda.LocalCUDACluster(
n_workers=1, device_memory_limit="1MB", jit_unspill=True, asynchronous=True
) as cluster:
Expand All @@ -418,14 +418,15 @@ async def f():
"""Trigger a memory_monitor() and reset memory_limit"""
w = get_worker()
# Set a host memory limit that triggers spilling to disk
w.memory_pause_fraction = False
w.memory_manager.memory_pause_fraction = False
memory = w.monitor.proc.memory_info().rss
w.memory_limit = memory - 10 ** 8
w.memory_target_fraction = 1
await w.memory_monitor()
w.memory_manager.memory_limit = memory - 10 ** 8
w.memory_manager.memory_target_fraction = 1
print(w.memory_manager.data)
await w.memory_manager.memory_monitor(w)
# Check that host memory are freed
assert w.monitor.proc.memory_info().rss < memory - 10 ** 7
w.memory_limit = memory * 10 # Un-limit
w.memory_manager.memory_limit = memory * 10 # Un-limit

await client.submit(f)
log = str(await client.get_worker_logs())
Expand Down
68 changes: 44 additions & 24 deletions dask_cuda/tests/test_spill.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,31 +84,41 @@ def delayed_worker_assert(total_size, device_chunk_overhead, serialized_chunk_ov
[
{
"device_memory_limit": int(200e6),
"memory_limit": int(800e6),
"host_target": 0.0,
"host_spill": 0.0,
"host_pause": 0.0,
"memory_limit": int(2000e6),
"host_target": False,
"host_spill": False,
"host_pause": False,
"spills_to_disk": False,
},
{
"device_memory_limit": int(200e6),
"memory_limit": int(200e6),
"host_target": 0.0,
"host_spill": 0.0,
"host_pause": 0.0,
"host_target": False,
"host_spill": False,
"host_pause": False,
"spills_to_disk": True,
},
{
# This test setup differs from the one above as Distributed worker
# pausing is enabled and thus triggers `DeviceHostFile.evict()`
"device_memory_limit": int(200e6),
"memory_limit": 0,
"host_target": 0.0,
"host_spill": 0.0,
"host_pause": 0.0,
"memory_limit": int(200e6),
"host_target": None,
"host_spill": None,
"host_pause": False,
"spills_to_disk": True,
},
{
"device_memory_limit": int(200e6),
"memory_limit": None,
"host_target": False,
"host_spill": False,
"host_pause": False,
"spills_to_disk": False,
},
],
)
@pytest.mark.asyncio
@gen_test(timeout=20)
async def test_cupy_cluster_device_spill(params):
cupy = pytest.importorskip("cupy")
with dask.config.set({"distributed.worker.memory.terminate": False}):
Expand Down Expand Up @@ -159,31 +169,41 @@ async def test_cupy_cluster_device_spill(params):
[
{
"device_memory_limit": int(200e6),
"memory_limit": int(800e6),
"host_target": 0.0,
"host_spill": 0.0,
"host_pause": 0.0,
"memory_limit": int(4000e6),
"host_target": False,
"host_spill": False,
"host_pause": False,
"spills_to_disk": False,
},
{
"device_memory_limit": int(200e6),
"memory_limit": int(200e6),
"host_target": 0.0,
"host_spill": 0.0,
"host_pause": 0.0,
"host_target": False,
"host_spill": False,
"host_pause": False,
"spills_to_disk": True,
},
{
# This test setup differs from the one above as Distributed worker
# pausing is enabled and thus triggers `DeviceHostFile.evict()`
"device_memory_limit": int(200e6),
"memory_limit": int(200e6),
"host_target": None,
"host_spill": None,
"host_pause": False,
"spills_to_disk": True,
},
{
"device_memory_limit": int(200e6),
"memory_limit": 0,
"host_target": 0.0,
"host_spill": 0.0,
"host_pause": 0.0,
"memory_limit": None,
"host_target": False,
"host_spill": False,
"host_pause": False,
"spills_to_disk": False,
},
],
)
@pytest.mark.asyncio
@gen_test(timeout=20)
async def test_cudf_cluster_device_spill(params):
cudf = pytest.importorskip("cudf")

Expand Down

0 comments on commit 381ff6d

Please sign in to comment.