Skip to content

Commit

Permalink
Add worker config to cache task completion results. (spotify#3178)
Browse files Browse the repository at this point in the history
* Add option to cache completion results.

* Extend completion cache to Worker methods.

* Sync completation cache after running task.

* Increase test coverage.
  • Loading branch information
riga authored Jun 24, 2022
1 parent 1205644 commit f8254fd
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 10 deletions.
9 changes: 9 additions & 0 deletions doc/configuration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,15 @@ check_complete_on_run
missing.
Defaults to false.

cache_task_completion
By default, luigi task processes might check the completion status multiple
times per task which is a safe way to avoid potential inconsistencies. For
tasks with many dynamic dependencies, yielded in multiple stages, this might
become expensive, e.g. in case the per-task completion check entails remote
resources. When set to true, completion checks are cached so that tasks
declared as complete once are not checked again.
Defaults to false.


[elasticsearch]
---------------
Expand Down
57 changes: 47 additions & 10 deletions luigi/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import subprocess
import sys
import contextlib
import functools

import queue as Queue
import random
Expand Down Expand Up @@ -117,7 +118,7 @@ class TaskProcess(multiprocessing.Process):

def __init__(self, task, worker_id, result_queue, status_reporter,
use_multiprocessing=False, worker_timeout=0, check_unfulfilled_deps=True,
check_complete_on_run=False):
check_complete_on_run=False, task_completion_cache=None):
super(TaskProcess, self).__init__()
self.task = task
self.worker_id = worker_id
Expand All @@ -128,6 +129,10 @@ def __init__(self, task, worker_id, result_queue, status_reporter,
self.use_multiprocessing = use_multiprocessing or self.timeout_time is not None
self.check_unfulfilled_deps = check_unfulfilled_deps
self.check_complete_on_run = check_complete_on_run
self.task_completion_cache = task_completion_cache

# completeness check using the cache
self.check_complete = functools.partial(check_complete_cached, completion_cache=task_completion_cache)

def _run_get_new_deps(self):
task_gen = self.task.run()
Expand All @@ -146,7 +151,7 @@ def _run_get_new_deps(self):
return None

new_req = flatten(requires)
if all(t.complete() for t in new_req):
if all(self.check_complete(t) for t in new_req):
next_send = getpaths(requires)
else:
new_deps = [(t.task_module, t.task_family, t.to_str_params())
Expand All @@ -172,7 +177,7 @@ def run(self):
# checking completeness of self.task so outputs of dependencies are
# irrelevant.
if self.check_unfulfilled_deps and not _is_external(self.task):
missing = [dep.task_id for dep in self.task.deps() if not dep.complete()]
missing = [dep.task_id for dep in self.task.deps() if not self.check_complete(dep)]
if missing:
deps = 'dependency' if len(missing) == 1 else 'dependencies'
raise RuntimeError('Unfulfilled %s at run time: %s' % (deps, ', '.join(missing)))
Expand All @@ -182,7 +187,7 @@ def run(self):

if _is_external(self.task):
# External task
if self.task.complete():
if self.check_complete(self.task):
status = DONE
else:
status = FAILED
Expand All @@ -192,7 +197,12 @@ def run(self):
with self._forward_attributes():
new_deps = self._run_get_new_deps()
if not new_deps:
if not self.check_complete_on_run or self.task.complete():
if not self.check_complete_on_run:
# update the cache
if self.task_completion_cache is not None:
self.task_completion_cache[self.task.task_id] = True
status = DONE
elif self.check_complete(self.task):
status = DONE
else:
raise TaskException("Task finished running, but complete() is still returning false.")
Expand Down Expand Up @@ -394,13 +404,29 @@ def __init__(self, trace):
self.trace = trace


def check_complete(task, out_queue):
def check_complete_cached(task, completion_cache=None):
# check if cached and complete
cache_key = task.task_id
if completion_cache is not None and completion_cache.get(cache_key):
return True

# (re-)check the status
is_complete = task.complete()

# tell the cache when complete
if completion_cache is not None and is_complete:
completion_cache[cache_key] = is_complete

return is_complete


def check_complete(task, out_queue, completion_cache=None):
"""
Checks if task is complete, puts the result to out_queue.
Checks if task is complete, puts the result to out_queue, optionally using the completion cache.
"""
logger.debug("Checking if %s is complete", task)
try:
is_complete = task.complete()
is_complete = check_complete_cached(task, completion_cache)
except Exception:
is_complete = TracebackWrapper(traceback.format_exc())
out_queue.put((task, is_complete))
Expand Down Expand Up @@ -462,6 +488,11 @@ class worker(Config):
'applied as a context manager around its run() call, so this can be '
'used for obtaining high level customizable monitoring or logging of '
'each individual Task run.')
cache_task_completion = BoolParameter(default=False,
description='If true, cache the response of successful completion checks '
'of tasks assigned to a worker. This can especially speed up tasks with '
'dynamic dependencies but assumes that the completion status does not change '
'after it was true the first time.')


class KeepAliveThread(threading.Thread):
Expand Down Expand Up @@ -560,6 +591,11 @@ def __init__(self, scheduler=None, worker_id=None, worker_processes=1, assistant
self._running_tasks = {}
self._idle_since = None

# mp-safe dictionary for caching completation checks across task processes
self._task_completion_cache = None
if self._config.cache_task_completion:
self._task_completion_cache = multiprocessing.Manager().dict()

# Stuff for execution_summary
self._add_task_history = []
self._get_work_response_history = []
Expand Down Expand Up @@ -745,7 +781,7 @@ def add(self, task, multiprocess=False, processes=0):
queue = DequeQueue()
pool = SingleProcessPool()
self._validate_task(task)
pool.apply_async(check_complete, [task, queue])
pool.apply_async(check_complete, [task, queue, self._task_completion_cache])

# we track queue size ourselves because len(queue) won't work for multiprocessing
queue_size = 1
Expand All @@ -759,7 +795,7 @@ def add(self, task, multiprocess=False, processes=0):
if next.task_id not in seen:
self._validate_task(next)
seen.add(next.task_id)
pool.apply_async(check_complete, [next, queue])
pool.apply_async(check_complete, [next, queue, self._task_completion_cache])
queue_size += 1
except (KeyboardInterrupt, TaskException):
raise
Expand Down Expand Up @@ -1024,6 +1060,7 @@ def _create_task_process(self, task):
worker_timeout=self._config.timeout,
check_unfulfilled_deps=self._config.check_unfulfilled_deps,
check_complete_on_run=self._config.check_complete_on_run,
task_completion_cache=self._task_completion_cache,
)

def _purge_children(self):
Expand Down
82 changes: 82 additions & 0 deletions test/worker_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,88 @@ def requires(self):
self.assertEqual(a2.complete_count, 2)
self.assertEqual(b2.complete_count, 2)

def test_cache_task_completion_config(self):
class A(Task):

i = luigi.IntParameter()

def __init__(self, *args, **kwargs):
super(A, self).__init__(*args, **kwargs)
self.complete_count = 0
self.has_run = False

def complete(self):
self.complete_count += 1
return self.has_run

def run(self):
self.has_run = True

class B(A):

def run(self):
yield A(i=self.i + 0)
yield A(i=self.i + 1)
yield A(i=self.i + 2)
self.has_run = True

# test with enabled cache_task_completion
with Worker(scheduler=self.sch, worker_id='2', cache_task_completion=True) as w:
b0 = B(i=0)
a0 = A(i=0)
a1 = A(i=1)
a2 = A(i=2)
self.assertTrue(w.add(b0))
# a's are required dynamically, so their counts must be 0
self.assertEqual(b0.complete_count, 1)
self.assertEqual(a0.complete_count, 0)
self.assertEqual(a1.complete_count, 0)
self.assertEqual(a2.complete_count, 0)
w.run()
# the complete methods of a's yielded first in b's run method were called equally often
self.assertEqual(b0.complete_count, 1)
self.assertEqual(a0.complete_count, 2)
self.assertEqual(a1.complete_count, 2)
self.assertEqual(a2.complete_count, 2)

# test with disabled cache_task_completion
with Worker(scheduler=self.sch, worker_id='2', cache_task_completion=False) as w:
b10 = B(i=10)
a10 = A(i=10)
a11 = A(i=11)
a12 = A(i=12)
self.assertTrue(w.add(b10))
# a's are required dynamically, so their counts must be 0
self.assertEqual(b10.complete_count, 1)
self.assertEqual(a10.complete_count, 0)
self.assertEqual(a11.complete_count, 0)
self.assertEqual(a12.complete_count, 0)
w.run()
# the complete methods of a's yielded first in b's run method were called more often
self.assertEqual(b10.complete_count, 1)
self.assertEqual(a10.complete_count, 5)
self.assertEqual(a11.complete_count, 4)
self.assertEqual(a12.complete_count, 3)

# test with enabled check_complete_on_run
with Worker(scheduler=self.sch, worker_id='2', check_complete_on_run=True) as w:
b20 = B(i=20)
a20 = A(i=20)
a21 = A(i=21)
a22 = A(i=22)
self.assertTrue(w.add(b20))
# a's are required dynamically, so their counts must be 0
self.assertEqual(b20.complete_count, 1)
self.assertEqual(a20.complete_count, 0)
self.assertEqual(a21.complete_count, 0)
self.assertEqual(a22.complete_count, 0)
w.run()
# the complete methods of a's yielded first in b's run method were called more often
self.assertEqual(b20.complete_count, 2)
self.assertEqual(a20.complete_count, 6)
self.assertEqual(a21.complete_count, 5)
self.assertEqual(a22.complete_count, 4)

def test_gets_missed_work(self):
class A(Task):
done = False
Expand Down

0 comments on commit f8254fd

Please sign in to comment.