Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add capability to log spilling #442

Merged
merged 12 commits into from
Apr 5, 2021
124 changes: 120 additions & 4 deletions dask_cuda/device_host_file.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import functools
import logging
import os
import time

from zict import Buffer, File, Func
from zict.common import ZictBase
Expand All @@ -21,8 +23,84 @@
from .utils import nvtx_annotate


class LoggedBuffer(Buffer):
"""Extends zict.Buffer with logging capabilities

Two arguments `fast_name` and `slow_name` are passed to constructor that
identify a user-friendly name for logging of where spilling is going from/to.
For example, their names can be "Device" and "Host" to identify that spilling
is happening from a CUDA device into system memory.
"""

def __init__(self, *args, fast_name="Fast", slow_name="Slow", addr=None, **kwargs):
self.addr = "Unknown Address" if addr is None else addr
self.fast_name = fast_name
self.slow_name = slow_name
self.msg_template = (
"Worker at <%s>: Spilled key %s with %s bytes from %s to %s in %s seconds"
)

# It is a bit hacky to forcefully capture the "distributed.worker" logger,
# eventually it would be better to have a different logger. For now this
# is ok, allowing users to read logs with client.get_worker_logs(), a
# proper solution would require changes to Distributed.
self.logger = logging.getLogger("distributed.worker")

super().__init__(*args, **kwargs)

self.total_time_fast_to_slow = 0.0
self.total_time_slow_to_fast = 0.0

def fast_to_slow(self, key, value):
start = time.time()
ret = super().fast_to_slow(key, value)
total = time.time() - start
self.total_time_fast_to_slow += total

self.logger.info(
self.msg_template
% (
self.addr,
key,
weight(key, value),
self.fast_name,
self.slow_name,
total,
)
)

return ret

def slow_to_fast(self, key):
start = time.time()
ret = super().slow_to_fast(key)
total = time.time() - start
self.total_time_slow_to_fast += total

self.logger.info(
self.msg_template
% (self.addr, key, weight(key, ret), self.slow_name, self.fast_name, total)
)

return ret

def set_address(self, addr):
self.addr = addr

def get_total_spilling_time(self):
return {
(
"Total spilling time from %s to %s" % (self.fast_name, self.slow_name)
): self.total_time_fast_to_slow,
(
"Total spilling time from %s to %s" % (self.slow_name, self.fast_name)
): self.total_time_slow_to_fast,
}


class DeviceSerialized:
"""Store device object on the host

This stores a device-side object as
1. A msgpack encodable header
2. A list of `bytes`-like objects (like NumPy arrays)
Expand Down Expand Up @@ -106,10 +184,18 @@ class DeviceHostFile(ZictBase):
implies no spilling to disk.
local_directory: path
Path where to store serialized objects on disk
log_spilling: bool
If True, all spilling operations will be logged directly to
distributed.worker with an INFO loglevel. This will eventually be
replaced by a Dask configuration flag.
"""

def __init__(
self, device_memory_limit=None, memory_limit=None, local_directory=None,
self,
device_memory_limit=None,
memory_limit=None,
local_directory=None,
log_spilling=False,
):
if local_directory is None:
local_directory = dask.config.get("temporary-directory") or os.getcwd()
Expand All @@ -126,18 +212,35 @@ def __init__(
deserialize_bytes,
File(self.disk_func_path),
)

host_buffer_kwargs = {}
device_buffer_kwargs = {}
buffer_class = Buffer
if log_spilling is True:
buffer_class = LoggedBuffer
host_buffer_kwargs = {"fast_name": "Host", "slow_name": "Disk"}
device_buffer_kwargs = {"fast_name": "Device", "slow_name": "Host"}

if memory_limit == 0:
self.host_buffer = self.host_func
else:
self.host_buffer = Buffer(
self.host_func, self.disk_func, memory_limit, weight=weight
self.host_buffer = buffer_class(
self.host_func,
self.disk_func,
memory_limit,
weight=weight,
**host_buffer_kwargs,
)

self.device_keys = set()
self.device_func = dict()
self.device_host_func = Func(device_to_host, host_to_device, self.host_buffer)
self.device_buffer = Buffer(
self.device_func, self.device_host_func, device_memory_limit, weight=weight
self.device_func,
self.device_host_func,
device_memory_limit,
weight=weight,
**device_buffer_kwargs,
)

self.device = self.device_buffer.fast.d
Expand Down Expand Up @@ -175,3 +278,16 @@ def __iter__(self):
def __delitem__(self, key):
self.device_keys.discard(key)
del self.device_buffer[key]

def set_address(self, addr):
if isinstance(self.host_buffer, LoggedBuffer):
self.host_buffer.set_address(addr)
self.device_buffer.set_address(addr)

def get_total_spilling_time(self):
ret = {}
if isinstance(self.device_buffer, LoggedBuffer):
ret = {**ret, **self.device_buffer.get_total_spilling_time()}
if isinstance(self.host_buffer, LoggedBuffer):
ret = {**ret, **self.host_buffer.get_total_spilling_time()}
return ret
23 changes: 22 additions & 1 deletion dask_cuda/local_cuda_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import warnings

import dask
from dask.distributed import LocalCluster
from dask.distributed import LocalCluster, Nanny, Worker
from distributed.utils import parse_bytes
from distributed.worker import parse_memory_limit

Expand All @@ -22,6 +22,20 @@
)


class LoggedWorker(Worker):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

async def start(self):
await super().start()
self.data.set_address(self.address)


class LoggedNanny(Nanny):
def __init__(self, *args, **kwargs):
super().__init__(*args, worker_class=LoggedWorker, **kwargs)


class LocalCUDACluster(LocalCluster):
"""A variant of ``dask.distributed.LocalCluster`` that uses one GPU per process.

Expand Down Expand Up @@ -110,6 +124,10 @@ class LocalCUDACluster(LocalCluster):
If ``True``, enable just-in-time unspilling. This is experimental and doesn't
support memory spilling to disk. Please see ``proxy_object.ProxyObject`` and
``proxify_host_file.ProxifyHostFile``.
log_spilling: bool
If True, all spilling operations will be logged directly to
distributed.worker with an INFO loglevel. This will eventually be
replaced by a Dask configuration flag.


Examples
Expand Down Expand Up @@ -155,6 +173,7 @@ def __init__(
rmm_managed_memory=False,
rmm_log_directory=None,
jit_unspill=None,
log_spilling=False,
**kwargs,
):
# Required by RAPIDS libraries (e.g., cuDF) to ensure no context
Expand Down Expand Up @@ -231,6 +250,7 @@ def __init__(
"local_directory": local_directory
or dask.config.get("temporary-directory")
or os.getcwd(),
"log_spilling": log_spilling,
},
)

Expand Down Expand Up @@ -271,6 +291,7 @@ def __init__(
data=data,
local_directory=local_directory,
protocol=protocol,
worker_class=LoggedNanny if log_spilling is True else Nanny,
config={
"ucx": get_ucx_config(
enable_tcp_over_ucx=enable_tcp_over_ucx,
Expand Down