Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add worker config to cache task completion results. #3178

Merged
merged 5 commits into from
Jun 24, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions doc/configuration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,15 @@ check_complete_on_run
missing.
Defaults to false.

cache_task_completion
By default, luigi task processes might check the completion status multiple
times per task which is a safe way to avoid potential inconsistencies. For
tasks with many dynamic dependencies, yielded in multiple stages, this might
become expensive, e.g. in case the per-task completion check entails remote
resources. When set to true, completion checks are cached so that tasks
declared as complete once are not checked again.
Defaults to false.


[elasticsearch]
---------------
Expand Down
57 changes: 47 additions & 10 deletions luigi/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import subprocess
import sys
import contextlib
import functools

import queue as Queue
import random
Expand Down Expand Up @@ -117,7 +118,7 @@ class TaskProcess(multiprocessing.Process):

def __init__(self, task, worker_id, result_queue, status_reporter,
use_multiprocessing=False, worker_timeout=0, check_unfulfilled_deps=True,
check_complete_on_run=False):
check_complete_on_run=False, task_completion_cache=None):
super(TaskProcess, self).__init__()
self.task = task
self.worker_id = worker_id
Expand All @@ -128,6 +129,10 @@ def __init__(self, task, worker_id, result_queue, status_reporter,
self.use_multiprocessing = use_multiprocessing or self.timeout_time is not None
self.check_unfulfilled_deps = check_unfulfilled_deps
self.check_complete_on_run = check_complete_on_run
self.task_completion_cache = task_completion_cache

# completeness check using the cache
self.check_complete = functools.partial(check_complete_cached, completion_cache=task_completion_cache)

def _run_get_new_deps(self):
task_gen = self.task.run()
Expand All @@ -146,7 +151,7 @@ def _run_get_new_deps(self):
return None

new_req = flatten(requires)
if all(t.complete() for t in new_req):
if all(self.check_complete(t) for t in new_req):
next_send = getpaths(requires)
else:
new_deps = [(t.task_module, t.task_family, t.to_str_params())
Expand All @@ -172,7 +177,7 @@ def run(self):
# checking completeness of self.task so outputs of dependencies are
# irrelevant.
if self.check_unfulfilled_deps and not _is_external(self.task):
missing = [dep.task_id for dep in self.task.deps() if not dep.complete()]
missing = [dep.task_id for dep in self.task.deps() if not self.check_complete(dep)]
if missing:
deps = 'dependency' if len(missing) == 1 else 'dependencies'
raise RuntimeError('Unfulfilled %s at run time: %s' % (deps, ', '.join(missing)))
Expand All @@ -182,7 +187,7 @@ def run(self):

if _is_external(self.task):
# External task
if self.task.complete():
if self.check_complete(self.task):
status = DONE
else:
status = FAILED
Expand All @@ -192,7 +197,12 @@ def run(self):
with self._forward_attributes():
new_deps = self._run_get_new_deps()
if not new_deps:
if not self.check_complete_on_run or self.task.complete():
if not self.check_complete_on_run:
# update the cache
if self.task_completion_cache is not None:
self.task_completion_cache[self.task.task_id] = True
status = DONE
elif self.check_complete(self.task):
status = DONE
else:
raise TaskException("Task finished running, but complete() is still returning false.")
Expand Down Expand Up @@ -394,13 +404,29 @@ def __init__(self, trace):
self.trace = trace


def check_complete(task, out_queue):
def check_complete_cached(task, completion_cache=None):
# check if cached and complete
cache_key = task.task_id
if completion_cache is not None and completion_cache.get(cache_key):
return True

# (re-)check the status
is_complete = task.complete()

# tell the cache when complete
if completion_cache is not None and is_complete:
completion_cache[cache_key] = is_complete

return is_complete


def check_complete(task, out_queue, completion_cache=None):
"""
Checks if task is complete, puts the result to out_queue.
Checks if task is complete, puts the result to out_queue, optionally using the completion cache.
"""
logger.debug("Checking if %s is complete", task)
try:
is_complete = task.complete()
is_complete = check_complete_cached(task, completion_cache)
except Exception:
is_complete = TracebackWrapper(traceback.format_exc())
out_queue.put((task, is_complete))
Expand Down Expand Up @@ -462,6 +488,11 @@ class worker(Config):
'applied as a context manager around its run() call, so this can be '
'used for obtaining high level customizable monitoring or logging of '
'each individual Task run.')
cache_task_completion = BoolParameter(default=False,
description='If true, cache the response of successful completion checks '
'of tasks assigned to a worker. This can especially speed up tasks with '
'dynamic dependencies but assumes that the completion status does not change '
'after it was true the first time.')


class KeepAliveThread(threading.Thread):
Expand Down Expand Up @@ -560,6 +591,11 @@ def __init__(self, scheduler=None, worker_id=None, worker_processes=1, assistant
self._running_tasks = {}
self._idle_since = None

# mp-safe dictionary for caching completation checks across task processes
self._task_completion_cache = None
if self._config.cache_task_completion:
self._task_completion_cache = multiprocessing.Manager().dict()

# Stuff for execution_summary
self._add_task_history = []
self._get_work_response_history = []
Expand Down Expand Up @@ -745,7 +781,7 @@ def add(self, task, multiprocess=False, processes=0):
queue = DequeQueue()
pool = SingleProcessPool()
self._validate_task(task)
pool.apply_async(check_complete, [task, queue])
pool.apply_async(check_complete, [task, queue, self._task_completion_cache])

# we track queue size ourselves because len(queue) won't work for multiprocessing
queue_size = 1
Expand All @@ -759,7 +795,7 @@ def add(self, task, multiprocess=False, processes=0):
if next.task_id not in seen:
self._validate_task(next)
seen.add(next.task_id)
pool.apply_async(check_complete, [next, queue])
pool.apply_async(check_complete, [next, queue, self._task_completion_cache])
queue_size += 1
except (KeyboardInterrupt, TaskException):
raise
Expand Down Expand Up @@ -1024,6 +1060,7 @@ def _create_task_process(self, task):
worker_timeout=self._config.timeout,
check_unfulfilled_deps=self._config.check_unfulfilled_deps,
check_complete_on_run=self._config.check_complete_on_run,
task_completion_cache=self._task_completion_cache,
)

def _purge_children(self):
Expand Down
82 changes: 82 additions & 0 deletions test/worker_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,88 @@ def requires(self):
self.assertEqual(a2.complete_count, 2)
self.assertEqual(b2.complete_count, 2)

def test_cache_task_completion_config(self):
class A(Task):

i = luigi.IntParameter()

def __init__(self, *args, **kwargs):
super(A, self).__init__(*args, **kwargs)
self.complete_count = 0
self.has_run = False

def complete(self):
self.complete_count += 1
return self.has_run

def run(self):
self.has_run = True

class B(A):

def run(self):
yield A(i=self.i + 0)
yield A(i=self.i + 1)
yield A(i=self.i + 2)
self.has_run = True

# test with enabled cache_task_completion
with Worker(scheduler=self.sch, worker_id='2', cache_task_completion=True) as w:
b0 = B(i=0)
a0 = A(i=0)
a1 = A(i=1)
a2 = A(i=2)
self.assertTrue(w.add(b0))
# a's are required dynamically, so their counts must be 0
self.assertEqual(b0.complete_count, 1)
self.assertEqual(a0.complete_count, 0)
self.assertEqual(a1.complete_count, 0)
self.assertEqual(a2.complete_count, 0)
w.run()
# the complete methods of a's yielded first in b's run method were called equally often
self.assertEqual(b0.complete_count, 1)
self.assertEqual(a0.complete_count, 2)
self.assertEqual(a1.complete_count, 2)
self.assertEqual(a2.complete_count, 2)

# test with disabled cache_task_completion
with Worker(scheduler=self.sch, worker_id='2', cache_task_completion=False) as w:
b10 = B(i=10)
a10 = A(i=10)
a11 = A(i=11)
a12 = A(i=12)
self.assertTrue(w.add(b10))
# a's are required dynamically, so their counts must be 0
self.assertEqual(b10.complete_count, 1)
self.assertEqual(a10.complete_count, 0)
self.assertEqual(a11.complete_count, 0)
self.assertEqual(a12.complete_count, 0)
w.run()
# the complete methods of a's yielded first in b's run method were called more often
self.assertEqual(b10.complete_count, 1)
self.assertEqual(a10.complete_count, 5)
self.assertEqual(a11.complete_count, 4)
self.assertEqual(a12.complete_count, 3)
Comment on lines +496 to +498
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm a little unclear why these complete_count values for each of the tasks differ in quantity. Could you clarify that for me?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, i'm confused as to why the assertions below (with check_complete_on_run=True) resulted in larger complete_count quantities.

Copy link
Contributor Author

@riga riga Jun 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm a little unclear why these complete_count values for each of the tasks differ in quantity. Could you clarify that for me?

Sure!

The complete count with cache_task_completion=False of (5, 4, 3) for (a10, a11, a12) is due to luigi's assumption of idempotence of run() methods that yield dynamic dependencies in that the worker invokes run() and in case it's a generator, it get's the next result and

  • if it's a bunch of already complete tasks (some of the completeness checks is happening here), it gets the next generator result, or
  • if it's a bunch of tasks of which at least one is not complete yet, it adds all of them to the tree and forgets about the state of the generator.

(code here) The yielding task is placed back to the tree in PENDING state, too. And when it's started again later on, the entire procedure is triggered again, leading to a new generator in its initial state, but now with the previously incomplete bunch being complete. Therefore, completion checks of tasks of a certain bunch are always performed at least once more than those in the next bunch.

Similarly, i'm confused as to why the assertions below (with check_complete_on_run=True) resulted in larger complete_count quantities.

With check_complete_on_run=True there is a single, additional call happening here which is increasing the counts. I wanted to check if that's consistent with the proposed changes so I added this block in the same test.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, i see now. This makes much more sense now! Thank you for the thorough response!


# test with enabled check_complete_on_run
with Worker(scheduler=self.sch, worker_id='2', check_complete_on_run=True) as w:
b20 = B(i=20)
a20 = A(i=20)
a21 = A(i=21)
a22 = A(i=22)
self.assertTrue(w.add(b20))
# a's are required dynamically, so their counts must be 0
self.assertEqual(b20.complete_count, 1)
self.assertEqual(a20.complete_count, 0)
self.assertEqual(a21.complete_count, 0)
self.assertEqual(a22.complete_count, 0)
w.run()
# the complete methods of a's yielded first in b's run method were called more often
self.assertEqual(b20.complete_count, 2)
self.assertEqual(a20.complete_count, 6)
self.assertEqual(a21.complete_count, 5)
self.assertEqual(a22.complete_count, 4)

def test_gets_missed_work(self):
class A(Task):
done = False
Expand Down