diff --git a/dash_slicer/slicer.py b/dash_slicer/slicer.py index 7a898bb..8c418ec 100644 --- a/dash_slicer/slicer.py +++ b/dash_slicer/slicer.py @@ -2,7 +2,7 @@ from plotly.graph_objects import Figure from dash import Dash from dash.dependencies import Input, Output, State, ALL -from dash_core_components import Graph, Slider, Store +from dash_core_components import Graph, Slider, Store, Interval from .utils import img_array_to_uri, get_thumbnail_size, shape3d_to_size2d @@ -26,6 +26,7 @@ class VolumeSlicer: reverse_y (bool): Whether to reverse the y-axis, so that the origin of the slice is in the top-left, rather than bottom-left. Default True. (This sets the figure's yaxes ``autorange`` to "reversed" or True.) + Note: setting this to False affects performance, see #12. scene_id (str): the scene that this slicer is part of. Slicers that have the same scene-id show each-other's positions with line indicators. By default this is derived from ``id(volume)``. @@ -51,9 +52,13 @@ class VolumeSlicer: The value in the store must be an 3-element tuple (x, y, z) in scene coordinates. To apply the position for one position only, use e.g ``(None, None, x)``. - Some notes on performance: for a smooth experience, create the `Dash` - application with `update_title=None`, and when running the server in debug - mode, consider setting `dev_tools_props_check=False`. + Some notes on performance: for a smooth experience, avoid triggering + unnecessary figure updates. When adding a callback that uses the + slicer position, use the (rate limited) `index` and `pos` stores + rather than the slider value. Further, create the `Dash` application + with `update_title=None`, and when running the server in debug mode, + consider setting `dev_tools_props_check=False`. + """ _global_slicer_counter = 0 @@ -154,6 +159,20 @@ def stores(self): """ return self._stores + @property + def index(self): + """A dcc.Store containing the integer slice number. This value + is a rate-limited version of the slider value. + """ + return self._index + + @property + def pos(self): + """A dcc.Store containing the float position in scene coordinates, + along the slice-axis. + """ + return self._pos + @property def overlay_data(self): """A dcc.Store containing the overlay data. The form of this @@ -277,39 +296,68 @@ def _create_dash_components(self): config={"scrollZoom": True}, ) - # Create a slider object that the user can put in the layout (or not) + initial_index = info["size"][2] // 2 + initial_pos = info["origin"][2] + initial_index * info["spacing"][2] + + # Create a slider object that the user can put in the layout (or not). + # Note that the tooltip introduces a measurable performance penalty, + # so maybe we can display it in a different way? self._slider = Slider( id=self._subid("slider"), min=0, max=info["size"][2] - 1, step=1, - value=info["size"][2] // 2, - tooltip={"always_visible": False, "placement": "left"}, + value=initial_index, updatemode="drag", + tooltip={"always_visible": False, "placement": "left"}, ) # Create the stores that we need (these must be present in the layout) + + # A dict of static info for this slicer self._info = Store(id=self._subid("info"), data=info) - self._position = Store( - id=self._subid("position", True, axis=self._axis), data=0 - ) - self._setpos = Store(id=self._subid("setpos", True), data=None) - self._requested_index = Store(id=self._subid("req-index"), data=0) - self._request_data = Store(id=self._subid("req-data"), data="") + + # A list of low-res slices (encoded as base64-png) self._lowres_data = Store(id=self._subid("lowres"), data=thumbnails) + + # A list of mask slices (encoded as base64-png or null) self._overlay_data = Store(id=self._subid("overlay"), data=[]) + + # Slice data provided by the server + self._server_data = Store(id=self._subid("server-data"), data="") + + # Store image traces for the slicer. self._img_traces = Store(id=self._subid("img-traces"), data=[]) + + # Store indicator traces for the slicer. self._indicator_traces = Store(id=self._subid("indicator-traces"), data=[]) + + # A timer to apply a rate-limit between slider.value and index.data + self._timer = Interval(id=self._subid("timer"), interval=100, disabled=True) + + # The (integer) index of the slice to show. This value is rate-limited + self._index = Store(id=self._subid("index"), data=initial_index) + + # The (float) position (in scene coords) of the current slice, + # used to publish our position to slicers with the same scene_id. + self._pos = Store( + id=self._subid("pos", True, axis=self._axis), data=initial_pos + ) + + # Signal to set the position of other slicers with the same scene_id. + self._setpos = Store(id=self._subid("setpos", True), data=None) + self._stores = [ self._info, - self._position, - self._setpos, - self._requested_index, - self._request_data, self._lowres_data, self._overlay_data, + self._server_data, self._img_traces, self._indicator_traces, + self._timer, + self._index, + self._pos, + self._setpos, ] def _create_server_callbacks(self): @@ -317,8 +365,8 @@ def _create_server_callbacks(self): app = self._app @app.callback( - Output(self._request_data.id, "data"), - [Input(self._requested_index.id, "data")], + Output(self._server_data.id, "data"), + [Input(self._index.id, "data")], ) def upload_requested_slice(slice_index): slice = img_array_to_uri(self._slice(slice_index)) @@ -326,14 +374,29 @@ def upload_requested_slice(slice_index): def _create_client_callbacks(self): """Create the callbacks that run client-side.""" + + # setpos (external) + # \ + # slider --[rate limit]--> index --> pos + # \ \ + # \ server_data (a new slice) + # \ \ + # \ --> image_traces + # ----------------------- / \ + # -----> figure + # / + # indicator_traces + # / + # pos (external) + app = self._app # ---------------------------------------------------------------------- - # Callback to trigger fellow slicers to go to a specific position. + # Callback to trigger fellow slicers to go to a specific position on click. app.clientside_callback( """ - function trigger_setpos(data, index, info) { + function update_setpos_from_click(data, index, info) { if (data && data.points && data.points.length) { let point = data["points"][0]; let xyz = [point["x"], point["y"]]; @@ -350,11 +413,11 @@ def _create_client_callbacks(self): ) # ---------------------------------------------------------------------- - # Callback to update index from external setpos signal. + # Callback to update slider based on external setpos signals. app.clientside_callback( """ - function respond_to_setpos(positions, cur_index, info) { + function update_slider_value(positions, cur_index, info) { for (let trigger of dash_clientside.callback_context.triggered) { if (!trigger.value) continue; let pos = trigger.value[2 - info.axis]; @@ -381,64 +444,81 @@ def _create_client_callbacks(self): ) # ---------------------------------------------------------------------- - # Callback to update position (in scene coordinates) from the index. + # Callback to rate-limit the index (using a timer/interval). app.clientside_callback( """ - function update_position(index, info) { - return info.origin[2] + index * info.spacing[2]; - } - """, - Output(self._position.id, "data"), - [Input(self._slider.id, "value")], - [State(self._info.id, "data")], - ) + function update_index_rate_limiting(index, n_intervals, interval) { - # ---------------------------------------------------------------------- - # Callback to request new slices. - # Note: this callback cannot be merged with the one below, because - # it would create a circular dependency. + if (!window._slicer_{{ID}}) window._slicer_{{ID}} = {}; + let slicer_state = window._slicer_{{ID}}; + let now = window.performance.now(); - app.clientside_callback( - """ - function update_request(index) { + // Get whether the slider was moved + let slider_was_moved = false; + for (let trigger of dash_clientside.callback_context.triggered) { + if (trigger.prop_id.indexOf('slider') >= 0) slider_was_moved = true; + } - // Clear the cache? - if (!window.slicecache_for_{{ID}}) { window.slicecache_for_{{ID}} = {}; } - let slice_cache = window.slicecache_for_{{ID}}; + // Initialize return values + let req_index = dash_clientside.no_update; + let disable_timer = false; - // Request a new slice (or not) - let request_index = index; - if (slice_cache[index]) { - return window.dash_clientside.no_update; - } else { - console.log('requesting slice ' + index); - return index; + // If the slider moved, remember the time when this happened + slicer_state.new_time = slicer_state.new_time || 0; + + if (slider_was_moved) { + slicer_state.new_time = now; + } else if (!n_intervals) { + disable_timer = true; // start disabled + } + + // We can either update the rate-limited index interval ms after + // the real index changed, or interval ms after it stopped + // changing. The former makes the indicators come along while + // dragging the slider, the latter is better for a smooth + // experience, and the interval can be set much lower. + if (index != slicer_state.req_index) { + if (now - slicer_state.new_time >= interval) { + req_index = slicer_state.req_index = index; + disable_timer = true; + console.log('requesting slice ' + req_index); + } } + + return [req_index, disable_timer]; } """.replace( "{{ID}}", self._context_id ), - Output(self._requested_index.id, "data"), - [Input(self.slider.id, "value")], + [ + Output(self._index.id, "data"), + Output(self._timer.id, "disabled"), + ], + [Input(self._slider.id, "value"), Input(self._timer.id, "n_intervals")], + [State(self._timer.id, "interval")], ) # ---------------------------------------------------------------------- - # Callback that creates a list of image traces (slice and overlay). + # Callback to update position (in scene coordinates) from the index. app.clientside_callback( """ - function update_image_traces(index, req_data, overlays, lowres, info, current_traces) { + function update_pos(index, info) { + return info.origin[2] + index * info.spacing[2]; + } + """, + Output(self._pos.id, "data"), + [Input(self._index.id, "data")], + [State(self._info.id, "data")], + ) - // Add data to the cache if the data is indeed new - if (!window.slicecache_for_{{ID}}) { window.slicecache_for_{{ID}} = {}; } - let slice_cache = window.slicecache_for_{{ID}}; - for (let trigger of dash_clientside.callback_context.triggered) { - if (trigger.prop_id.indexOf('req-data') >= 0) { - slice_cache[req_data.index] = req_data; - break; - } - } + # ---------------------------------------------------------------------- + # Callback that creates a list of image traces (slice and overlay). + + app.clientside_callback( + """ + function update_image_traces(index, server_data, overlays, lowres, info, current_traces) { // Prepare traces let slice_trace = { @@ -455,14 +535,14 @@ def _create_client_callbacks(self): overlay_trace.hovertemplate = ''; let new_traces = [slice_trace, overlay_trace]; - // Depending on the state of the cache, use full data, or use lowres and request slice - if (slice_cache[index]) { - let cached = slice_cache[index]; - slice_trace.source = cached.slice; + // Use full data, or use lowres + if (index == server_data.index) { + slice_trace.source = server_data.slice; } else { slice_trace.source = lowres[index]; // Scale the image to take the exact same space as the full-res - // version. It's not correct, but it looks better ... + // version. Note that depending on how the low-res data is + // created, the pixel centers may not be correctly aligned. slice_trace.dx *= info.size[0] / info.lowres_size[0]; slice_trace.dy *= info.size[1] / info.lowres_size[1]; slice_trace.x0 += 0.5 * slice_trace.dx - 0.5 * info.spacing[0]; @@ -474,7 +554,7 @@ def _create_client_callbacks(self): if (new_traces[0].source == current_traces[0].source && new_traces[1].source == current_traces[1].source) { - new_traces = window.dash_clientside.no_update; + new_traces = dash_clientside.no_update; } return new_traces; } @@ -483,8 +563,8 @@ def _create_client_callbacks(self): ), Output(self._img_traces.id, "data"), [ - Input(self.slider.id, "value"), - Input(self._request_data.id, "data"), + Input(self._slider.id, "value"), + Input(self._server_data.id, "data"), Input(self._overlay_data.id, "data"), ], [ @@ -497,12 +577,9 @@ def _create_client_callbacks(self): # ---------------------------------------------------------------------- # Callback to create scatter traces from the positions of other slicers. - # Create a callback to create a trace representing all slice-indices that: - # * corresponding to the same volume data - # * match any of the selected axii app.clientside_callback( """ - function handle_indicator(positions1, positions2, info, current) { + function update_indicator_traces(positions1, positions2, info, current) { let x0 = info.origin[0], y0 = info.origin[1]; let x1 = x0 + info.size[0] * info.spacing[0], y1 = y0 + info.size[1] * info.spacing[1]; x0 = x0 - info.spacing[0], y0 = y0 - info.spacing[1]; @@ -536,7 +613,7 @@ def _create_client_callbacks(self): { "scene": self._scene_id, "context": ALL, - "name": "position", + "name": "pos", "axis": axis, }, "data", @@ -562,7 +639,6 @@ def _create_client_callbacks(self): for (let trace of indicators) { traces.push(trace); } // Update figure - console.log("updating figure"); let figure = {...ori_figure}; figure.data = traces; diff --git a/examples/bring_your_own_slider.py b/examples/bring_your_own_slider.py index 1820e09..5049471 100644 --- a/examples/bring_your_own_slider.py +++ b/examples/bring_your_own_slider.py @@ -14,7 +14,7 @@ import imageio -app = dash.Dash(__name__) +app = dash.Dash(__name__, update_title=None) vol = imageio.volread("imageio:stent.npz") slicer = VolumeSlicer(app, vol) diff --git a/examples/slicer_with_1_plus_2_views.py b/examples/slicer_with_1_plus_2_views.py index 3c3d3c5..017b882 100644 --- a/examples/slicer_with_1_plus_2_views.py +++ b/examples/slicer_with_1_plus_2_views.py @@ -19,7 +19,7 @@ import imageio -app = dash.Dash(__name__) +app = dash.Dash(__name__, update_title=None) vol1 = imageio.volread("imageio:stent.npz") @@ -28,14 +28,10 @@ ori = 1000, 2000, 3000 -slicer1 = VolumeSlicer( - app, vol1, axis=1, origin=ori, reverse_y=False, scene_id="scene1" -) -slicer2 = VolumeSlicer( - app, vol1, axis=0, origin=ori, reverse_y=False, scene_id="scene1" -) +slicer1 = VolumeSlicer(app, vol1, axis=1, origin=ori, scene_id="scene1") +slicer2 = VolumeSlicer(app, vol1, axis=0, origin=ori, scene_id="scene1") slicer3 = VolumeSlicer( - app, vol2, axis=0, origin=ori, spacing=spacing, reverse_y=False, scene_id="scene1" + app, vol2, axis=0, origin=ori, spacing=spacing, scene_id="scene1" ) app.layout = html.Div( diff --git a/examples/slicer_with_2_views.py b/examples/slicer_with_2_views.py index 418c7f1..0707dca 100644 --- a/examples/slicer_with_2_views.py +++ b/examples/slicer_with_2_views.py @@ -8,7 +8,7 @@ import imageio -app = dash.Dash(__name__) +app = dash.Dash(__name__, update_title=None) vol = imageio.volread("imageio:stent.npz") slicer1 = VolumeSlicer(app, vol, axis=1) diff --git a/examples/slicer_with_3_views.py b/examples/slicer_with_3_views.py index 04d5ae8..440bb04 100644 --- a/examples/slicer_with_3_views.py +++ b/examples/slicer_with_3_views.py @@ -11,14 +11,14 @@ from skimage.measure import marching_cubes import imageio -app = dash.Dash(__name__) +app = dash.Dash(__name__, update_title=None) server = app.server # Read volumes and create slicer objects vol = imageio.volread("imageio:stent.npz") -slicer1 = VolumeSlicer(app, vol, reverse_y=False, axis=0) -slicer2 = VolumeSlicer(app, vol, reverse_y=False, axis=1) -slicer3 = VolumeSlicer(app, vol, reverse_y=False, axis=2) +slicer1 = VolumeSlicer(app, vol, axis=0) +slicer2 = VolumeSlicer(app, vol, axis=1) +slicer3 = VolumeSlicer(app, vol, axis=2) # Calculate isosurface and create a figure with a mesh object verts, faces, _, _ = marching_cubes(vol, 300, step_size=2) diff --git a/examples/threshold_overlay.py b/examples/threshold_overlay.py index 8edf5a1..4a536ca 100644 --- a/examples/threshold_overlay.py +++ b/examples/threshold_overlay.py @@ -15,7 +15,7 @@ import imageio -app = dash.Dash(__name__) +app = dash.Dash(__name__, update_title=None) server = app.server vol = imageio.volread("imageio:stent.npz") diff --git a/tests/test_slicer.py b/tests/test_slicer.py index e90ec4f..ee4fe55 100644 --- a/tests/test_slicer.py +++ b/tests/test_slicer.py @@ -28,7 +28,7 @@ def test_slicer_init(): assert isinstance(s.graph, dcc.Graph) assert isinstance(s.slider, dcc.Slider) assert isinstance(s.stores, list) - assert all(isinstance(store, dcc.Store) for store in s.stores) + assert all(isinstance(store, (dcc.Store, dcc.Interval)) for store in s.stores) def test_scene_id_and_context_id():