Skip to content

Commit

Permalink
Replace long callback interval with request polling handled in renderer.
Browse files Browse the repository at this point in the history
  • Loading branch information
T4rk1n committed Jun 13, 2022
1 parent 3a207ce commit fd9ee13
Show file tree
Hide file tree
Showing 12 changed files with 571 additions and 319 deletions.
202 changes: 194 additions & 8 deletions dash/_callback.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import collections
from functools import wraps

import flask

from .dependencies import (
handle_callback_args,
handle_grouped_callback_args,
Output,
)
from .exceptions import PreventUpdate
from .exceptions import PreventUpdate, WildcardInLongCallback, DuplicateCallback

from ._grouping import (
flatten_grouping,
Expand All @@ -17,9 +19,11 @@
create_callback_id,
stringify_id,
to_json,
coerce_to_list,
)

from . import _validate
from .long_callback.managers import BaseLongCallbackManager


class NoUpdate:
Expand All @@ -30,15 +34,28 @@ def to_plotly_json(self): # pylint: disable=no-self-use

@staticmethod
def is_no_update(obj):
return obj == {"_dash_no_update": "_dash_no_update"}
return isinstance(obj, NoUpdate) or obj == {
"_dash_no_update": "_dash_no_update"
}


GLOBAL_CALLBACK_LIST = []
GLOBAL_CALLBACK_MAP = {}
GLOBAL_INLINE_SCRIPTS = []


def callback(*_args, **_kwargs):
def callback(
*_args,
long=False,
long_interval=1000,
long_progress=None,
long_progress_default=None,
long_running=None,
long_cancel=None,
long_manager=None,
long_cache_args_to_ignore=None,
**_kwargs,
):
"""
Normally used as a decorator, `@dash.callback` provides a server-side
callback relating the values of one or more `Output` items to one or
Expand All @@ -56,15 +73,79 @@ def callback(*_args, **_kwargs):
not to fire when its outputs are first added to the page. Defaults to
`False` and unlike `app.callback` is not configurable at the app level.
"""

long_spec = None

if long:
long_spec = {
"interval": long_interval,
}

if long_manager:
long_spec["manager"] = long_manager

if long_progress:
long_spec["progress"] = coerce_to_list(long_progress)
validate_long_inputs(long_spec["progress"])

if long_progress_default:
long_spec["progressDefault"] = coerce_to_list(long_progress_default)

if not len(long_spec["progress"]) == len(long_spec["progressDefault"]):
raise Exception(
"Progress and progress default needs to be of same length"
)

if long_running:
long_spec["running"] = coerce_to_list(long_running)
validate_long_inputs(x[0] for x in long_spec["running"])

if long_cancel:
cancel_inputs = coerce_to_list(long_cancel)
validate_long_inputs(cancel_inputs)

cancels_output = [Output(c.component_id, "id") for c in cancel_inputs]

try:

@callback(cancels_output, cancel_inputs, prevent_initial_call=True)
def cancel_call(*_):
job_ids = flask.request.args.getlist("cancelJob")
manager = long_manager or flask.g.long_callback_manager
if job_ids:
for job_id in job_ids:
manager.terminate_job(int(job_id))
return NoUpdate()

except DuplicateCallback:
pass # Already a callback to cancel, will get the proper jobs from the store.

long_spec["cancel"] = [c.to_dict() for c in cancel_inputs]

if long_cache_args_to_ignore:
long_spec["cache_args_to_ignore"] = long_cache_args_to_ignore

return register_callback(
GLOBAL_CALLBACK_LIST,
GLOBAL_CALLBACK_MAP,
False,
*_args,
**_kwargs,
long=long_spec,
)


def validate_long_inputs(deps):
for dep in deps:
if dep.has_wildcard():
raise WildcardInLongCallback(
f"""
long callbacks does not support dependencies with
pattern-matching ids
Received: {repr(dep)}\n"""
)


