Skip to content

Commit

Permalink
Multiple worker executors (#4869)
Browse files Browse the repository at this point in the history
  • Loading branch information
madsbk authored Jun 4, 2021
1 parent 2bdec05 commit 0a54d95
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 25 deletions.
2 changes: 2 additions & 0 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7366,6 +7366,8 @@ def _task_to_msg(state: SchedulerState, ts: TaskState, duration: double = -1) ->
else:
msg["task"] = task

if ts._annotations:
msg["annotations"] = ts._annotations
return msg


Expand Down
24 changes: 24 additions & 0 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1997,3 +1997,27 @@ def get_worker_client_id():

with pytest.raises(ValueError):
default_client()


@pytest.mark.asyncio
async def test_multiple_executors(cleanup):
def get_thread_name():
return threading.current_thread().name

async with Scheduler() as s:
async with Worker(
s.address,
nthreads=2,
executor={
"GPU": ThreadPoolExecutor(1, thread_name_prefix="Dask-GPU-Threads")
},
) as w:
async with Client(s.address, asynchronous=True) as c:
futures = []
with dask.annotate(executor="default"):
futures.append(c.submit(get_thread_name, pure=False))
with dask.annotate(executor="GPU"):
futures.append(c.submit(get_thread_name, pure=False))
default_result, gpu_result = await c.gather(futures)
assert "Dask-Default-Threads" in default_result
assert "Dask-GPU-Threads" in gpu_result
88 changes: 63 additions & 25 deletions distributed/worker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import bisect
import concurrent.futures
import errno
import heapq
import logging
Expand All @@ -16,7 +17,7 @@
from functools import partial
from inspect import isawaitable
from pickle import PicklingError
from typing import Iterable
from typing import Dict, Iterable

from tlz import first, keymap, merge, pluck # noqa: F401
from tornado import gen
Expand All @@ -28,7 +29,7 @@
from dask.system import CPU_COUNT
from dask.utils import format_bytes, funcname

from . import comm, preloading, profile, system
from . import comm, preloading, profile, system, utils
from .batched import BatchedSend
from .comm import connect, get_address_host
from .comm.addressing import address_from_user_args
Expand Down Expand Up @@ -150,6 +151,8 @@ class TaskState:
serializable (e.g. int, string, list, dict).
* **nbytes**: ``int``
The size of a particular piece of data
* **annotations**: ``dict``
Task annotations
Parameters
----------
Expand Down Expand Up @@ -184,6 +187,7 @@ def __init__(self, key, runspec=None):
self.stop_time = None
self.metadata = {}
self.nbytes = None
self.annotations = None

def __repr__(self):
return "<Task %r %s>" % (self.key, self.state)
Expand Down Expand Up @@ -223,11 +227,9 @@ class Worker(ServerNode):
* **nthreads:** ``int``:
Number of nthreads used by this worker process
* **executor:** ``concurrent.futures.ThreadPoolExecutor``:
Executor used to perform computation
This can also be the string "offload" in which case this uses the same
thread pool used for offloading communications. This results in the
same thread being used for deserialization and computation.
* **executors:** ``Dict[str, concurrent.futures.Executor]``:
Executors used to perform computation. Always contains the default
executor.
* **local_directory:** ``path``:
Path on local machine to store temporary files
* **scheduler:** ``rpc``:
Expand Down Expand Up @@ -324,7 +326,15 @@ class Worker(ServerNode):
Fraction of memory at which we start spilling to disk
memory_pause_fraction: float
Fraction of memory at which we stop running new tasks
executor: concurrent.futures.Executor
executor: concurrent.futures.Executor, dict[str, concurrent.futures.Executor], str
The executor(s) to use. Depending on the type, it has the following meanings:
- Executor instance: The default executor.
- Dict[str, Executor]: mapping names to Executor instances. If the
"default" key isn't in the dict, a "default" executor will be created
using ``ThreadPoolExecutor(nthreads)``.
- Str: The string "offload", which refer to the same thread pool used for
offloading communications. This results in the same thread being used
for deserialization and computation.
resources: dict
Resources that this worker has like ``{'GPU': 2}``
nanny: str
Expand Down Expand Up @@ -626,14 +636,25 @@ def __init__(
self.actors = {}
self.loop = loop or IOLoop.current()
self.reconnect = reconnect

# Common executors always available
self.executors: Dict[str, concurrent.futures.Executor] = {
"offload": utils._offload_executor,
"actor": ThreadPoolExecutor(1, thread_name_prefix="Dask-Actor-Threads"),
}

# Find the default executor
if executor == "offload":
from distributed.utils import _offload_executor as executor
self.executor = executor or ThreadPoolExecutor(
self.nthreads, thread_name_prefix="Dask-Worker-Threads'"
)
self.actor_executor = ThreadPoolExecutor(
1, thread_name_prefix="Dask-Actor-Threads"
)
self.executors["default"] = self.executors["offload"]
elif isinstance(executor, dict):
self.executors.update(executor)
elif executor is not None:
self.executors["default"] = executor
if "default" not in self.executors:
self.executors["default"] = ThreadPoolExecutor(
self.nthreads, thread_name_prefix="Dask-Default-Threads"
)

self.batched_stream = BatchedSend(interval="2ms", loop=self.loop)
self.name = name
self.scheduler_delay = 0
Expand Down Expand Up @@ -808,6 +829,10 @@ def local_dir(self):
)
return self.local_directory

@property
def executor(self):
return self.executors["default"]

async def get_metrics(self):
out = dict(
executing=self.executing_count,
Expand Down Expand Up @@ -1280,13 +1305,14 @@ async def close(
with suppress(TimeoutError):
await self.batched_stream.close(timedelta(seconds=timeout))

self.actor_executor._work_queue.queue.clear()
if isinstance(self.executor, ThreadPoolExecutor):
self.executor._work_queue.queue.clear()
self.executor.shutdown(wait=executor_wait, timeout=timeout)
else:
self.executor.shutdown(wait=False)
self.actor_executor.shutdown(wait=executor_wait, timeout=timeout)
for executor in self.executors.values():
if executor is utils._offload_executor:
continue # Never shutdown the offload executor
if isinstance(executor, ThreadPoolExecutor):
executor._work_queue.queue.clear()
executor.shutdown(wait=executor_wait, timeout=timeout)
else:
executor.shutdown(wait=executor_wait)

self.stop()
await self.rpc.close()
Expand Down Expand Up @@ -1499,6 +1525,7 @@ def add_task(
duration=None,
resource_restrictions=None,
actor=False,
annotations=None,
**kwargs2,
):
try:
Expand Down Expand Up @@ -1545,6 +1572,7 @@ def add_task(
ts.duration = duration
if resource_restrictions:
ts.resource_restrictions = resource_restrictions
ts.annotations = annotations

who_has = who_has or {}

Expand Down Expand Up @@ -2560,7 +2588,7 @@ def executor_submit(self, key, function, args=(), kwargs=None, executor=None):
callbacks to ensure things run smoothly. This can get tricky, so we
pull it off into an separate method.
"""
executor = executor or self.executor
executor = executor or self.executors["default"]
job_counter[0] += 1
# logger.info("%s:%d Starts job %d, %s", self.ip, self.port, i, key)
kwargs = kwargs or {}
Expand Down Expand Up @@ -2656,7 +2684,7 @@ async def actor_execute(
self.active_threads,
self.active_threads_lock,
),
executor=self.actor_executor,
executor=self.executors["actor"],
)
else:
result = func(*args, **kwargs)
Expand Down Expand Up @@ -2790,8 +2818,17 @@ async def execute(self, key, report=False):
if self.digests is not None:
self.digests["disk-load-duration"].add(stop - start)

if ts.annotations is not None and "executor" in ts.annotations:
executor = ts.annotations["executor"]
else:
executor = "default"
assert executor in self.executors

logger.debug(
"Execute key: %s worker: %s", ts.key, self.address
"Execute key: %s worker: %s, executor: %s",
ts.key,
self.address,
executor,
) # TODO: comment out?
assert key == ts.key
try:
Expand All @@ -2808,6 +2845,7 @@ async def execute(self, key, report=False):
self.active_threads_lock,
self.scheduler_delay,
),
executor=self.executors[executor],
)
except RuntimeError as e:
executor_error = e
Expand Down

0 comments on commit 0a54d95

Please sign in to comment.