From a434e1fcee1fd9a34952e0e9badfd4cce51e5455 Mon Sep 17 00:00:00 2001 From: Marcel R Date: Tue, 21 Jun 2022 16:06:03 +0200 Subject: [PATCH 1/4] Add option to cache completion results. --- doc/configuration.rst | 9 ++++++ luigi/worker.py | 41 +++++++++++++++++++++++---- test/worker_test.py | 65 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 110 insertions(+), 5 deletions(-) diff --git a/doc/configuration.rst b/doc/configuration.rst index 9a9103c9ee..2499fb7567 100644 --- a/doc/configuration.rst +++ b/doc/configuration.rst @@ -356,6 +356,15 @@ check_complete_on_run missing. Defaults to false. +cache_task_completion + By default, luigi task processes might check the completion status multiple + times per task which is a safe way to avoid potential inconsistencies. For + tasks with many dynamic dependencies, yielded in multiple stages, this might + become expensive, e.g. in case the per-task completion check entails remote + resources. When set to true, completion checks are cached so that tasks + declared as complete once are not checked again. + Defaults to false. + [elasticsearch] --------------- diff --git a/luigi/worker.py b/luigi/worker.py index 84359e4740..e4f890aee8 100644 --- a/luigi/worker.py +++ b/luigi/worker.py @@ -117,7 +117,7 @@ class TaskProcess(multiprocessing.Process): def __init__(self, task, worker_id, result_queue, status_reporter, use_multiprocessing=False, worker_timeout=0, check_unfulfilled_deps=True, - check_complete_on_run=False): + check_complete_on_run=False, task_completion_cache=None): super(TaskProcess, self).__init__() self.task = task self.worker_id = worker_id @@ -128,6 +128,7 @@ def __init__(self, task, worker_id, result_queue, status_reporter, self.use_multiprocessing = use_multiprocessing or self.timeout_time is not None self.check_unfulfilled_deps = check_unfulfilled_deps self.check_complete_on_run = check_complete_on_run + self.task_completion_cache = task_completion_cache def _run_get_new_deps(self): task_gen = self.task.run() @@ -146,7 +147,7 @@ def _run_get_new_deps(self): return None new_req = flatten(requires) - if all(t.complete() for t in new_req): + if all(self._check_complete(t) for t in new_req): next_send = getpaths(requires) else: new_deps = [(t.task_module, t.task_family, t.to_str_params()) @@ -172,7 +173,7 @@ def run(self): # checking completeness of self.task so outputs of dependencies are # irrelevant. if self.check_unfulfilled_deps and not _is_external(self.task): - missing = [dep.task_id for dep in self.task.deps() if not dep.complete()] + missing = [dep.task_id for dep in self.task.deps() if not self._check_complete(dep)] if missing: deps = 'dependency' if len(missing) == 1 else 'dependencies' raise RuntimeError('Unfulfilled %s at run time: %s' % (deps, ', '.join(missing))) @@ -182,7 +183,7 @@ def run(self): if _is_external(self.task): # External task - if self.task.complete(): + if self._check_complete(self.task): status = DONE else: status = FAILED @@ -192,7 +193,7 @@ def run(self): with self._forward_attributes(): new_deps = self._run_get_new_deps() if not new_deps: - if not self.check_complete_on_run or self.task.complete(): + if not self.check_complete_on_run or self._check_complete(self.task): status = DONE else: raise TaskException("Task finished running, but complete() is still returning false.") @@ -266,6 +267,24 @@ def _forward_attributes(self): for reporter_attr, task_attr in self.forward_reporter_attributes.items(): setattr(self.task, task_attr, None) + def _check_complete(self, task): + """ + Checks if a task is complete, optionally using the task_completion_cache. + """ + task_id = task.task_id + + # return True if caching is used and the task was already complete + if self.task_completion_cache is not None and self.task_completion_cache.get(task_id): + return True + + is_complete = task.complete() + + # update the cache when used + if self.task_completion_cache is not None: + self.task_completion_cache[task_id] = is_complete + + return is_complete + # This code and the task_process_context config key currently feels a bit ad-hoc. # Discussion on generalizing it into a plugin system: https://github.com/spotify/luigi/issues/1897 @@ -462,6 +481,11 @@ class worker(Config): 'applied as a context manager around its run() call, so this can be ' 'used for obtaining high level customizable monitoring or logging of ' 'each individual Task run.') + cache_task_completion = BoolParameter(default=False, + description='If true, cache the response of successful completion checks ' + 'of tasks assigned to a worker. This can especially speed up tasks with ' + 'dynamic dependencies but assumes that the completion status does not change ' + 'after it was true the first time.') class KeepAliveThread(threading.Thread): @@ -560,6 +584,9 @@ def __init__(self, scheduler=None, worker_id=None, worker_processes=1, assistant self._running_tasks = {} self._idle_since = None + # mp-safe dictionary for caching completation checks across task processes (set in run) + self._task_completion_cache = None + # Stuff for execution_summary self._add_task_history = [] self._get_work_response_history = [] @@ -1024,6 +1051,7 @@ def _create_task_process(self, task): worker_timeout=self._config.timeout, check_unfulfilled_deps=self._config.check_unfulfilled_deps, check_complete_on_run=self._config.check_complete_on_run, + task_completion_cache=self._task_completion_cache, ) def _purge_children(self): @@ -1181,6 +1209,9 @@ def run(self): self._add_worker() + if self._config.cache_task_completion: + self._task_completion_cache = multiprocessing.Manager().dict() + while True: while len(self._running_tasks) >= self.worker_processes > 0: logger.debug('%d running tasks, waiting for next task to finish', len(self._running_tasks)) diff --git a/test/worker_test.py b/test/worker_test.py index 7f09314b82..0ba103a0c4 100644 --- a/test/worker_test.py +++ b/test/worker_test.py @@ -434,6 +434,71 @@ def requires(self): self.assertEqual(a2.complete_count, 2) self.assertEqual(b2.complete_count, 2) + def test_cache_task_completion_config(self): + class A(Task): + + i = luigi.IntParameter() + + def __init__(self, *args, **kwargs): + super(A, self).__init__(*args, **kwargs) + self.complete_count = 0 + self.has_run = False + + def complete(self): + self.complete_count += 1 + return self.has_run + + def run(self): + self.has_run = True + + class B(A): + + def run(self): + yield A(i=self.i + 0) + yield A(i=self.i + 1) + yield A(i=self.i + 2) + self.has_run = True + + # test the enabled feature + with Worker(scheduler=self.sch, worker_id='2') as w: + w._config.cache_task_completion = True + b0 = B(i=0) + a0 = A(i=0) + a1 = A(i=1) + a2 = A(i=2) + self.assertTrue(w.add(b0)) + # a's are required dynamically, so their counts must be 0 + self.assertEqual(b0.complete_count, 1) + self.assertEqual(a0.complete_count, 0) + self.assertEqual(a1.complete_count, 0) + self.assertEqual(a2.complete_count, 0) + w.run() + # the complete methods of a's yielded first in b's run method were called equally often + self.assertEqual(b0.complete_count, 1) + self.assertEqual(a0.complete_count, 3) + self.assertEqual(a1.complete_count, 3) + self.assertEqual(a2.complete_count, 3) + + # test the disabled feature + with Worker(scheduler=self.sch, worker_id='2') as w: + w._config.cache_task_completion = False + b10 = B(i=10) + a10 = A(i=10) + a11 = A(i=11) + a12 = A(i=12) + self.assertTrue(w.add(b10)) + # a's are required dynamically, so their counts must be 0 + self.assertEqual(b10.complete_count, 1) + self.assertEqual(a10.complete_count, 0) + self.assertEqual(a11.complete_count, 0) + self.assertEqual(a12.complete_count, 0) + w.run() + # the complete methods of a's yielded first in b's run method were called more often + self.assertEqual(b10.complete_count, 1) + self.assertEqual(a10.complete_count, 5) + self.assertEqual(a11.complete_count, 4) + self.assertEqual(a12.complete_count, 3) + def test_gets_missed_work(self): class A(Task): done = False From 847204591c5122fe8b034a0d7e93e82f7bfd0d52 Mon Sep 17 00:00:00 2001 From: Marcel R Date: Tue, 21 Jun 2022 21:20:14 +0200 Subject: [PATCH 2/4] Extend completion cache to Worker methods. --- luigi/worker.py | 63 +++++++++++++++++++++++---------------------- test/worker_test.py | 6 ++--- 2 files changed, 34 insertions(+), 35 deletions(-) diff --git a/luigi/worker.py b/luigi/worker.py index e4f890aee8..58f7d70668 100644 --- a/luigi/worker.py +++ b/luigi/worker.py @@ -39,6 +39,7 @@ import subprocess import sys import contextlib +import functools import queue as Queue import random @@ -130,6 +131,9 @@ def __init__(self, task, worker_id, result_queue, status_reporter, self.check_complete_on_run = check_complete_on_run self.task_completion_cache = task_completion_cache + # completeness check using the cache + self.check_complete = functools.partial(check_complete_cached, completion_cache=task_completion_cache) + def _run_get_new_deps(self): task_gen = self.task.run() @@ -147,7 +151,7 @@ def _run_get_new_deps(self): return None new_req = flatten(requires) - if all(self._check_complete(t) for t in new_req): + if all(self.check_complete(t) for t in new_req): next_send = getpaths(requires) else: new_deps = [(t.task_module, t.task_family, t.to_str_params()) @@ -173,7 +177,7 @@ def run(self): # checking completeness of self.task so outputs of dependencies are # irrelevant. if self.check_unfulfilled_deps and not _is_external(self.task): - missing = [dep.task_id for dep in self.task.deps() if not self._check_complete(dep)] + missing = [dep.task_id for dep in self.task.deps() if not self.check_complete(dep)] if missing: deps = 'dependency' if len(missing) == 1 else 'dependencies' raise RuntimeError('Unfulfilled %s at run time: %s' % (deps, ', '.join(missing))) @@ -183,7 +187,7 @@ def run(self): if _is_external(self.task): # External task - if self._check_complete(self.task): + if self.check_complete(self.task): status = DONE else: status = FAILED @@ -193,7 +197,7 @@ def run(self): with self._forward_attributes(): new_deps = self._run_get_new_deps() if not new_deps: - if not self.check_complete_on_run or self._check_complete(self.task): + if not self.check_complete_on_run or self.check_complete(self.task): status = DONE else: raise TaskException("Task finished running, but complete() is still returning false.") @@ -267,24 +271,6 @@ def _forward_attributes(self): for reporter_attr, task_attr in self.forward_reporter_attributes.items(): setattr(self.task, task_attr, None) - def _check_complete(self, task): - """ - Checks if a task is complete, optionally using the task_completion_cache. - """ - task_id = task.task_id - - # return True if caching is used and the task was already complete - if self.task_completion_cache is not None and self.task_completion_cache.get(task_id): - return True - - is_complete = task.complete() - - # update the cache when used - if self.task_completion_cache is not None: - self.task_completion_cache[task_id] = is_complete - - return is_complete - # This code and the task_process_context config key currently feels a bit ad-hoc. # Discussion on generalizing it into a plugin system: https://github.com/spotify/luigi/issues/1897 @@ -413,13 +399,29 @@ def __init__(self, trace): self.trace = trace -def check_complete(task, out_queue): +def check_complete_cached(task, completion_cache=None): + # check if cached and complete + cache_key = task.task_id + if completion_cache is not None and completion_cache.get(cache_key): + return True + + # (re-)check the status + is_complete = task.complete() + + # tell the cache when complete + if completion_cache is not None and is_complete: + completion_cache[cache_key] = is_complete + + return is_complete + + +def check_complete(task, out_queue, completion_cache=None): """ - Checks if task is complete, puts the result to out_queue. + Checks if task is complete, puts the result to out_queue, optionally using the completion cache. """ logger.debug("Checking if %s is complete", task) try: - is_complete = task.complete() + is_complete = check_complete_cached(task, completion_cache) except Exception: is_complete = TracebackWrapper(traceback.format_exc()) out_queue.put((task, is_complete)) @@ -584,8 +586,10 @@ def __init__(self, scheduler=None, worker_id=None, worker_processes=1, assistant self._running_tasks = {} self._idle_since = None - # mp-safe dictionary for caching completation checks across task processes (set in run) + # mp-safe dictionary for caching completation checks across task processes self._task_completion_cache = None + if self._config.cache_task_completion: + self._task_completion_cache = multiprocessing.Manager().dict() # Stuff for execution_summary self._add_task_history = [] @@ -772,7 +776,7 @@ def add(self, task, multiprocess=False, processes=0): queue = DequeQueue() pool = SingleProcessPool() self._validate_task(task) - pool.apply_async(check_complete, [task, queue]) + pool.apply_async(check_complete, [task, queue, self._task_completion_cache]) # we track queue size ourselves because len(queue) won't work for multiprocessing queue_size = 1 @@ -786,7 +790,7 @@ def add(self, task, multiprocess=False, processes=0): if next.task_id not in seen: self._validate_task(next) seen.add(next.task_id) - pool.apply_async(check_complete, [next, queue]) + pool.apply_async(check_complete, [next, queue, self._task_completion_cache]) queue_size += 1 except (KeyboardInterrupt, TaskException): raise @@ -1209,9 +1213,6 @@ def run(self): self._add_worker() - if self._config.cache_task_completion: - self._task_completion_cache = multiprocessing.Manager().dict() - while True: while len(self._running_tasks) >= self.worker_processes > 0: logger.debug('%d running tasks, waiting for next task to finish', len(self._running_tasks)) diff --git a/test/worker_test.py b/test/worker_test.py index 0ba103a0c4..6dd3da0233 100644 --- a/test/worker_test.py +++ b/test/worker_test.py @@ -460,8 +460,7 @@ def run(self): self.has_run = True # test the enabled feature - with Worker(scheduler=self.sch, worker_id='2') as w: - w._config.cache_task_completion = True + with Worker(scheduler=self.sch, worker_id='2', cache_task_completion=True) as w: b0 = B(i=0) a0 = A(i=0) a1 = A(i=1) @@ -480,8 +479,7 @@ def run(self): self.assertEqual(a2.complete_count, 3) # test the disabled feature - with Worker(scheduler=self.sch, worker_id='2') as w: - w._config.cache_task_completion = False + with Worker(scheduler=self.sch, worker_id='2', cache_task_completion=False) as w: b10 = B(i=10) a10 = A(i=10) a11 = A(i=11) From f874aaf2021439f0bbc75348fa8f1f00bdb760ff Mon Sep 17 00:00:00 2001 From: Marcel R Date: Tue, 21 Jun 2022 21:23:47 +0200 Subject: [PATCH 3/4] Sync completation cache after running task. --- luigi/worker.py | 7 ++++++- test/worker_test.py | 6 +++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/luigi/worker.py b/luigi/worker.py index 58f7d70668..20b90b5caa 100644 --- a/luigi/worker.py +++ b/luigi/worker.py @@ -197,7 +197,12 @@ def run(self): with self._forward_attributes(): new_deps = self._run_get_new_deps() if not new_deps: - if not self.check_complete_on_run or self.check_complete(self.task): + if not self.check_complete_on_run: + # update the cache + if self.task_completion_cache is not None: + self.task_completion_cache[self.task.task_id] = True + status = DONE + elif self.check_complete(self.task): status = DONE else: raise TaskException("Task finished running, but complete() is still returning false.") diff --git a/test/worker_test.py b/test/worker_test.py index 6dd3da0233..811408397d 100644 --- a/test/worker_test.py +++ b/test/worker_test.py @@ -474,9 +474,9 @@ def run(self): w.run() # the complete methods of a's yielded first in b's run method were called equally often self.assertEqual(b0.complete_count, 1) - self.assertEqual(a0.complete_count, 3) - self.assertEqual(a1.complete_count, 3) - self.assertEqual(a2.complete_count, 3) + self.assertEqual(a0.complete_count, 2) + self.assertEqual(a1.complete_count, 2) + self.assertEqual(a2.complete_count, 2) # test the disabled feature with Worker(scheduler=self.sch, worker_id='2', cache_task_completion=False) as w: From 42275922749013814906ae57d117ce63355bbd84 Mon Sep 17 00:00:00 2001 From: Marcel R Date: Wed, 22 Jun 2022 21:19:05 +0200 Subject: [PATCH 4/4] Increase test coverage. --- test/worker_test.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/test/worker_test.py b/test/worker_test.py index 811408397d..ea443b668c 100644 --- a/test/worker_test.py +++ b/test/worker_test.py @@ -459,7 +459,7 @@ def run(self): yield A(i=self.i + 2) self.has_run = True - # test the enabled feature + # test with enabled cache_task_completion with Worker(scheduler=self.sch, worker_id='2', cache_task_completion=True) as w: b0 = B(i=0) a0 = A(i=0) @@ -478,7 +478,7 @@ def run(self): self.assertEqual(a1.complete_count, 2) self.assertEqual(a2.complete_count, 2) - # test the disabled feature + # test with disabled cache_task_completion with Worker(scheduler=self.sch, worker_id='2', cache_task_completion=False) as w: b10 = B(i=10) a10 = A(i=10) @@ -497,6 +497,25 @@ def run(self): self.assertEqual(a11.complete_count, 4) self.assertEqual(a12.complete_count, 3) + # test with enabled check_complete_on_run + with Worker(scheduler=self.sch, worker_id='2', check_complete_on_run=True) as w: + b20 = B(i=20) + a20 = A(i=20) + a21 = A(i=21) + a22 = A(i=22) + self.assertTrue(w.add(b20)) + # a's are required dynamically, so their counts must be 0 + self.assertEqual(b20.complete_count, 1) + self.assertEqual(a20.complete_count, 0) + self.assertEqual(a21.complete_count, 0) + self.assertEqual(a22.complete_count, 0) + w.run() + # the complete methods of a's yielded first in b's run method were called more often + self.assertEqual(b20.complete_count, 2) + self.assertEqual(a20.complete_count, 6) + self.assertEqual(a21.complete_count, 5) + self.assertEqual(a22.complete_count, 4) + def test_gets_missed_work(self): class A(Task): done = False