From 0a54d95cb26f70427b80b180b5ba7a67b2553a26 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Fri, 4 Jun 2021 18:33:04 +0200 Subject: [PATCH] Multiple worker executors (#4869) --- distributed/scheduler.py | 2 + distributed/tests/test_worker.py | 24 +++++++++ distributed/worker.py | 88 +++++++++++++++++++++++--------- 3 files changed, 89 insertions(+), 25 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index de3843c17f4..863e05bc1c1 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -7366,6 +7366,8 @@ def _task_to_msg(state: SchedulerState, ts: TaskState, duration: double = -1) -> else: msg["task"] = task + if ts._annotations: + msg["annotations"] = ts._annotations return msg diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index f50c73e990a..05672364ea8 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1997,3 +1997,27 @@ def get_worker_client_id(): with pytest.raises(ValueError): default_client() + + +@pytest.mark.asyncio +async def test_multiple_executors(cleanup): + def get_thread_name(): + return threading.current_thread().name + + async with Scheduler() as s: + async with Worker( + s.address, + nthreads=2, + executor={ + "GPU": ThreadPoolExecutor(1, thread_name_prefix="Dask-GPU-Threads") + }, + ) as w: + async with Client(s.address, asynchronous=True) as c: + futures = [] + with dask.annotate(executor="default"): + futures.append(c.submit(get_thread_name, pure=False)) + with dask.annotate(executor="GPU"): + futures.append(c.submit(get_thread_name, pure=False)) + default_result, gpu_result = await c.gather(futures) + assert "Dask-Default-Threads" in default_result + assert "Dask-GPU-Threads" in gpu_result diff --git a/distributed/worker.py b/distributed/worker.py index 15e76544a8a..8e803c594a6 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1,5 +1,6 @@ import asyncio import bisect +import concurrent.futures import errno import heapq import logging @@ -16,7 +17,7 @@ from functools import partial from inspect import isawaitable from pickle import PicklingError -from typing import Iterable +from typing import Dict, Iterable from tlz import first, keymap, merge, pluck # noqa: F401 from tornado import gen @@ -28,7 +29,7 @@ from dask.system import CPU_COUNT from dask.utils import format_bytes, funcname -from . import comm, preloading, profile, system +from . import comm, preloading, profile, system, utils from .batched import BatchedSend from .comm import connect, get_address_host from .comm.addressing import address_from_user_args @@ -150,6 +151,8 @@ class TaskState: serializable (e.g. int, string, list, dict). * **nbytes**: ``int`` The size of a particular piece of data + * **annotations**: ``dict`` + Task annotations Parameters ---------- @@ -184,6 +187,7 @@ def __init__(self, key, runspec=None): self.stop_time = None self.metadata = {} self.nbytes = None + self.annotations = None def __repr__(self): return "" % (self.key, self.state) @@ -223,11 +227,9 @@ class Worker(ServerNode): * **nthreads:** ``int``: Number of nthreads used by this worker process - * **executor:** ``concurrent.futures.ThreadPoolExecutor``: - Executor used to perform computation - This can also be the string "offload" in which case this uses the same - thread pool used for offloading communications. This results in the - same thread being used for deserialization and computation. + * **executors:** ``Dict[str, concurrent.futures.Executor]``: + Executors used to perform computation. Always contains the default + executor. * **local_directory:** ``path``: Path on local machine to store temporary files * **scheduler:** ``rpc``: @@ -324,7 +326,15 @@ class Worker(ServerNode): Fraction of memory at which we start spilling to disk memory_pause_fraction: float Fraction of memory at which we stop running new tasks - executor: concurrent.futures.Executor + executor: concurrent.futures.Executor, dict[str, concurrent.futures.Executor], str + The executor(s) to use. Depending on the type, it has the following meanings: + - Executor instance: The default executor. + - Dict[str, Executor]: mapping names to Executor instances. If the + "default" key isn't in the dict, a "default" executor will be created + using ``ThreadPoolExecutor(nthreads)``. + - Str: The string "offload", which refer to the same thread pool used for + offloading communications. This results in the same thread being used + for deserialization and computation. resources: dict Resources that this worker has like ``{'GPU': 2}`` nanny: str @@ -626,14 +636,25 @@ def __init__( self.actors = {} self.loop = loop or IOLoop.current() self.reconnect = reconnect + + # Common executors always available + self.executors: Dict[str, concurrent.futures.Executor] = { + "offload": utils._offload_executor, + "actor": ThreadPoolExecutor(1, thread_name_prefix="Dask-Actor-Threads"), + } + + # Find the default executor if executor == "offload": - from distributed.utils import _offload_executor as executor - self.executor = executor or ThreadPoolExecutor( - self.nthreads, thread_name_prefix="Dask-Worker-Threads'" - ) - self.actor_executor = ThreadPoolExecutor( - 1, thread_name_prefix="Dask-Actor-Threads" - ) + self.executors["default"] = self.executors["offload"] + elif isinstance(executor, dict): + self.executors.update(executor) + elif executor is not None: + self.executors["default"] = executor + if "default" not in self.executors: + self.executors["default"] = ThreadPoolExecutor( + self.nthreads, thread_name_prefix="Dask-Default-Threads" + ) + self.batched_stream = BatchedSend(interval="2ms", loop=self.loop) self.name = name self.scheduler_delay = 0 @@ -808,6 +829,10 @@ def local_dir(self): ) return self.local_directory + @property + def executor(self): + return self.executors["default"] + async def get_metrics(self): out = dict( executing=self.executing_count, @@ -1280,13 +1305,14 @@ async def close( with suppress(TimeoutError): await self.batched_stream.close(timedelta(seconds=timeout)) - self.actor_executor._work_queue.queue.clear() - if isinstance(self.executor, ThreadPoolExecutor): - self.executor._work_queue.queue.clear() - self.executor.shutdown(wait=executor_wait, timeout=timeout) - else: - self.executor.shutdown(wait=False) - self.actor_executor.shutdown(wait=executor_wait, timeout=timeout) + for executor in self.executors.values(): + if executor is utils._offload_executor: + continue # Never shutdown the offload executor + if isinstance(executor, ThreadPoolExecutor): + executor._work_queue.queue.clear() + executor.shutdown(wait=executor_wait, timeout=timeout) + else: + executor.shutdown(wait=executor_wait) self.stop() await self.rpc.close() @@ -1499,6 +1525,7 @@ def add_task( duration=None, resource_restrictions=None, actor=False, + annotations=None, **kwargs2, ): try: @@ -1545,6 +1572,7 @@ def add_task( ts.duration = duration if resource_restrictions: ts.resource_restrictions = resource_restrictions + ts.annotations = annotations who_has = who_has or {} @@ -2560,7 +2588,7 @@ def executor_submit(self, key, function, args=(), kwargs=None, executor=None): callbacks to ensure things run smoothly. This can get tricky, so we pull it off into an separate method. """ - executor = executor or self.executor + executor = executor or self.executors["default"] job_counter[0] += 1 # logger.info("%s:%d Starts job %d, %s", self.ip, self.port, i, key) kwargs = kwargs or {} @@ -2656,7 +2684,7 @@ async def actor_execute( self.active_threads, self.active_threads_lock, ), - executor=self.actor_executor, + executor=self.executors["actor"], ) else: result = func(*args, **kwargs) @@ -2790,8 +2818,17 @@ async def execute(self, key, report=False): if self.digests is not None: self.digests["disk-load-duration"].add(stop - start) + if ts.annotations is not None and "executor" in ts.annotations: + executor = ts.annotations["executor"] + else: + executor = "default" + assert executor in self.executors + logger.debug( - "Execute key: %s worker: %s", ts.key, self.address + "Execute key: %s worker: %s, executor: %s", + ts.key, + self.address, + executor, ) # TODO: comment out? assert key == ts.key try: @@ -2808,6 +2845,7 @@ async def execute(self, key, report=False): self.active_threads_lock, self.scheduler_delay, ), + executor=self.executors[executor], ) except RuntimeError as e: executor_error = e