Skip to content

Commit

Permalink
Merge pull request #2473 from plotly/fix-2221
Browse files Browse the repository at this point in the history
Add callback id to long callback key generation.
  • Loading branch information
T4rk1n authored Mar 29, 2023
2 parents 1bb7a74 + ba38e94 commit 6773440
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 19 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ This project adheres to [Semantic Versioning](https://semver.org/).

- [#2479](https://github.com/plotly/dash/pull/2479) Fix `KeyError` "Callback function not found for output [...], , perhaps you forgot to prepend the '@'?" issue when using duplicate callbacks targeting the same output. This issue would occur when the app is restarted or when running with multiple `gunicorn` workers.
- [#2471](https://github.com/plotly/dash/pull/2471) Fix `allow_duplicate` output with clientside callback, fix [#2467](https://github.com/plotly/dash/issues/2467)
- [#2473](https://github.com/plotly/dash/pull/2473) Fix background callbacks with different outputs but same function, fix [#2221](https://github.com/plotly/dash/issues/2221)

## [2.9.1] - 2023-03-17

Expand Down
4 changes: 3 additions & 1 deletion dash/_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,9 @@ def wrap_func(func):

if long is not None:
long_key = BaseLongCallbackManager.register_func(
func, long.get("progress") is not None
func,
long.get("progress") is not None,
callback_id,
)

@wraps(func)
Expand Down
14 changes: 8 additions & 6 deletions dash/long_callback/managers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def terminate_unhealthy_job(self, job):
def job_running(self, job):
raise NotImplementedError

def make_job_fn(self, fn, progress):
def make_job_fn(self, fn, progress, key=None):
raise NotImplementedError

def call_job_fn(self, key, job_fn, args, context):
Expand Down Expand Up @@ -76,11 +76,11 @@ def build_cache_key(self, fn, args, cache_args_to_ignore):
return hashlib.sha1(str(hash_dict).encode("utf-8")).hexdigest()

def register(self, key, fn, progress):
self.func_registry[key] = self.make_job_fn(fn, progress)
self.func_registry[key] = self.make_job_fn(fn, progress, key)

@staticmethod
def register_func(fn, progress):
key = BaseLongCallbackManager.hash_function(fn)
def register_func(fn, progress, callback_id):
key = BaseLongCallbackManager.hash_function(fn, callback_id)
BaseLongCallbackManager.functions.append(
(
key,
Expand All @@ -99,7 +99,9 @@ def _make_progress_key(key):
return key + "-progress"

@staticmethod
def hash_function(fn):
def hash_function(fn, callback_id=""):
fn_source = inspect.getsource(fn)
fn_str = fn_source
return hashlib.sha1(fn_str.encode("utf-8")).hexdigest()
return hashlib.sha1(
callback_id.encode("utf-8") + fn_str.encode("utf-8")
).hexdigest()
15 changes: 4 additions & 11 deletions dash/long_callback/managers/celery_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import json
import inspect
import hashlib
import traceback
from contextvars import copy_context

Expand Down Expand Up @@ -78,8 +76,8 @@ def job_running(self, job):
"PROGRESS",
)

def make_job_fn(self, fn, progress):
return _make_job_fn(fn, self.handle, progress)
def make_job_fn(self, fn, progress, key=None):
return _make_job_fn(fn, self.handle, progress, key)

def get_task(self, job):
if job:
Expand Down Expand Up @@ -127,15 +125,10 @@ def get_result(self, key, job):
return result


def _make_job_fn(fn, celery_app, progress):
def _make_job_fn(fn, celery_app, progress, key):
cache = celery_app.backend

# Hash function source and module to create a unique (but stable) celery task name
fn_source = inspect.getsource(fn)
fn_str = fn_source
fn_hash = hashlib.sha1(fn_str.encode("utf-8")).hexdigest()

@celery_app.task(name=f"long_callback_{fn_hash}")
@celery_app.task(name=f"long_callback_{key}")
def job_fn(result_key, progress_key, user_callback_args, context=None):
def _set_progress(progress_value):
if not isinstance(progress_value, (list, tuple)):
Expand Down
2 changes: 1 addition & 1 deletion dash/long_callback/managers/diskcache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def job_running(self, job):
return proc.status() != psutil.STATUS_ZOMBIE
return False

def make_job_fn(self, fn, progress):
def make_job_fn(self, fn, progress, key=None):
return _make_job_fn(fn, self.handle, progress)

def clear_cache_entry(self, key):
Expand Down
36 changes: 36 additions & 0 deletions tests/integration/long_callback/app_diff_outputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from dash import Dash, Input, Output, html

from tests.integration.long_callback.utils import get_long_callback_manager

long_callback_manager = get_long_callback_manager()
handle = long_callback_manager.handle

app = Dash(__name__, long_callback_manager=long_callback_manager)

app.layout = html.Div(
[
html.Button("click 1", id="button-1"),
html.Button("click 2", id="button-2"),
html.Div(id="output-1"),
html.Div(id="output-2"),
]
)


def gen_callback(index):
@app.callback(
Output(f"output-{index}", "children"),
Input(f"button-{index}", "n_clicks"),
background=True,
prevent_initial_call=True,
)
def callback_name(_):
return f"Clicked on {index}"


for i in range(1, 3):
gen_callback(i)


if __name__ == "__main__":
app.run_server(debug=True)
9 changes: 9 additions & 0 deletions tests/integration/long_callback/test_basic_long_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,3 +547,12 @@ def test_lcbc014_progress_delete(dash_duo, manager):
dash_duo.wait_for_text_to_equal("#output", "done")

assert dash_duo.find_element("#progress-counter").text == "2"


def test_lcbc015_diff_outputs_same_func(dash_duo, manager):
with setup_long_callback_app(manager, "app_diff_outputs") as app:
dash_duo.start_server(app)

for i in range(1, 3):
dash_duo.find_element(f"#button-{i}").click()
dash_duo.wait_for_text_to_equal(f"#output-{i}", f"Clicked on {i}")

0 comments on commit 6773440

Please sign in to comment.