Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Dask CUDA work with the new WorkerMemoryManager abstraction #870

Merged
merged 10 commits into from
Mar 21, 2022
2 changes: 1 addition & 1 deletion dask_cuda/cuda_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
enable_proctitle_on_children,
enable_proctitle_on_current,
)
from distributed.worker import parse_memory_limit
from distributed.worker_memory import parse_memory_limit

from .device_host_file import DeviceHostFile
from .initialize import initialize
Expand Down
27 changes: 21 additions & 6 deletions dask_cuda/device_host_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ class DeviceHostFile(ZictBase):
spills to host cache once filled.
memory_limit: int
Number of bytes of host memory for host LRU cache, spills to
disk once filled. Setting this to 0 means unlimited host memory,
implies no spilling to disk.
disk once filled. Setting this to `0` or `None` means unlimited
host memory, implies no spilling to disk.
local_directory: path
Path where to store serialized objects on disk
log_spilling: bool
Expand All @@ -182,6 +182,9 @@ def __init__(
)
os.makedirs(self.disk_func_path, exist_ok=True)

if memory_limit == 0:
memory_limit = None

self.host_func = dict()
self.disk_func = Func(
functools.partial(serialize_bytelist, on_error="raise"),
Expand All @@ -197,7 +200,7 @@ def __init__(
host_buffer_kwargs = {"fast_name": "Host", "slow_name": "Disk"}
device_buffer_kwargs = {"fast_name": "Device", "slow_name": "Host"}

if memory_limit == 0:
if memory_limit is None:
self.host_buffer = self.host_func
else:
self.host_buffer = buffer_class(
Expand All @@ -220,11 +223,13 @@ def __init__(
)

self.device = self.device_buffer.fast.d
self.host = self.host_buffer if memory_limit == 0 else self.host_buffer.fast.d
self.disk = None if memory_limit == 0 else self.host_buffer.slow.d
self.host = (
self.host_buffer if memory_limit is None else self.host_buffer.fast.d
)
self.disk = None if memory_limit is None else self.host_buffer.slow.d

# For Worker compatibility only, where `fast` is host memory buffer
self.fast = self.host_buffer if memory_limit == 0 else self.host_buffer.fast
self.fast = self.host_buffer if memory_limit is None else self.host_buffer.fast

def __setitem__(self, key, value):
if key in self.device_buffer:
Expand Down Expand Up @@ -255,6 +260,16 @@ def __delitem__(self, key):
self.device_keys.discard(key)
del self.device_buffer[key]

def evict(self):
"""Evicts least recently used host buffer (aka, CPU or system memory)

Implements distributed.spill.ManualEvictProto interface"""
try:
_, _, weight = self.host_buffer.fast.evict()
return weight
except Exception: # We catch all `Exception`s, just like zict.LRU
return -1

def set_address(self, addr):
if isinstance(self.host_buffer, LoggedBuffer):
self.host_buffer.set_address(addr)
Expand Down
2 changes: 1 addition & 1 deletion dask_cuda/local_cuda_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import dask
from dask.utils import parse_bytes
from distributed import LocalCluster, Nanny, Worker
from distributed.worker import parse_memory_limit
from distributed.worker_memory import parse_memory_limit

from .device_host_file import DeviceHostFile
from .initialize import initialize
Expand Down
4 changes: 3 additions & 1 deletion dask_cuda/tests/test_local_cuda_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def get_visible_devices():
)

# Use full memory, checked with some buffer to ignore rounding difference
full_mem = sum(w.memory_limit for w in cluster.workers.values())
full_mem = sum(
w.memory_manager.memory_limit for w in cluster.workers.values()
)
assert full_mem >= MEMORY_LIMIT - 1024 and full_mem < MEMORY_LIMIT + 1024

for w, devices in result.items():
Expand Down
13 changes: 7 additions & 6 deletions dask_cuda/tests/test_proxify_host_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ async def test_worker_force_spill_to_disk():
"""Test Dask triggering CPU-to-Disk spilling """
cudf = pytest.importorskip("cudf")

with dask.config.set({"distributed.worker.memory.terminate": 0}):
with dask.config.set({"distributed.worker.memory.terminate": None}):
async with dask_cuda.LocalCUDACluster(
n_workers=1, device_memory_limit="1MB", jit_unspill=True, asynchronous=True
) as cluster:
Expand All @@ -418,14 +418,15 @@ async def f():
"""Trigger a memory_monitor() and reset memory_limit"""
w = get_worker()
# Set a host memory limit that triggers spilling to disk
w.memory_pause_fraction = False
w.memory_manager.memory_pause_fraction = False
memory = w.monitor.proc.memory_info().rss
w.memory_limit = memory - 10 ** 8
w.memory_target_fraction = 1
await w.memory_monitor()
w.memory_manager.memory_limit = memory - 10 ** 8
w.memory_manager.memory_target_fraction = 1

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't work: dask/distributed#5367

print(w.memory_manager.data)
await w.memory_manager.memory_monitor(w)
# Check that host memory are freed
assert w.monitor.proc.memory_info().rss < memory - 10 ** 7
w.memory_limit = memory * 10 # Un-limit
w.memory_manager.memory_limit = memory * 10 # Un-limit

await client.submit(f)
log = str(await client.get_worker_logs())
Expand Down
68 changes: 44 additions & 24 deletions dask_cuda/tests/test_spill.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,31 +84,41 @@ def delayed_worker_assert(total_size, device_chunk_overhead, serialized_chunk_ov
[
{
"device_memory_limit": int(200e6),
"memory_limit": int(800e6),
"host_target": 0.0,
"host_spill": 0.0,
"host_pause": 0.0,
"memory_limit": int(2000e6),
"host_target": False,
"host_spill": False,
"host_pause": False,
"spills_to_disk": False,
},
{
"device_memory_limit": int(200e6),
"memory_limit": int(200e6),
"host_target": 0.0,
"host_spill": 0.0,
"host_pause": 0.0,
"host_target": False,
"host_spill": False,
"host_pause": False,
"spills_to_disk": True,
},
{
# This test setup differs from the one above as Distributed worker
# pausing is enabled and thus triggers `DeviceHostFile.evict()`
"device_memory_limit": int(200e6),
"memory_limit": 0,
"host_target": 0.0,
"host_spill": 0.0,
"host_pause": 0.0,
"memory_limit": int(200e6),
"host_target": None,
"host_spill": None,
"host_pause": False,
"spills_to_disk": True,
},
{
"device_memory_limit": int(200e6),
"memory_limit": None,
"host_target": False,
"host_spill": False,
"host_pause": False,
"spills_to_disk": False,
},
],
)
@pytest.mark.asyncio
@gen_test(timeout=20)
async def test_cupy_cluster_device_spill(params):
cupy = pytest.importorskip("cupy")
with dask.config.set({"distributed.worker.memory.terminate": False}):
Expand Down Expand Up @@ -159,31 +169,41 @@ async def test_cupy_cluster_device_spill(params):
[
{
"device_memory_limit": int(200e6),
"memory_limit": int(800e6),
"host_target": 0.0,
"host_spill": 0.0,
"host_pause": 0.0,
"memory_limit": int(4000e6),
"host_target": False,
"host_spill": False,
"host_pause": False,
"spills_to_disk": False,
},
{
"device_memory_limit": int(200e6),
"memory_limit": int(200e6),
"host_target": 0.0,
"host_spill": 0.0,
"host_pause": 0.0,
"host_target": False,
"host_spill": False,
"host_pause": False,
"spills_to_disk": True,
},
{
# This test setup differs from the one above as Distributed worker
# pausing is enabled and thus triggers `DeviceHostFile.evict()`
"device_memory_limit": int(200e6),
"memory_limit": int(200e6),
"host_target": None,
"host_spill": None,
"host_pause": False,
"spills_to_disk": True,
},
{
"device_memory_limit": int(200e6),
"memory_limit": 0,
"host_target": 0.0,
"host_spill": 0.0,
"host_pause": 0.0,
"memory_limit": None,
"host_target": False,
"host_spill": False,
"host_pause": False,
"spills_to_disk": False,
},
],
)
@pytest.mark.asyncio
@gen_test(timeout=20)
async def test_cudf_cluster_device_spill(params):
cudf = pytest.importorskip("cudf")

Expand Down