From 45bf8494fe99ea02e300294ac5023d0e9538047d Mon Sep 17 00:00:00 2001 From: Almar Klein Date: Tue, 15 Dec 2020 17:07:14 +0100 Subject: [PATCH 1/8] Add support for contrast limits --- dash_slicer/slicer.py | 64 +++++++++++++++++++++++++++++++++---------- 1 file changed, 50 insertions(+), 14 deletions(-) diff --git a/dash_slicer/slicer.py b/dash_slicer/slicer.py index 72246e9..d23726d 100644 --- a/dash_slicer/slicer.py +++ b/dash_slicer/slicer.py @@ -111,6 +111,8 @@ class VolumeSlicer: the slice is in the top-left, rather than bottom-left. Default True. Note: setting this to False affects performance, see #12. This has been fixed, but the fix has not yet been released with Dash. + * `clim` (tuple of `float`): the (initial) contrast limits. Default the min + and max of the volume. * `scene_id` (`str`): the scene that this slicer is part of. Slicers that have the same scene-id show each-other's positions with line indicators. By default this is derived from `id(volume)`. @@ -137,6 +139,7 @@ def __init__( origin=None, axis=0, reverse_y=True, + clim=None, scene_id=None, color=None, thumbnail=True, @@ -161,6 +164,14 @@ def __init__( self._axis = int(axis) self._reverse_y = bool(reverse_y) + # Check and store contrast limits + if clim is None: + self._initial_clim = self._volume.min(), self._volume.max() + elif isinstance(clim, (tuple, list)) and len(clim) == 2: + self._initial_clim = float(clim[0]), float(clim[1]) + else: + raise ValueError("The clim must be None or a 2-tuple of floats.") + # Check and store thumbnail if not (isinstance(thumbnail, (int, bool))): raise ValueError("thumbnail must be a boolean or an integer.") @@ -271,6 +282,14 @@ def state(self): """ return self._state + @property + def clim(self): + """A `dcc.Store` representing the contrast limits as a 2-element tuple. + This value should probably not be changed too often (e.g. on slider drag) + because the thumbnail data is recreated on each change. + """ + return self._clim + @property def extra_traces(self): """A `dcc.Store` that can be used as an output to define @@ -377,12 +396,18 @@ def _subid(self, name, use_dict=False, **kwargs): assert not kwargs return self._context_id + "-" + name - def _slice(self, index): + def _slice(self, index, clim): """Sample a slice from the volume.""" + # Sample from the volume indices = [slice(None), slice(None), slice(None)] indices[self._axis] = index - im = self._volume[tuple(indices)] - return (im.astype(np.float32) * (255 / im.max())).astype(np.uint8) + im = self._volume[tuple(indices)].astype(np.float32) + # Apply contrast limits + clim = min(clim), max(clim) + im = (im - clim[0]) * (255 / (clim[1] - clim[0])) + im[im < 0] = 0 + im[im > 255] = 255 + return im.astype(np.uint8) def _create_dash_components(self): """Create the graph, slider, figure, etc.""" @@ -391,17 +416,13 @@ def _create_dash_components(self): # Prep low-res slices. The get_thumbnail_size() is a bit like # a simulation to get the low-res size. if not self._thumbnail: - thumbnail_size = None + self._thumbnail_size_param = None info["thumbnail_size"] = info["size"] else: - thumbnail_size = self._thumbnail + self._thumbnail_size_param = self._thumbnail info["thumbnail_size"] = get_thumbnail_size( - info["size"][:2], thumbnail_size + info["size"][:2], self._thumbnail_size_param ) - thumbnails = [ - img_array_to_uri(self._slice(i), thumbnail_size) - for i in range(info["size"][2]) - ] # Create the figure object - can be accessed by user via slicer.graph.figure self._fig = fig = plotly.graph_objects.Figure(data=[]) @@ -451,8 +472,11 @@ def _create_dash_components(self): # A dict of static info for this slicer self._info = Store(id=self._subid("info"), data=info) + # A list of contrast limits + self._clim = Store(id=self._subid("clim"), data=self._initial_clim) + # A list of low-res slices, or the full-res data (encoded as base64-png) - self._thumbs_data = Store(id=self._subid("thumbs"), data=thumbnails) + self._thumbs_data = Store(id=self._subid("thumbs"), data=[]) # A list of mask slices (encoded as base64-png or null) self._overlay_data = Store(id=self._subid("overlay"), data=[]) @@ -482,6 +506,7 @@ def _create_dash_components(self): self._stores = [ self._info, + self._clim, self._thumbs_data, self._overlay_data, self._server_data, @@ -497,15 +522,26 @@ def _create_server_callbacks(self): """Create the callbacks that run server-side.""" app = self._app + @app.callback( + Output(self._thumbs_data.id, "data"), + [Input(self._clim.id, "data")], + ) + def upload_thumbnails(clim): + thumbnail_size = self._thumbnail_size_param + return [ + img_array_to_uri(self._slice(i, clim), thumbnail_size) + for i in range(self.nslices) + ] + @app.callback( Output(self._server_data.id, "data"), - [Input(self._state.id, "data")], + [Input(self._state.id, "data"), Input(self._clim.id, "data")], ) - def upload_requested_slice(state): + def upload_requested_slice(state, clim): if state is None or not state["index_changed"]: return dash.no_update index = state["index"] - slice = img_array_to_uri(self._slice(index)) + slice = img_array_to_uri(self._slice(index, clim)) return {"index": index, "slice": slice} def _create_client_callbacks(self): From af35d6fd892f045fcc7104b67589b4ab41e7c8cb Mon Sep 17 00:00:00 2001 From: Almar Klein Date: Tue, 15 Dec 2020 17:08:02 +0100 Subject: [PATCH 2/8] Add contrast example --- examples/contrast.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 examples/contrast.py diff --git a/examples/contrast.py b/examples/contrast.py new file mode 100644 index 0000000..780f5d2 --- /dev/null +++ b/examples/contrast.py @@ -0,0 +1,31 @@ +""" +A small example demonstrating contrast limits. +""" + +import dash +import dash_html_components as html +import dash_core_components as dcc +from dash.dependencies import Input, Output +from dash_slicer import VolumeSlicer +import imageio + + +app = dash.Dash(__name__, update_title=None) + +vol = imageio.volread("imageio:stent.npz") +slicer = VolumeSlicer(app, vol, clim=(0, 1000)) +clim_slider = dcc.RangeSlider( + id="clim-slider", min=vol.min(), max=vol.max(), value=(0, 1000) +) + +app.layout = html.Div([slicer.graph, slicer.slider, clim_slider, *slicer.stores]) + + +@app.callback(Output(slicer.clim.id, "data"), [Input("clim-slider", "value")]) +def update_clim(value): + return value + + +if __name__ == "__main__": + # Note: dev_tools_props_check negatively affects the performance of VolumeSlicer + app.run_server(debug=True, dev_tools_props_check=False) From a3332cc527aa5d055a1c1c3aa3933cb76f7c1fc7 Mon Sep 17 00:00:00 2001 From: Almar Klein Date: Tue, 15 Dec 2020 17:08:16 +0100 Subject: [PATCH 3/8] Update 3 other examples --- examples/slicer_with_3_views.py | 6 +++--- examples/threshold_contour.py | 2 +- examples/threshold_overlay.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/slicer_with_3_views.py b/examples/slicer_with_3_views.py index c5ff1f1..e5afd60 100644 --- a/examples/slicer_with_3_views.py +++ b/examples/slicer_with_3_views.py @@ -17,9 +17,9 @@ # Read volumes and create slicer objects vol = imageio.volread("imageio:stent.npz") -slicer0 = VolumeSlicer(app, vol, axis=0) -slicer1 = VolumeSlicer(app, vol, axis=1) -slicer2 = VolumeSlicer(app, vol, axis=2) +slicer0 = VolumeSlicer(app, vol, clim=(0, 800), axis=0) +slicer1 = VolumeSlicer(app, vol, clim=(0, 800), axis=1) +slicer2 = VolumeSlicer(app, vol, clim=(0, 800), axis=2) # Calculate isosurface and create a figure with a mesh object verts, faces, _, _ = marching_cubes(vol, 300, step_size=4) diff --git a/examples/threshold_contour.py b/examples/threshold_contour.py index ffda504..4aac28f 100644 --- a/examples/threshold_contour.py +++ b/examples/threshold_contour.py @@ -20,7 +20,7 @@ vol = imageio.volread("imageio:stent.npz") mi, ma = vol.min(), vol.max() -slicer = VolumeSlicer(app, vol) +slicer = VolumeSlicer(app, vol, clim=(0, 800)) app.layout = html.Div( diff --git a/examples/threshold_overlay.py b/examples/threshold_overlay.py index 5938950..fadad47 100644 --- a/examples/threshold_overlay.py +++ b/examples/threshold_overlay.py @@ -20,7 +20,7 @@ vol = imageio.volread("imageio:stent.npz") mi, ma = vol.min(), vol.max() -slicer = VolumeSlicer(app, vol) +slicer = VolumeSlicer(app, vol, clim=(0, 800)) app.layout = html.Div( From 11ca4ed5d29243e48ff750e4ba238c132bf38337 Mon Sep 17 00:00:00 2001 From: Almar Klein Date: Tue, 15 Dec 2020 17:08:22 +0100 Subject: [PATCH 4/8] update readme --- README.md | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index c142e23..d8b9490 100644 --- a/README.md +++ b/README.md @@ -77,7 +77,7 @@ the package. ### The VolumeSlicer class -**class `VolumeSlicer(app, volume, *, spacing=None, origin=None, axis=0, reverse_y=True, scene_id=None, color=None, thumbnail=True)`** +**class `VolumeSlicer(app, volume, *, spacing=None, origin=None, axis=0, reverse_y=True, clim=None, scene_id=None, color=None, thumbnail=True)`** A slicer object to show 3D image data in Dash. Upon instantiation one can provide the following parameters: @@ -95,6 +95,8 @@ instantiation one can provide the following parameters: the slice is in the top-left, rather than bottom-left. Default True. Note: setting this to False affects performance, see #12. This has been fixed, but the fix has not yet been released with Dash. +* `clim` (tuple of `float`): the (initial) contrast limits. Default the min + and max of the volume. * `scene_id` (`str`): the scene that this slicer is part of. Slicers that have the same scene-id show each-other's positions with line indicators. By default this is derived from `id(volume)`. @@ -118,6 +120,10 @@ color can be a list of such colors, defining a colormap. **property `VolumeSlicer.axis`** (`int`): The axis to slice. +**property `VolumeSlicer.clim`**: A `dcc.Store` representing the contrast limits as a 2-element tuple. +This value should probably not be changed too often (e.g. on slider drag) +because the thumbnail data is recreated on each change. + **property `VolumeSlicer.extra_traces`**: A `dcc.Store` that can be used as an output to define additional traces to be shown in this slicer. The data must be a list of dictionaries, with each dict representing a raw trace From d3ddcfbbf788e1e6ed351e4ba0259ecb33506ffd Mon Sep 17 00:00:00 2001 From: Almar Klein Date: Tue, 15 Dec 2020 17:13:01 +0100 Subject: [PATCH 5/8] add note to readme in generated code --- README.md | 2 ++ update_docs_in_readme.py | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index d8b9490..f7d6a0e 100644 --- a/README.md +++ b/README.md @@ -75,6 +75,8 @@ the package. ## Reference + + ### The VolumeSlicer class **class `VolumeSlicer(app, volume, *, spacing=None, origin=None, axis=0, reverse_y=True, clim=None, scene_id=None, color=None, thumbnail=True)`** diff --git a/update_docs_in_readme.py b/update_docs_in_readme.py index 0e8cdb9..8eb795e 100644 --- a/update_docs_in_readme.py +++ b/update_docs_in_readme.py @@ -14,6 +14,7 @@ def write_reference_docs(): """Write the reference docs to the README.""" # Prepare header = "## Reference" + note = "" filename = os.path.join(HERE, "README.md") assert os.path.isfile(filename), "README.md not found" # Load first part of the readme @@ -22,7 +23,7 @@ def write_reference_docs(): text1, _, _ = text.partition(header) text1 = text1.strip() # Create second part of the readme - text2 = "\n\n\n" + header + "\n\n" + get_reference_docs() + text2 = "\n\n\n" + header + "\n\n" + note + "\n\n" + get_reference_docs() if "\r" in text1: text2 = text2.replace("\n", "\r\n") # Wite From 38b4f33e2876a3ddddd8cdba3ed3dc36380b13d5 Mon Sep 17 00:00:00 2001 From: Almar Klein Date: Tue, 15 Dec 2020 20:45:03 +0100 Subject: [PATCH 6/8] tweak doc-build so test is more robust --- dash_slicer/docs.py | 3 +++ tests/test_docs.py | 4 ++-- update_docs_in_readme.py | 8 +++----- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/dash_slicer/docs.py b/dash_slicer/docs.py index 219993e..b5d73bb 100644 --- a/dash_slicer/docs.py +++ b/dash_slicer/docs.py @@ -6,6 +6,9 @@ import dash_slicer +md_seperator = "" # noqa + + def dedent(text): """Dedent a docstring, removing leading whitespace.""" lines = text.lstrip().splitlines() diff --git a/tests/test_docs.py b/tests/test_docs.py index 44c4eaf..05d91c5 100644 --- a/tests/test_docs.py +++ b/tests/test_docs.py @@ -1,6 +1,6 @@ import os -from dash_slicer.docs import get_reference_docs +from dash_slicer.docs import get_reference_docs, md_seperator HERE = os.path.dirname(os.path.abspath(__file__)) @@ -18,7 +18,7 @@ def test_that_reference_docs_in_readme_are_up_to_date(): assert os.path.isfile(filename) with open(filename, "rb") as f: text = f.read().decode() - _, _, ref = text.partition("## Reference") + _, _, ref = text.partition(md_seperator) ref1 = ref.strip().replace("\r\n", "\n") ref2 = get_reference_docs().strip() assert ( diff --git a/update_docs_in_readme.py b/update_docs_in_readme.py index 8eb795e..5c0460b 100644 --- a/update_docs_in_readme.py +++ b/update_docs_in_readme.py @@ -4,7 +4,7 @@ """ import os -from dash_slicer.docs import get_reference_docs +from dash_slicer.docs import get_reference_docs, md_seperator HERE = os.path.dirname(os.path.abspath(__file__)) @@ -13,17 +13,15 @@ def write_reference_docs(): """Write the reference docs to the README.""" # Prepare - header = "## Reference" - note = "" filename = os.path.join(HERE, "README.md") assert os.path.isfile(filename), "README.md not found" # Load first part of the readme with open(filename, "rb") as f: text = f.read().decode() - text1, _, _ = text.partition(header) + text1, _, _ = text.partition(md_seperator) text1 = text1.strip() # Create second part of the readme - text2 = "\n\n\n" + header + "\n\n" + note + "\n\n" + get_reference_docs() + text2 = "\n\n" + md_seperator + "\n\n" + get_reference_docs() if "\r" in text1: text2 = text2.replace("\n", "\r\n") # Wite From 4640f0765c68373114c30fb1ad12e75ceb267857 Mon Sep 17 00:00:00 2001 From: Almar Klein Date: Thu, 17 Dec 2020 10:38:28 +0100 Subject: [PATCH 7/8] fix handling of thumbnails, and a few tweaks --- dash_slicer/slicer.py | 50 ++++++++++++++++----------------- examples/slicer_with_3_views.py | 6 ++-- 2 files changed, 29 insertions(+), 27 deletions(-) diff --git a/dash_slicer/slicer.py b/dash_slicer/slicer.py index d23726d..2effa68 100644 --- a/dash_slicer/slicer.py +++ b/dash_slicer/slicer.py @@ -176,17 +176,17 @@ def __init__( if not (isinstance(thumbnail, (int, bool))): raise ValueError("thumbnail must be a boolean or an integer.") if thumbnail is False: - self._thumbnail = False + self._thumbnail_param = None elif thumbnail is None or thumbnail is True: - self._thumbnail = 32 # default size + self._thumbnail_param = 32 # default size else: thumbnail = int(thumbnail) if thumbnail >= np.max(volume.shape[:3]): - self._thumbnail = False # dont go larger than image size + self._thumbnail_param = None # dont go larger than image size elif thumbnail <= 0: - self._thumbnail = False # consider 0 and -1 the same as False + self._thumbnail_param = None # consider 0 and -1 the same as False else: - self._thumbnail = thumbnail + self._thumbnail_param = thumbnail # Check and store scene id, and generate if scene_id is None: @@ -218,8 +218,7 @@ def __init__( # Build the slicer self._create_dash_components() - if thumbnail: - self._create_server_callbacks() + self._create_server_callbacks() self._create_client_callbacks() # Note(AK): we could make some stores public, but let's do this only when actual use-cases arise? @@ -415,13 +414,11 @@ def _create_dash_components(self): # Prep low-res slices. The get_thumbnail_size() is a bit like # a simulation to get the low-res size. - if not self._thumbnail: - self._thumbnail_size_param = None + if self._thumbnail_param is None: info["thumbnail_size"] = info["size"] else: - self._thumbnail_size_param = self._thumbnail info["thumbnail_size"] = get_thumbnail_size( - info["size"][:2], self._thumbnail_size_param + info["size"][:2], self._thumbnail_param ) # Create the figure object - can be accessed by user via slicer.graph.figure @@ -527,22 +524,25 @@ def _create_server_callbacks(self): [Input(self._clim.id, "data")], ) def upload_thumbnails(clim): - thumbnail_size = self._thumbnail_size_param return [ - img_array_to_uri(self._slice(i, clim), thumbnail_size) + img_array_to_uri(self._slice(i, clim), self._thumbnail_param) for i in range(self.nslices) ] - @app.callback( - Output(self._server_data.id, "data"), - [Input(self._state.id, "data"), Input(self._clim.id, "data")], - ) - def upload_requested_slice(state, clim): - if state is None or not state["index_changed"]: - return dash.no_update - index = state["index"] - slice = img_array_to_uri(self._slice(index, clim)) - return {"index": index, "slice": slice} + if self._thumbnail_param is not None: + # The callback to push full-res slices to the client is only needed + # if the thumbnails are not already full-res. + + @app.callback( + Output(self._server_data.id, "data"), + [Input(self._state.id, "data"), Input(self._clim.id, "data")], + ) + def upload_requested_slice(state, clim): + if state is None or not state["index_changed"]: + return dash.no_update + index = state["index"] + slice = img_array_to_uri(self._slice(index, clim)) + return {"index": index, "slice": slice} def _create_client_callbacks(self): """Create the callbacks that run client-side.""" @@ -750,7 +750,7 @@ def _create_client_callbacks(self): State(self._info.id, "data"), State(self._graph.id, "figure"), ], - # prevent_initial_call=True, + prevent_initial_call=True, ) # ---------------------------------------------------------------------- @@ -806,9 +806,9 @@ def _create_client_callbacks(self): Input(self._slider.id, "value"), Input(self._server_data.id, "data"), Input(self._overlay_data.id, "data"), + Input(self._thumbs_data.id, "data"), ], [ - State(self._thumbs_data.id, "data"), State(self._info.id, "data"), State(self._img_traces.id, "data"), ], diff --git a/examples/slicer_with_3_views.py b/examples/slicer_with_3_views.py index e5afd60..c174069 100644 --- a/examples/slicer_with_3_views.py +++ b/examples/slicer_with_3_views.py @@ -34,7 +34,7 @@ app.layout = html.Div( style={ "display": "grid", - "gridTemplateColumns": "40% 40%", + "gridTemplateColumns": "50% 50%", }, children=[ html.Div( @@ -92,7 +92,9 @@ let s = { type: 'scatter3d', x: xyz[0], y: xyz[1], z: xyz[2], - mode: 'lines', line: {color: state.color} + mode: 'lines', line: {color: state.color}, + hoverinfo: 'skip', + showlegend: false, }; traces.push(s); } From b8f1e1114bcb21f1eebeb0678def537cbeb22540 Mon Sep 17 00:00:00 2001 From: Almar Klein Date: Thu, 17 Dec 2020 10:44:08 +0100 Subject: [PATCH 8/8] add tests for clim --- tests/test_slicer.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/tests/test_slicer.py b/tests/test_slicer.py index c857b74..82b5faf 100644 --- a/tests/test_slicer.py +++ b/tests/test_slicer.py @@ -33,12 +33,14 @@ def test_slicer_init(): assert isinstance(s.slider, dcc.Slider) assert isinstance(s.stores, list) assert all(isinstance(store, (dcc.Store, dcc.Interval)) for store in s.stores) + for store in [s.clim, s.state, s.extra_traces, s.overlay_data]: + assert isinstance(store, dcc.Store) def test_slicer_thumbnail(): - app = dash.Dash() vol = np.random.uniform(0, 255, (100, 100, 100)).astype(np.uint8) + app = dash.Dash() _ = VolumeSlicer(app, vol) # Test for name pattern of server-side callback when thumbnails are used assert any(["server-data.data" in key for key in app.callback_map]) @@ -49,6 +51,21 @@ def test_slicer_thumbnail(): assert not any(["server-data.data" in key for key in app.callback_map]) +def test_clim(): + app = dash.Dash() + vol = np.random.uniform(0, 255, (10, 10, 10)).astype(np.uint8) + mi, ma = vol.min(), vol.max() + + s = VolumeSlicer(app, vol) + assert s._initial_clim == (mi, ma) + + s = VolumeSlicer(app, vol, clim=None) + assert s._initial_clim == (mi, ma) + + s = VolumeSlicer(app, vol, clim=(10, 12)) + assert s._initial_clim == (10, 12) + + def test_scene_id_and_context_id(): app = dash.Dash()