Skip to content

Commit

Permalink
Merge pull request #1 from mrocklin/test-fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
jrbourbeau authored May 6, 2020
2 parents a254927 + 72398cd commit 91afdc1
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 38 deletions.
12 changes: 3 additions & 9 deletions distributed/cli/tests/test_dask_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,11 +305,6 @@ def test_preload_remote_module(loop, tmp_path):
f.write(PRELOAD_TEXT)

with popen([sys.executable, "-m", "http.server", "9382"], cwd=tmp_path):
import requests

data = requests.get("http://localhost:9382/scheduler_info.py").content
assert b"scheduler.foo" in data

with popen(
[
"dask-scheduler",
Expand All @@ -322,10 +317,9 @@ def test_preload_remote_module(loop, tmp_path):
with Client(
scheduler_file=tmp_path / "scheduler-file.json", loop=loop
) as c:
assert (
c.run_on_scheduler(lambda dask_scheduler: dask_scheduler.foo)
== "bar"
)
assert c.run_on_scheduler(
lambda dask_scheduler: getattr(dask_scheduler, "foo", None)
) == "bar"


PRELOAD_COMMAND_TEXT = """
Expand Down
1 change: 1 addition & 0 deletions distributed/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ async def handle_comm(self, comm, shutting_down=shutting_down):

logger.debug("Connection from %r to %s", address, type(self).__name__)
self._comms[comm] = op
await self
try:
while True:
try:
Expand Down
23 changes: 13 additions & 10 deletions distributed/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def __init__(
deserialize=deserialize,
io_loop=self.io_loop,
)
self._startup_lock = asyncio.Lock()

def versions(self, comm=None, packages=None):
return get_versions(packages=packages)
Expand Down Expand Up @@ -169,26 +170,28 @@ async def __aexit__(self, typ, value, traceback):
await self.close()

def __await__(self):
if self.status == "running":
return gen.sleep(0).__await__()
else:
future = self.start()
async def _():
timeout = getattr(self, "death_timeout", 0)
if timeout:

async def wait_for(future, timeout=None):
async with self._startup_lock:
if self.status == "running":
return self
if timeout:
try:
await asyncio.wait_for(future, timeout=timeout)
await asyncio.wait_for(self.start(), timeout=timeout)
self.status = "running"
except Exception:
await self.close(timeout=1)
raise TimeoutError(
"{} failed to start in {} seconds".format(
type(self).__name__, timeout
)
)
else:
await self.start()
self.status = "running"
return self

future = wait_for(future, timeout=timeout)
return future.__await__()
return _().__await__()

async def start(self):
# subclasses should implement their own start method whichs calls super().start()
Expand Down
37 changes: 18 additions & 19 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1426,8 +1426,8 @@ def get_worker_service_addr(self, worker, service_name, protocol=False):

async def start(self):
""" Clear out old state and restart all running coroutines """

await super().start()
assert self.status != "running"

enable_gc_diagnosis()

Expand All @@ -1437,28 +1437,26 @@ async def start(self):
for c in self._worker_coroutines:
c.cancel()

if self.status != "running":
for addr in self._start_address:
await self.listen(addr, **self.security.get_listen_args("scheduler"))
self.ip = get_address_host(self.listen_address)
listen_ip = self.ip
for addr in self._start_address:
await self.listen(addr, **self.security.get_listen_args("scheduler"))
self.ip = get_address_host(self.listen_address)
listen_ip = self.ip

if listen_ip == "0.0.0.0":
listen_ip = ""
if listen_ip == "0.0.0.0":
listen_ip = ""

if self.address.startswith("inproc://"):
listen_ip = "localhost"
if self.address.startswith("inproc://"):
listen_ip = "localhost"

# Services listen on all addresses
self.start_services(listen_ip)
# Services listen on all addresses
self.start_services(listen_ip)

self.status = "running"
for listener in self.listeners:
logger.info(" Scheduler at: %25s", listener.contact_address)
for k, v in self.services.items():
logger.info("%11s at: %25s", k, "%s:%d" % (listen_ip, v.port))
for listener in self.listeners:
logger.info(" Scheduler at: %25s", listener.contact_address)
for k, v in self.services.items():
logger.info("%11s at: %25s", k, "%s:%d" % (listen_ip, v.port))

self.loop.add_callback(self.reevaluate_occupancy)
self.loop.add_callback(self.reevaluate_occupancy)

if self.scheduler_file:
with open(self.scheduler_file, "w") as f:
Expand Down Expand Up @@ -2937,7 +2935,8 @@ async def restart(self, client=None, timeout=3):
finally:
await asyncio.gather(*[nanny.close_rpc() for nanny in nannies])

await self.start()
self.status = None
await self

self.log_event([client, "all"], {"action": "restart", "client": client})
start = time()
Expand Down

0 comments on commit 91afdc1

Please sign in to comment.