Skip to content

Commit

Permalink
make all manager methods async (#273)
Browse files Browse the repository at this point in the history
* make all ClusterManager methods async

allows easier, more consistent override

- defer initialization to `await manager` instead of immediately (avoids invoking asyncio at import time)
- instantiate default manager as part of extension loading, not at import time

* fix dask_cluster_manager key
  • Loading branch information
minrk authored Feb 24, 2025
1 parent edac755 commit 19d5dac
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 46 deletions.
6 changes: 4 additions & 2 deletions dask_labextension/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

from jupyter_server.utils import url_path_join

from . import config
from . import config # noqa
from .clusterhandler import DaskClusterHandler
from .dashboardhandler import DaskDashboardCheckHandler, DaskDashboardHandler
from .manager import DaskClusterManager


from ._version import __version__
from ._version import __version__ # noqa


def _jupyter_labextension_paths():
Expand All @@ -33,6 +34,7 @@ def load_jupyter_server_extension(nb_server_app):
cluster_id_regex = r"(?P<cluster_id>[^/]+)"
web_app = nb_server_app.web_app
base_url = web_app.settings["base_url"]
web_app.settings["dask_cluster_manager"] = DaskClusterManager()
get_cluster_path = url_path_join(base_url, "dask/clusters/" + cluster_id_regex)
list_clusters_path = url_path_join(base_url, "dask/clusters/" + "?")
get_dashboard_path = url_path_join(
Expand Down
26 changes: 18 additions & 8 deletions dask_labextension/clusterhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,34 @@
# Distributed under the terms of the Modified BSD License.

import json
from inspect import isawaitable

from tornado import web
from jupyter_server.base.handlers import APIHandler

from .manager import manager
from .manager import DaskClusterManager


class DaskClusterHandler(APIHandler):
"""
A tornado HTTP handler for managing dask clusters.
"""

manager: DaskClusterManager

async def prepare(self):
r = super().prepare()
if isawaitable(r):
await r
self.manager = await self.settings["dask_cluster_manager"]

@web.authenticated
async def delete(self, cluster_id: str) -> None:
"""
Delete a cluster by id.
"""
try: # to delete the cluster.
val = await manager.close_cluster(cluster_id)
val = await self.manager.close_cluster(cluster_id)
if val is None:
raise web.HTTPError(404, f"Dask cluster {cluster_id} not found")

Expand All @@ -37,12 +46,13 @@ async def get(self, cluster_id: str = "") -> None:
"""
Get a cluster by id. If no id is given, lists known clusters.
"""
manager = self.manager
if cluster_id == "":
cluster_list = manager.list_clusters()
cluster_list = await manager.list_clusters()
self.set_status(200)
self.finish(json.dumps(cluster_list))
else:
cluster_model = manager.get_cluster(cluster_id)
cluster_model = await manager.get_cluster(cluster_id)
if cluster_model is None:
raise web.HTTPError(404, f"Dask cluster {cluster_id} not found")

Expand All @@ -55,13 +65,13 @@ async def put(self, cluster_id: str = "") -> None:
Create a new cluster with a given id. If no id is given, a random
one is selected.
"""
if manager.get_cluster(cluster_id):
if await self.manager.get_cluster(cluster_id):
raise web.HTTPError(
403, f"A Dask cluster with ID {cluster_id} already exists!"
)

try:
cluster_model = await manager.start_cluster(cluster_id)
cluster_model = await self.manager.start_cluster(cluster_id)
self.set_status(200)
self.finish(json.dumps(cluster_model))
except Exception as e:
Expand All @@ -76,13 +86,13 @@ async def patch(self, cluster_id):
new_model = json.loads(self.request.body)
try:
if new_model.get("adapt") is not None:
cluster_model = manager.adapt_cluster(
cluster_model = await self.manager.adapt_cluster(
cluster_id,
new_model["adapt"]["minimum"],
new_model["adapt"]["maximum"],
)
else:
cluster_model = await manager.scale_cluster(
cluster_model = await self.manager.scale_cluster(
cluster_id, new_model["workers"]
)
self.set_status(200)
Expand Down
22 changes: 16 additions & 6 deletions dask_labextension/dashboardhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,32 @@
server, preventing CORS issues.
"""
import json
from inspect import isawaitable
from urllib import parse

from tornado import httpclient, web


from jupyter_server.base.handlers import APIHandler
from jupyter_server.utils import url_path_join
from jupyter_server_proxy.handlers import ProxyHandler

from .manager import manager
from .manager import DaskClusterManager


class DaskDashboardCheckHandler(APIHandler):
"""
A handler for checking validity of a dask dashboard.
"""

manager: DaskClusterManager

async def prepare(self):
r = super().prepare()
if isawaitable(r):
await r
self.manager = await self.settings["dask_cluster_manager"]

@web.authenticated
async def get(self, url) -> None:
"""
Expand Down Expand Up @@ -133,7 +143,7 @@ async def http_get(self, cluster_id, proxied_path):
return await self.proxy(cluster_id, proxied_path)

async def open(self, cluster_id, proxied_path):
host, port = self._get_parsed(cluster_id)
host, port = await self._get_parsed(cluster_id)
return await super().proxy_open(host, port, proxied_path)

# We have to duplicate all these for now, I've no idea why!
Expand All @@ -157,17 +167,17 @@ def patch(self, cluster_id, proxied_path):
def options(self, cluster_id, proxied_path):
return self.proxy(cluster_id, proxied_path)

def proxy(self, cluster_id, proxied_path):
host, port = self._get_parsed(cluster_id)
async def proxy(self, cluster_id, proxied_path):
host, port = await self._get_parsed(cluster_id)
return super().proxy(host, port, proxied_path)

def _get_parsed(self, cluster_id):
async def _get_parsed(self, cluster_id):
"""
Given a cluster ID, get the hostname and port of its bokeh server.
"""
# Get the cluster by ID. If it is not found,
# raise an error.
cluster_model = manager.get_cluster(cluster_id)
cluster_model = await self.manager.get_cluster(cluster_id)
if not cluster_model:
raise web.HTTPError(404, f"Dask cluster {cluster_id} not found")

Expand Down
43 changes: 26 additions & 17 deletions dask_labextension/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.

import asyncio
import importlib
from inspect import isawaitable
from typing import Any, Dict, List, Union
Expand All @@ -11,8 +12,6 @@
import dask
from dask.utils import format_bytes
from dask.distributed import Adaptive
from tornado.ioloop import IOLoop
from tornado.concurrent import Future

# A type for a dask cluster model: a serializable
# representation of information about the cluster.
Expand Down Expand Up @@ -60,15 +59,28 @@ def __init__(self) -> None:
self._adaptives: Dict[str, Adaptive] = dict()
self._cluster_names: Dict[str, str] = dict()
self._n_clusters = 0
self._initialized = None

self.initialized = Future()
async def _async_init(self):
"""The async part of init
async def start_clusters():
for model in dask.config.get("labextension.initial"):
await self.start_cluster(configuration=model)
self.initialized.set_result(self)
Invoked by `await manager`
"""
for model in dask.config.get("labextension.initial"):
await self.start_cluster(configuration=model)
return self

IOLoop.current().add_callback(start_clusters)
@property
def initialized(self):
"""Don't create initialization task until it's been requested
typically via `await manager`
Makes it easier to ensure we don't do anything before we are in the event loop.
"""
if self._initialized is None:
self._initialized = asyncio.create_task(self._async_init())
return self._initialized

async def start_cluster(
self, cluster_id: str = "", configuration: dict = {}
Expand Down Expand Up @@ -121,7 +133,9 @@ async def close_cluster(self, cluster_id: str) -> Union[ClusterModel, None]:
"""
cluster = self._clusters.get(cluster_id)
if cluster:
await cluster.close()
r = cluster.close()
if isawaitable(r):
await r
self._clusters.pop(cluster_id)
name = self._cluster_names.pop(cluster_id)
adaptive = self._adaptives.pop(cluster_id, None)
Expand All @@ -130,7 +144,7 @@ async def close_cluster(self, cluster_id: str) -> Union[ClusterModel, None]:
else:
return None

def get_cluster(self, cluster_id) -> Union[ClusterModel, None]:
async def get_cluster(self, cluster_id) -> Union[ClusterModel, None]:
"""
Get a Dask cluster model.
Expand All @@ -151,7 +165,7 @@ def get_cluster(self, cluster_id) -> Union[ClusterModel, None]:

return make_cluster_model(cluster_id, name, cluster, adaptive)

def list_clusters(self) -> List[ClusterModel]:
async def list_clusters(self) -> List[ClusterModel]:
"""
List the Dask cluster models known to the manager.
Expand Down Expand Up @@ -188,7 +202,7 @@ async def scale_cluster(self, cluster_id: str, n: int) -> Union[ClusterModel, No
await t
return make_cluster_model(cluster_id, name, cluster, adaptive=None)

def adapt_cluster(
async def adapt_cluster(
self, cluster_id: str, minimum: int, maximum: int
) -> Union[ClusterModel, None]:
cluster = self._clusters.get(cluster_id)
Expand Down Expand Up @@ -290,8 +304,3 @@ def make_cluster_model(
model["adapt"] = {"minimum": adaptive.minimum, "maximum": adaptive.maximum}

return model


# Create a default cluster manager
# to keep track of clusters.
manager = DaskClusterManager()
26 changes: 13 additions & 13 deletions dask_labextension/tests/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ async def test_start():
assert not model.get("adapt")

# close cluster
assert len(manager.list_clusters()) == 1
assert len(await manager.list_clusters()) == 1
await manager.close_cluster(model["id"])

# add cluster with adaptive configuration
Expand All @@ -55,7 +55,7 @@ async def test_close():

# close the cluster
await manager.close_cluster(model["id"])
assert not manager.list_clusters()
assert not await manager.list_clusters()


@gen_test()
Expand All @@ -66,10 +66,10 @@ async def test_get():
model = await manager.start_cluster()

# return None if a nonexistent cluster is requested
assert not manager.get_cluster("fake")
assert not await manager.get_cluster("fake")

# get the cluster by id
assert model == manager.get_cluster(model["id"])
assert model == await manager.get_cluster(model["id"])


@pytest.mark.filterwarnings("ignore")
Expand All @@ -78,12 +78,12 @@ async def test_list():
with dask.config.set(config):
async with DaskClusterManager() as manager:
# start with an empty list
assert not manager.list_clusters()
assert not await manager.list_clusters()
# start clusters
model1 = await manager.start_cluster()
model2 = await manager.start_cluster()

models = manager.list_clusters()
models = await manager.list_clusters()
assert len(models) == 2
assert model1 in models
assert model2 in models
Expand All @@ -98,7 +98,7 @@ async def test_scale():
start = time()
while model["workers"] != 3:
await sleep(0.01)
model = manager.get_cluster(model["id"])
model = await manager.get_cluster(model["id"])
assert time() < start + 10, model["workers"]

await sleep(0.2) # let workers settle # TODO: remove need for this
Expand All @@ -108,7 +108,7 @@ async def test_scale():
start = time()
while model["workers"] != 6:
await sleep(0.01)
model = manager.get_cluster(model["id"])
model = await manager.get_cluster(model["id"])
assert time() < start + 10, model["workers"]


Expand All @@ -119,7 +119,7 @@ async def test_adapt():
# add a new cluster
model = await manager.start_cluster()
assert not model.get("adapt")
model = manager.adapt_cluster(model["id"], 0, 4)
model = await manager.adapt_cluster(model["id"], 0, 4)
adapt = model.get("adapt")
assert adapt
assert adapt["minimum"] == 0
Expand All @@ -144,21 +144,21 @@ async def test_initial():
):
# Test asynchronous starting of clusters via a context
async with DaskClusterManager() as manager:
clusters = manager.list_clusters()
clusters = await manager.list_clusters()
assert len(clusters) == 1
assert clusters[0]["name"] == "foo"

# Test asynchronous starting of clusters outside of a context
manager = DaskClusterManager()
assert len(manager.list_clusters()) == 0
assert len(await manager.list_clusters()) == 0
await manager
clusters = manager.list_clusters()
clusters = await manager.list_clusters()
assert len(clusters) == 1
assert clusters[0]["name"] == "foo"
await manager.close()

manager = await DaskClusterManager()
clusters = manager.list_clusters()
clusters = await manager.list_clusters()
assert len(clusters) == 1
assert clusters[0]["name"] == "foo"
await manager.close()

0 comments on commit 19d5dac

Please sign in to comment.