def clientside_callback(clientside_function, *args, **kwargs):
return register_clientside_callback(
GLOBAL_CALLBACK_LIST,
Expand All @@ -87,6 +168,7 @@ def insert_callback(
state,
inputs_state_indices,
prevent_initial_call,
long=None,
):
if prevent_initial_call is None:
prevent_initial_call = config_prevent_initial_callbacks
Expand All @@ -98,19 +180,26 @@ def insert_callback(
"state": [c.to_dict() for c in state],
"clientside_function": None,
"prevent_initial_call": prevent_initial_call,
"long": long
and {
"interval": long["interval"],
},
}

callback_map[callback_id] = {
"inputs": callback_spec["inputs"],
"state": callback_spec["state"],
"outputs_indices": outputs_indices,
"inputs_state_indices": inputs_state_indices,
"long": long,
}
callback_list.append(callback_spec)

return callback_id


def register_callback(
# pylint: disable=R0912
def register_callback( # pylint: disable=R0914
callback_list, callback_map, config_prevent_initial_callbacks, *_args, **_kwargs
):
(
Expand All @@ -129,6 +218,8 @@ def register_callback(
insert_output = flatten_grouping(output)
multi = True

long = _kwargs.get("long")

output_indices = make_grouping_by_index(output, list(range(grouping_len(output))))
callback_id = insert_callback(
callback_list,
Expand All @@ -140,23 +231,118 @@ def register_callback(
flat_state,
inputs_state_indices,
prevent_initial_call,
long=long,
)

# pylint: disable=too-many-locals
def wrap_func(func):

if long is not None:
long_key = BaseLongCallbackManager.register_func(
func, long.get("progress") is not None
)

@wraps(func)
def add_context(*args, **kwargs):
output_spec = kwargs.pop("outputs_list")
callback_manager = long.get(
"manager", kwargs.pop("long_callback_manager", None)
)
_validate.validate_output_spec(insert_output, output_spec, Output)

func_args, func_kwargs = _validate.validate_and_group_input_args(
args, inputs_state_indices
)

# don't touch the comment on the next line - used by debugger
output_value = func(*func_args, **func_kwargs) # %% callback invoked %%
response = {"multi": True}

if long is not None:
progress_outputs = long.get("progress")
cache_key = flask.request.args.get("cacheKey")
job_id = flask.request.args.get("job")

current_key = callback_manager.build_cache_key(
func,
# Inputs provided as dict is kwargs.
func_args if func_args else func_kwargs,
long.get("cache_args_to_ignore", []),
)

if not cache_key:
cache_key = current_key

job_fn = callback_manager.func_registry.get(long_key)

job = callback_manager.call_job_fn(
cache_key,
job_fn,
args,
)

data = {
"cacheKey": cache_key,
"job": job,
}

running = long.get("running")

if running:
data["running"] = {str(r[0]): r[1] for r in running}
data["runningOff"] = {str(r[0]): r[2] for r in running}
cancel = long.get("cancel")
if cancel:
data["cancel"] = cancel

progress_default = long.get("progressDefault")
if progress_default:
data["progressDefault"] = {
str(o): x
for o, x in zip(progress_outputs, progress_default)
}
return to_json(data)
else:
if progress_outputs:
# Get the progress before the result as it would be erased after the results.
progress = callback_manager.get_progress(cache_key)
if progress:
response["progress"] = {
str(x): progress[i]
for i, x in enumerate(progress_outputs)
}

output_value = callback_manager.get_result(cache_key, job_id)
# Must get job_running after get_result since get_results terminates it.
job_running = callback_manager.job_running(job_id)
if not job_running and output_value is callback_manager.UNDEFINED:
# Job canceled -> no output to close the loop.
output_value = NoUpdate()

elif (
isinstance(output_value, dict)
and "long_callback_error" in output_value
):
error = output_value.get("long_callback_error")
raise Exception(
f"An error occurred inside a long callback: {error['msg']}\n{error['tb']}"
)

if job_running and output_value is not callback_manager.UNDEFINED:
# cached results.
callback_manager.terminate_job(job_id)

if multi and isinstance(output_value, (list, tuple)):
output_value = [
NoUpdate() if NoUpdate.is_no_update(r) else r
for r in output_value
]

if output_value is callback_manager.UNDEFINED:
return to_json(response)
else:
# don't touch the comment on the next line - used by debugger
output_value = func(*func_args, **func_kwargs) # %% callback invoked %%

if isinstance(output_value, NoUpdate):
if NoUpdate.is_no_update(output_value):
raise PreventUpdate

if not multi:
Expand Down Expand Up @@ -191,7 +377,7 @@ def add_context(*args, **kwargs):
if not has_update:
raise PreventUpdate

response = {"response": component_ids, "multi": True}
response["response"] = component_ids

try:
jsonResponse = to_json(response)
Expand Down
6 changes: 6 additions & 0 deletions dash/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,9 @@ def gen_salt(chars):
return "".join(
secrets.choice(string.ascii_letters + string.digits) for _ in range(chars)
)


def coerce_to_list(obj):
if not isinstance(obj, (list, tuple)):
return [obj]
return obj
Loading

0 comments on commit fd9ee13

Please sign in to comment.