From 3e0006a06a82b39fe35480722efb55ce0d965edb Mon Sep 17 00:00:00 2001 From: Almar Klein Date: Tue, 27 Oct 2020 15:56:59 +0100 Subject: [PATCH 01/11] first code comitted --- .gitignore | 5 + README.md | 5 +- dash_3d_viewer/__init__.py | 10 ++ dash_3d_viewer/slicer.py | 219 +++++++++++++++++++++++++++++++++ dash_3d_viewer/utils.py | 0 examples/slicer_with_1_view.py | 31 +++++ requirements.txt | 5 + 7 files changed, 274 insertions(+), 1 deletion(-) create mode 100644 .gitignore create mode 100644 dash_3d_viewer/__init__.py create mode 100644 dash_3d_viewer/slicer.py create mode 100644 dash_3d_viewer/utils.py create mode 100644 examples/slicer_with_1_view.py create mode 100644 requirements.txt diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..022d29b --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +__pycache__ +*.pyc +*.pyo +dist/ +build/ diff --git a/README.md b/README.md index 3b4dbc8..e62a39d 100644 --- a/README.md +++ b/README.md @@ -1 +1,4 @@ -# dash-3d-viewer \ No newline at end of file +# dash-3d-viewer + +A tool to make it easy to build slice-views on 3D image data. + diff --git a/dash_3d_viewer/__init__.py b/dash_3d_viewer/__init__.py new file mode 100644 index 0000000..e41ae10 --- /dev/null +++ b/dash_3d_viewer/__init__.py @@ -0,0 +1,10 @@ +""" +Dash 3d viewer - a tool to make it easy to build slice-views on 3D image data. +""" + + +from .slicer import DashVolumeSlicer + + +__version__ = "0.0.1" +version_info = tuple(map(int, __version__.split("."))) diff --git a/dash_3d_viewer/slicer.py b/dash_3d_viewer/slicer.py new file mode 100644 index 0000000..9af64f2 --- /dev/null +++ b/dash_3d_viewer/slicer.py @@ -0,0 +1,219 @@ +import base64 + +import skimage +import PIL.Image +import numpy as np +import plotly.graph_objects as go +from plotly.utils import ImageUriValidator +import dash +from dash.dependencies import Input, Output, State +import dash_core_components as dcc + +# todo: id's defined here must be made unique +# todo: anisotropy +# todo: clim +# todo: maybe ... a plane instead of an axis? +# todo: callbacks are now defined before the layout, which is not supposed to work? +# todo: request neighbouring slices too +# todo: remove slices from the cache if the cache becomes too big +# todo: should we put "slicer" in the name to make clear this tool applies to image data? + + +# %%%%% From plot_common + + +def dummy_fig(): + fig = go.Figure(go.Scatter(x=[], y=[])) # todo: why a scatter plot here? + fig.update_layout(template=None) + fig.update_xaxes(showgrid=False, showticklabels=False, zeroline=False) + fig.update_yaxes( + showgrid=False, scaleanchor="x", showticklabels=False, zeroline=False + ) + return fig + + +def img_array_to_pil_image(ia): + ia = skimage.util.img_as_ubyte(ia) + img = PIL.Image.fromarray(ia) + return img + + +def pil_image_to_uri(img): + return ImageUriValidator.pil_image_to_uri(img) + + +def img_array_to_uri(img_array): + imgf = img_array_to_pil_image(img_array) + uri = pil_image_to_uri(imgf) + return uri + + +# %%%%%% Utils + + +# %%%%% + + +class DashVolumeSlicer: + """A slicer to show 3D image data in Dash.""" + + def __init__(self, app, volume, axis=0): + + assert isinstance(app, dash.Dash) + if not isinstance(volume, np.ndarray) and image.ndim == 3: + raise TypeError("DashVolumeSlicer expects a 3D numpy array") + + self._id = "thereisonlyoneslicerfornow" + self._volume = volume + self._axis = int(axis) + self._max_slice = self._volume.shape[self._axis] + assert 0 <= self._axis <= 2 + + slice_shape = list(volume.shape) + slice_shape.pop(self._axis) + + # Create the figure object + fig = dummy_fig() + # Add an empty layout image that we can populate from JS. + fig.add_layout_image( + dict( + source="", + xref="x", + yref="y", + x=0, + y=0, + sizex=slice_shape[0], + sizey=slice_shape[1], + sizing="contain", + layer="below", + ) + ) + fig.update_xaxes( + showgrid=False, + range=(0, slice_shape[0]), + showticklabels=False, + zeroline=False, + ) + fig.update_yaxes( + showgrid=False, + scaleanchor="x", + range=(slice_shape[1], 0), + showticklabels=False, + zeroline=False, + ) + fig.update_layout( + { + "margin": dict(l=0, r=0, b=0, t=0, pad=4), + } + ) + + self.graph = dcc.Graph( + id="graph", + figure=fig, + config={"scrollZoom": True}, + ) + + self.slider = dcc.Slider( + id="slider", + min=0, + max=self._max_slice - 1, + step=1, + value=self._max_slice // 2, + updatemode="drag", + ) + + self.stores = [ + dcc.Store(id="slice-index", data=volume.shape[self._axis] // 2), + dcc.Store(id="_requested-slice-index", data=0), + dcc.Store(id="_slice-data", data=""), + ] + + self._create_server_handlers(app) + self._create_client_handlers(app) + + def _slice(self, index): + indices = [slice(None), slice(None), slice(None)] + indices[self._axis] = index + return self._volume[tuple(indices)] + + def _create_server_handlers(self, app): + @app.callback( + Output("_slice-data", "data"), + [Input("_requested-slice-index", "data")], + ) + def upload_requested_slice(slice_index): + slice = self._slice(slice_index) + slice = (slice.astype(np.float32) / 4).astype(np.uint8) + return [slice_index, img_array_to_uri(slice)] + + def _create_client_handlers(self, app): + + app.clientside_callback( + """ + function handle_slider_move(index) { + return index; + } + """, + Output("slice-index", "data"), + [Input("slider", "value")], + ) + + app.clientside_callback( + """ + function handle_slice_index(index) { + if (!window.slicecache_for_{{ID}}) { window.slicecache_for_{{ID}} = {}; } + let slice_cache = window.slicecache_for_{{ID}}; + if (slice_cache[index]) { + return window.dash_clientside.no_update; + } else { + console.log('requesting slice ' + index) + return index; + } + } + """.replace( + "{{ID}}", self._id + ), + Output("_requested-slice-index", "data"), + [Input("slice-index", "data")], + ) + + # app.clientside_callback(""" + # function update_slider_pos(index) { + # return index; + # } + # """, + # [Output("slice-index", "data")], + # [State("slider", "value")], + # ) + + app.clientside_callback( + """ + function handle_incoming_slice(index, index_and_data, ori_figure) { + let new_index = index_and_data[0]; + let new_data = index_and_data[1]; + // Store data in cache + if (!window.slicecache_for_{{ID}}) { window.slicecache_for_{{ID}} = {}; } + let slice_cache = window.slicecache_for_{{ID}}; + slice_cache[new_index] = new_data; + // Get the data we need *now* + let data = slice_cache[index]; + // Maybe we do not need an update + if (!data) { + return window.dash_clientside.no_update; + } + if (data == ori_figure.layout.images[0].source) { + return window.dash_clientside.no_update; + } + // Otherwise, perform update + console.log("updating figure"); + let figure = {...ori_figure}; + figure.layout.images[0].source = data; + return figure; + } + """.replace( + "{{ID}}", self._id + ), + Output("graph", "figure"), + [Input("slice-index", "data"), Input("_slice-data", "data")], + [State("graph", "figure")], + ) diff --git a/dash_3d_viewer/utils.py b/dash_3d_viewer/utils.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/slicer_with_1_view.py b/examples/slicer_with_1_view.py new file mode 100644 index 0000000..166c5fa --- /dev/null +++ b/examples/slicer_with_1_view.py @@ -0,0 +1,31 @@ +import dash +import dash_core_components as dcc +import dash_html_components as html +from dash.dependencies import Input, Output + +import imageio +from dash_3d_viewer import DashVolumeSlicer + + +app = dash.Dash(__name__) + + +vol = imageio.volread("imageio:stent.npz") +slicer = DashVolumeSlicer(app, vol) + + +app.layout = html.Div( + [html.H6("Blabla"), slicer.graph, html.Br(), slicer.slider, *slicer.stores] +) + + +# @app.callback( +# Output('my-output', 'children'), +# [Input('my-input', 'value')] +# ) +# def update_output_div(input_value): +# return 'Output bla: {}'.format(input_value) + + +if __name__ == "__main__": + app.run_server(debug=True) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..2e416aa --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +pillow +numpy +plotly +dash +dash_core_components From dcb59f7523ca2417ef6962b031f15abc74fc7d5e Mon Sep 17 00:00:00 2001 From: Almar Klein Date: Tue, 27 Oct 2020 16:16:13 +0100 Subject: [PATCH 02/11] flakeify --- README.md | 11 +++++++++++ dash_3d_viewer/__init__.py | 2 +- dash_3d_viewer/slicer.py | 6 ++---- examples/slicer_with_1_view.py | 2 -- setup.cfg | 4 ++++ 5 files changed, 18 insertions(+), 7 deletions(-) create mode 100644 setup.cfg diff --git a/README.md b/README.md index e62a39d..7902d49 100644 --- a/README.md +++ b/README.md @@ -2,3 +2,14 @@ A tool to make it easy to build slice-views on 3D image data. + +## Installation + +Eventually, this would be pip-installable. For now, use the developer workflow. + + +## Developers + +To run the examples: +with an env that has the appropriate requirements, from the repo's root directory, run `python example/xxxx.py`. + diff --git a/dash_3d_viewer/__init__.py b/dash_3d_viewer/__init__.py index e41ae10..1f56b2a 100644 --- a/dash_3d_viewer/__init__.py +++ b/dash_3d_viewer/__init__.py @@ -3,7 +3,7 @@ """ -from .slicer import DashVolumeSlicer +from .slicer import DashVolumeSlicer # noqa: F401 __version__ = "0.0.1" diff --git a/dash_3d_viewer/slicer.py b/dash_3d_viewer/slicer.py index 9af64f2..2884c1f 100644 --- a/dash_3d_viewer/slicer.py +++ b/dash_3d_viewer/slicer.py @@ -1,7 +1,5 @@ -import base64 - -import skimage import PIL.Image +import skimage import numpy as np import plotly.graph_objects as go from plotly.utils import ImageUriValidator @@ -60,7 +58,7 @@ class DashVolumeSlicer: def __init__(self, app, volume, axis=0): assert isinstance(app, dash.Dash) - if not isinstance(volume, np.ndarray) and image.ndim == 3: + if not (isinstance(volume, np.ndarray) and volume.ndim == 3): raise TypeError("DashVolumeSlicer expects a 3D numpy array") self._id = "thereisonlyoneslicerfornow" diff --git a/examples/slicer_with_1_view.py b/examples/slicer_with_1_view.py index 166c5fa..f64134e 100644 --- a/examples/slicer_with_1_view.py +++ b/examples/slicer_with_1_view.py @@ -1,7 +1,5 @@ import dash -import dash_core_components as dcc import dash_html_components as html -from dash.dependencies import Input, Output import imageio from dash_3d_viewer import DashVolumeSlicer diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..39e30a6 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,4 @@ +[flake8] +max_line_length = 89 +extend-ignore = E501 +exclude = build,dist,*.egg-info From 789d7f55eac5c58b970e3547ef0ae2fcd8232ba5 Mon Sep 17 00:00:00 2001 From: Almar Klein Date: Tue, 27 Oct 2020 16:45:00 +0100 Subject: [PATCH 03/11] some cleaning and tweaking --- dash_3d_viewer/slicer.py | 74 ++++++++-------------------------- dash_3d_viewer/utils.py | 10 +++++ examples/slicer_with_1_view.py | 21 +++------- 3 files changed, 33 insertions(+), 72 deletions(-) diff --git a/dash_3d_viewer/slicer.py b/dash_3d_viewer/slicer.py index 2884c1f..6cbe119 100644 --- a/dash_3d_viewer/slicer.py +++ b/dash_3d_viewer/slicer.py @@ -1,55 +1,10 @@ -import PIL.Image -import skimage import numpy as np -import plotly.graph_objects as go -from plotly.utils import ImageUriValidator -import dash +from plotly.graph_objects import Figure +from dash import Dash from dash.dependencies import Input, Output, State -import dash_core_components as dcc +from dash_core_components import Graph, Slider, Store -# todo: id's defined here must be made unique -# todo: anisotropy -# todo: clim -# todo: maybe ... a plane instead of an axis? -# todo: callbacks are now defined before the layout, which is not supposed to work? -# todo: request neighbouring slices too -# todo: remove slices from the cache if the cache becomes too big -# todo: should we put "slicer" in the name to make clear this tool applies to image data? - - -# %%%%% From plot_common - - -def dummy_fig(): - fig = go.Figure(go.Scatter(x=[], y=[])) # todo: why a scatter plot here? - fig.update_layout(template=None) - fig.update_xaxes(showgrid=False, showticklabels=False, zeroline=False) - fig.update_yaxes( - showgrid=False, scaleanchor="x", showticklabels=False, zeroline=False - ) - return fig - - -def img_array_to_pil_image(ia): - ia = skimage.util.img_as_ubyte(ia) - img = PIL.Image.fromarray(ia) - return img - - -def pil_image_to_uri(img): - return ImageUriValidator.pil_image_to_uri(img) - - -def img_array_to_uri(img_array): - imgf = img_array_to_pil_image(img_array) - uri = pil_image_to_uri(imgf) - return uri - - -# %%%%%% Utils - - -# %%%%% +from .utils import img_array_to_uri class DashVolumeSlicer: @@ -57,7 +12,7 @@ class DashVolumeSlicer: def __init__(self, app, volume, axis=0): - assert isinstance(app, dash.Dash) + assert isinstance(app, Dash) if not (isinstance(volume, np.ndarray) and volume.ndim == 3): raise TypeError("DashVolumeSlicer expects a 3D numpy array") @@ -71,7 +26,12 @@ def __init__(self, app, volume, axis=0): slice_shape.pop(self._axis) # Create the figure object - fig = dummy_fig() + fig = Figure() + fig.update_layout(template=None) + fig.update_xaxes(showgrid=False, showticklabels=False, zeroline=False) + fig.update_yaxes( + showgrid=False, scaleanchor="x", showticklabels=False, zeroline=False + ) # Add an empty layout image that we can populate from JS. fig.add_layout_image( dict( @@ -105,13 +65,13 @@ def __init__(self, app, volume, axis=0): } ) - self.graph = dcc.Graph( + self.graph = Graph( id="graph", figure=fig, config={"scrollZoom": True}, ) - self.slider = dcc.Slider( + self.slider = Slider( id="slider", min=0, max=self._max_slice - 1, @@ -121,9 +81,9 @@ def __init__(self, app, volume, axis=0): ) self.stores = [ - dcc.Store(id="slice-index", data=volume.shape[self._axis] // 2), - dcc.Store(id="_requested-slice-index", data=0), - dcc.Store(id="_slice-data", data=""), + Store(id="slice-index", data=volume.shape[self._axis] // 2), + Store(id="_requested-slice-index", data=0), + Store(id="_slice-data", data=""), ] self._create_server_handlers(app) @@ -141,7 +101,7 @@ def _create_server_handlers(self, app): ) def upload_requested_slice(slice_index): slice = self._slice(slice_index) - slice = (slice.astype(np.float32) / 4).astype(np.uint8) + slice = (slice.astype(np.float32) * (255 / slice.max())).astype(np.uint8) return [slice_index, img_array_to_uri(slice)] def _create_client_handlers(self, app): diff --git a/dash_3d_viewer/utils.py b/dash_3d_viewer/utils.py index e69de29..7abd195 100644 --- a/dash_3d_viewer/utils.py +++ b/dash_3d_viewer/utils.py @@ -0,0 +1,10 @@ +import PIL.Image +import skimage +from plotly.utils import ImageUriValidator + + +def img_array_to_uri(img_array): + img_array = skimage.util.img_as_ubyte(img_array) + img_pil = PIL.Image.fromarray(img_array) + uri = ImageUriValidator.pil_image_to_uri(img_pil) + return uri diff --git a/examples/slicer_with_1_view.py b/examples/slicer_with_1_view.py index f64134e..e79824e 100644 --- a/examples/slicer_with_1_view.py +++ b/examples/slicer_with_1_view.py @@ -1,28 +1,19 @@ +""" +A truly minimal example. +""" + import dash import dash_html_components as html - -import imageio from dash_3d_viewer import DashVolumeSlicer +import imageio app = dash.Dash(__name__) - vol = imageio.volread("imageio:stent.npz") slicer = DashVolumeSlicer(app, vol) - -app.layout = html.Div( - [html.H6("Blabla"), slicer.graph, html.Br(), slicer.slider, *slicer.stores] -) - - -# @app.callback( -# Output('my-output', 'children'), -# [Input('my-input', 'value')] -# ) -# def update_output_div(input_value): -# return 'Output bla: {}'.format(input_value) +app.layout = html.Div([slicer.graph, slicer.slider, *slicer.stores]) if __name__ == "__main__": From fb3dfa908a883c67be9ccc8bd979924b6ad5d744 Mon Sep 17 00:00:00 2001 From: Almar Klein Date: Wed, 28 Oct 2020 16:08:11 +0100 Subject: [PATCH 04/11] Implement ids, to support multiple slicers. Add examples --- dash_3d_viewer/slicer.py | 100 +++++++++++++++++--------------- dash_3d_viewer/utils.py | 6 ++ examples/slicer_with_1_view.py | 2 +- examples/slicer_with_2_views.py | 46 +++++++++++++++ examples/slicer_with_3_views.py | 58 ++++++++++++++++++ 5 files changed, 165 insertions(+), 47 deletions(-) create mode 100644 examples/slicer_with_2_views.py create mode 100644 examples/slicer_with_3_views.py diff --git a/dash_3d_viewer/slicer.py b/dash_3d_viewer/slicer.py index 6cbe119..2123ead 100644 --- a/dash_3d_viewer/slicer.py +++ b/dash_3d_viewer/slicer.py @@ -4,33 +4,53 @@ from dash.dependencies import Input, Output, State from dash_core_components import Graph, Slider, Store -from .utils import img_array_to_uri +from .utils import gen_random_id, img_array_to_uri class DashVolumeSlicer: """A slicer to show 3D image data in Dash.""" - def __init__(self, app, volume, axis=0): + def __init__(self, app, volume, axis=0, id=None): assert isinstance(app, Dash) + if not (isinstance(volume, np.ndarray) and volume.ndim == 3): raise TypeError("DashVolumeSlicer expects a 3D numpy array") - - self._id = "thereisonlyoneslicerfornow" self._volume = volume + + if id is None: + id = gen_random_id() + elif not isinstance(id, str): + raise TypeError("Id must be a string") + self._id = id + self._axis = int(axis) self._max_slice = self._volume.shape[self._axis] assert 0 <= self._axis <= 2 - slice_shape = list(volume.shape) - slice_shape.pop(self._axis) + # Get the slice size (width, height) + arr_shape = list(volume.shape) + arr_shape.pop(self._axis) + slice_size = list(reversed(arr_shape)) # Create the figure object fig = Figure() - fig.update_layout(template=None) - fig.update_xaxes(showgrid=False, showticklabels=False, zeroline=False) + fig.update_layout( + template=None, + margin=dict(l=0, r=0, b=0, t=0, pad=4), + ) + fig.update_xaxes( + showgrid=False, + range=(0, slice_size[0]), + showticklabels=False, + zeroline=False, + ) fig.update_yaxes( - showgrid=False, scaleanchor="x", showticklabels=False, zeroline=False + showgrid=False, + scaleanchor="x", + range=(slice_size[1], 0), # todo: allow flipping x or y + showticklabels=False, + zeroline=False, ) # Add an empty layout image that we can populate from JS. fig.add_layout_image( @@ -40,55 +60,40 @@ def __init__(self, app, volume, axis=0): yref="y", x=0, y=0, - sizex=slice_shape[0], - sizey=slice_shape[1], + sizex=slice_size[0], + sizey=slice_size[1], sizing="contain", layer="below", ) ) - fig.update_xaxes( - showgrid=False, - range=(0, slice_shape[0]), - showticklabels=False, - zeroline=False, - ) - fig.update_yaxes( - showgrid=False, - scaleanchor="x", - range=(slice_shape[1], 0), - showticklabels=False, - zeroline=False, - ) - fig.update_layout( - { - "margin": dict(l=0, r=0, b=0, t=0, pad=4), - } - ) - + # Wrap the figure in a graph self.graph = Graph( - id="graph", + id=self._subid("graph"), figure=fig, config={"scrollZoom": True}, ) - + # Create a slider object that the user can put in the layout (or not) self.slider = Slider( - id="slider", + id=self._subid("slider"), min=0, max=self._max_slice - 1, step=1, value=self._max_slice // 2, updatemode="drag", ) - + # Create the stores that we need (these must be present in the layout) self.stores = [ - Store(id="slice-index", data=volume.shape[self._axis] // 2), - Store(id="_requested-slice-index", data=0), - Store(id="_slice-data", data=""), + Store(id=self._subid("slice-index"), data=volume.shape[self._axis] // 2), + Store(id=self._subid("_requested-slice-index"), data=0), + Store(id=self._subid("_slice-data"), data=""), ] self._create_server_handlers(app) self._create_client_handlers(app) + def _subid(self, subid): + return self._id + "-" + subid + def _slice(self, index): indices = [slice(None), slice(None), slice(None)] indices[self._axis] = index @@ -96,8 +101,8 @@ def _slice(self, index): def _create_server_handlers(self, app): @app.callback( - Output("_slice-data", "data"), - [Input("_requested-slice-index", "data")], + Output(self._subid("_slice-data"), "data"), + [Input(self._subid("_requested-slice-index"), "data")], ) def upload_requested_slice(slice_index): slice = self._slice(slice_index) @@ -112,8 +117,8 @@ def _create_client_handlers(self, app): return index; } """, - Output("slice-index", "data"), - [Input("slider", "value")], + Output(self._subid("slice-index"), "data"), + [Input(self._subid("slider"), "value")], ) app.clientside_callback( @@ -131,8 +136,8 @@ def _create_client_handlers(self, app): """.replace( "{{ID}}", self._id ), - Output("_requested-slice-index", "data"), - [Input("slice-index", "data")], + Output(self._subid("_requested-slice-index"), "data"), + [Input(self._subid("slice-index"), "data")], ) # app.clientside_callback(""" @@ -171,7 +176,10 @@ def _create_client_handlers(self, app): """.replace( "{{ID}}", self._id ), - Output("graph", "figure"), - [Input("slice-index", "data"), Input("_slice-data", "data")], - [State("graph", "figure")], + Output(self._subid("graph"), "figure"), + [ + Input(self._subid("slice-index"), "data"), + Input(self._subid("_slice-data"), "data"), + ], + [State(self._subid("graph"), "figure")], ) diff --git a/dash_3d_viewer/utils.py b/dash_3d_viewer/utils.py index 7abd195..5dd2e8b 100644 --- a/dash_3d_viewer/utils.py +++ b/dash_3d_viewer/utils.py @@ -1,8 +1,14 @@ +import random + import PIL.Image import skimage from plotly.utils import ImageUriValidator +def gen_random_id(n=6): + return "".join(random.choice("abcdefghijklmnopqrtsuvwxyz") for i in range(n)) + + def img_array_to_uri(img_array): img_array = skimage.util.img_as_ubyte(img_array) img_pil = PIL.Image.fromarray(img_array) diff --git a/examples/slicer_with_1_view.py b/examples/slicer_with_1_view.py index e79824e..61ac309 100644 --- a/examples/slicer_with_1_view.py +++ b/examples/slicer_with_1_view.py @@ -17,4 +17,4 @@ if __name__ == "__main__": - app.run_server(debug=True) + app.run_server(debug=False) diff --git a/examples/slicer_with_2_views.py b/examples/slicer_with_2_views.py new file mode 100644 index 0000000..7913e2d --- /dev/null +++ b/examples/slicer_with_2_views.py @@ -0,0 +1,46 @@ +""" +An example with two slicers on the same volume. +""" + +import dash +import dash_html_components as html +from dash_3d_viewer import DashVolumeSlicer +import imageio + + +app = dash.Dash(__name__) + +vol = imageio.volread("imageio:stent.npz") +slicer1 = DashVolumeSlicer(app, vol, axis=1, id="slicer1") +slicer2 = DashVolumeSlicer(app, vol, axis=2, id="slicer2") + +app.layout = html.Div( + style={ + "display": "grid", + "grid-template-columns": "40% 40%", + }, + children=[ + html.Div( + [ + html.H1("Coronal"), + slicer1.graph, + html.Br(), + slicer1.slider, + *slicer1.stores, + ] + ), + html.Div( + [ + html.H1("Sagittal"), + slicer2.graph, + html.Br(), + slicer2.slider, + *slicer2.stores, + ] + ), + ], +) + + +if __name__ == "__main__": + app.run_server(debug=True) diff --git a/examples/slicer_with_3_views.py b/examples/slicer_with_3_views.py new file mode 100644 index 0000000..7c4a799 --- /dev/null +++ b/examples/slicer_with_3_views.py @@ -0,0 +1,58 @@ +""" +An example creating three slice-views through a volume, as is common +in medical applications. In the fourth quadrant you could place an isosurface mesh. +""" + +import dash +import dash_html_components as html +from dash_3d_viewer import DashVolumeSlicer +import imageio + + +app = dash.Dash(__name__) + +vol = imageio.volread("imageio:stent.npz") +slicer1 = DashVolumeSlicer(app, vol, axis=0, id="slicer1") +slicer2 = DashVolumeSlicer(app, vol, axis=1, id="slicer2") +slicer3 = DashVolumeSlicer(app, vol, axis=2, id="slicer3") + + +app.layout = html.Div( + style={ + "display": "grid", + "grid-template-columns": "40% 40%", + }, + children=[ + html.Div( + [ + html.Center(html.H1("Transversal")), + slicer1.graph, + html.Br(), + slicer1.slider, + *slicer1.stores, + ] + ), + html.Div( + [ + html.Center(html.H1("Coronal")), + slicer2.graph, + html.Br(), + slicer2.slider, + *slicer2.stores, + ] + ), + html.Div( + [ + html.Center(html.H1("Sagittal")), + slicer3.graph, + html.Br(), + slicer3.slider, + *slicer3.stores, + ] + ), + ], +) + + +if __name__ == "__main__": + app.run_server(debug=False) From cc13851605727282e44598f43d399592bc7914b8 Mon Sep 17 00:00:00 2001 From: Almar Klein Date: Wed, 28 Oct 2020 16:15:14 +0100 Subject: [PATCH 05/11] add 3d view to example --- dash_3d_viewer/slicer.py | 1 + examples/slicer_with_3_views.py | 12 +++++++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/dash_3d_viewer/slicer.py b/dash_3d_viewer/slicer.py index 2123ead..59eb4e7 100644 --- a/dash_3d_viewer/slicer.py +++ b/dash_3d_viewer/slicer.py @@ -67,6 +67,7 @@ def __init__(self, app, volume, axis=0, id=None): ) ) # Wrap the figure in a graph + # todo: or should the user provide this? self.graph = Graph( id=self._subid("graph"), figure=fig, diff --git a/examples/slicer_with_3_views.py b/examples/slicer_with_3_views.py index 7c4a799..afd7731 100644 --- a/examples/slicer_with_3_views.py +++ b/examples/slicer_with_3_views.py @@ -3,12 +3,14 @@ in medical applications. In the fourth quadrant you could place an isosurface mesh. """ +import plotly.graph_objects as go import dash import dash_html_components as html +import dash_core_components as dcc from dash_3d_viewer import DashVolumeSlicer +from skimage.measure import marching_cubes import imageio - app = dash.Dash(__name__) vol = imageio.volread("imageio:stent.npz") @@ -16,6 +18,11 @@ slicer2 = DashVolumeSlicer(app, vol, axis=1, id="slicer2") slicer3 = DashVolumeSlicer(app, vol, axis=2, id="slicer3") +verts, faces, _, _ = marching_cubes(vol, 300, step_size=2) +x, y, z = verts.T +i, j, k = faces.T +fig_mesh = go.Figure() +fig_mesh.add_trace(go.Mesh3d(x=z, y=y, z=x, opacity=0.2, i=k, j=j, k=i)) app.layout = html.Div( style={ @@ -50,6 +57,9 @@ *slicer3.stores, ] ), + html.Div( + [html.Center(html.H1("3D")), dcc.Graph(id="graph-helper", figure=fig_mesh)] + ), ], ) From 14ae20f4bec4a140dcbeab725b3cc2a77b68d1d7 Mon Sep 17 00:00:00 2001 From: Almar Klein Date: Thu, 29 Oct 2020 09:45:14 +0100 Subject: [PATCH 06/11] comments / clean up --- dash_3d_viewer/slicer.py | 38 +++++++++++++++++++-------------- examples/slicer_with_3_views.py | 5 ++++- 2 files changed, 26 insertions(+), 17 deletions(-) diff --git a/dash_3d_viewer/slicer.py b/dash_3d_viewer/slicer.py index 59eb4e7..76c77ab 100644 --- a/dash_3d_viewer/slicer.py +++ b/dash_3d_viewer/slicer.py @@ -11,27 +11,28 @@ class DashVolumeSlicer: """A slicer to show 3D image data in Dash.""" def __init__(self, app, volume, axis=0, id=None): - - assert isinstance(app, Dash) - + if not isinstance(app, Dash): + raise TypeError("Expect first arg to be a Dash app.") + # Check and store volume if not (isinstance(volume, np.ndarray) and volume.ndim == 3): - raise TypeError("DashVolumeSlicer expects a 3D numpy array") + raise TypeError("Expected volume to be a 3D numpy array") self._volume = volume - + # Check and store axis + if not (isinstance(axis, int) and 0 <= self._axis <= 2): + raise ValueError("The given axis must be 0, 1, or 2.") + self._axis = int(axis) + # Check and store id if id is None: id = gen_random_id() elif not isinstance(id, str): raise TypeError("Id must be a string") self._id = id - self._axis = int(axis) - self._max_slice = self._volume.shape[self._axis] - assert 0 <= self._axis <= 2 - - # Get the slice size (width, height) + # Get the slice size (width, height), and max index arr_shape = list(volume.shape) arr_shape.pop(self._axis) slice_size = list(reversed(arr_shape)) + self._max_index = self._volume.shape[self._axis] - 1 # Create the figure object fig = Figure() @@ -77,9 +78,9 @@ def __init__(self, app, volume, axis=0, id=None): self.slider = Slider( id=self._subid("slider"), min=0, - max=self._max_slice - 1, + max=self._max_index, step=1, - value=self._max_slice // 2, + value=self._max_index // 2, updatemode="drag", ) # Create the stores that we need (these must be present in the layout) @@ -89,18 +90,22 @@ def __init__(self, app, volume, axis=0, id=None): Store(id=self._subid("_slice-data"), data=""), ] - self._create_server_handlers(app) - self._create_client_handlers(app) + self._create_server_callbacks(app) + self._create_client_callbacks(app) def _subid(self, subid): + """Given a subid, get the full id including the slicer's prefix.""" return self._id + "-" + subid def _slice(self, index): + """Sample a slice from the volume.""" indices = [slice(None), slice(None), slice(None)] indices[self._axis] = index return self._volume[tuple(indices)] - def _create_server_handlers(self, app): + def _create_server_callbacks(self, app): + """Create the callbacks that run server-side.""" + @app.callback( Output(self._subid("_slice-data"), "data"), [Input(self._subid("_requested-slice-index"), "data")], @@ -110,7 +115,8 @@ def upload_requested_slice(slice_index): slice = (slice.astype(np.float32) * (255 / slice.max())).astype(np.uint8) return [slice_index, img_array_to_uri(slice)] - def _create_client_handlers(self, app): + def _create_client_callbacks(self, app): + """Create the callbacks that run client-side.""" app.clientside_callback( """ diff --git a/examples/slicer_with_3_views.py b/examples/slicer_with_3_views.py index afd7731..93e3906 100644 --- a/examples/slicer_with_3_views.py +++ b/examples/slicer_with_3_views.py @@ -1,6 +1,6 @@ """ An example creating three slice-views through a volume, as is common -in medical applications. In the fourth quadrant you could place an isosurface mesh. +in medical applications. In the fourth quadrant we put an isosurface mesh. """ import plotly.graph_objects as go @@ -13,17 +13,20 @@ app = dash.Dash(__name__) +# Read volumes and create slicer objects vol = imageio.volread("imageio:stent.npz") slicer1 = DashVolumeSlicer(app, vol, axis=0, id="slicer1") slicer2 = DashVolumeSlicer(app, vol, axis=1, id="slicer2") slicer3 = DashVolumeSlicer(app, vol, axis=2, id="slicer3") +# Calculate isosurface and create a figure with a mesh object verts, faces, _, _ = marching_cubes(vol, 300, step_size=2) x, y, z = verts.T i, j, k = faces.T fig_mesh = go.Figure() fig_mesh.add_trace(go.Mesh3d(x=z, y=y, z=x, opacity=0.2, i=k, j=j, k=i)) +# Put everything together in a 2x2 grid app.layout = html.Div( style={ "display": "grid", From 60d2cff576cddefdfcab75d7e3bf2313602a8598 Mon Sep 17 00:00:00 2001 From: Almar Klein Date: Thu, 29 Oct 2020 09:48:08 +0100 Subject: [PATCH 07/11] fix --- dash_3d_viewer/slicer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dash_3d_viewer/slicer.py b/dash_3d_viewer/slicer.py index 76c77ab..d367ed4 100644 --- a/dash_3d_viewer/slicer.py +++ b/dash_3d_viewer/slicer.py @@ -18,7 +18,7 @@ def __init__(self, app, volume, axis=0, id=None): raise TypeError("Expected volume to be a 3D numpy array") self._volume = volume # Check and store axis - if not (isinstance(axis, int) and 0 <= self._axis <= 2): + if not (isinstance(axis, int) and 0 <= axis <= 2): raise ValueError("The given axis must be 0, 1, or 2.") self._axis = int(axis) # Check and store id From 94158e25360f93c8b81f27e09fca3b7fbe56f82f Mon Sep 17 00:00:00 2001 From: Almar Klein Date: Thu, 29 Oct 2020 10:34:09 +0100 Subject: [PATCH 08/11] add setup.py --- .gitignore | 1 + setup.py | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+) create mode 100644 setup.py diff --git a/.gitignore b/.gitignore index 022d29b..2d0cba7 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ __pycache__ *.pyc *.pyo +*.egg-info dist/ build/ diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..594f4b3 --- /dev/null +++ b/setup.py @@ -0,0 +1,48 @@ +import re + +from setuptools import find_packages, setup + + +NAME = "dash_3d_viewer" +SUMMARY = ( + "A library to make it easy to build slice-views on 3D image data in Dash apps." +) + +with open(f"{NAME}/__init__.py") as fh: + VERSION = re.search(r"__version__ = \"(.*?)\"", fh.read()).group(1) + + +runtime_deps = [ + "pillow", + "numpy", + "plotly", + "dash", + "dash_core_components", + "scikit-image", # may not be needed eventually? +] + + +setup( + name=NAME, + version=VERSION, + packages=find_packages(exclude=["tests", "tests.*", "examples", "examples.*"]), + python_requires=">=3.6.0", + install_requires=runtime_deps, + license="MIT", + description=SUMMARY, + long_description=open("README.md").read(), + long_description_content_type="text/markdown", + author="Plotly", + author_email="almar.klein@gmail.com", + # url="https://github.com/plotly/will be renamed?", + data_files=[("", ["LICENSE"])], + zip_safe=True, # not if we put JS in a seperate file, I think + classifiers=[ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Topic :: Software Development :: Libraries :: Python Modules", + "Topic :: Scientific/Engineering :: Visualization", + ], +) From 3ae5c5a072ad11e4008cd09e6ae0d24cf544bfba Mon Sep 17 00:00:00 2001 From: Almar Klein Date: Thu, 29 Oct 2020 13:59:21 +0100 Subject: [PATCH 09/11] add example that actually used the sub-components --- examples/use_components.py | 63 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 examples/use_components.py diff --git a/examples/use_components.py b/examples/use_components.py new file mode 100644 index 0000000..229a66a --- /dev/null +++ b/examples/use_components.py @@ -0,0 +1,63 @@ +""" +A small example showing how to write callbacks involving the slicer's +components. The slicer's components are used as both inputs and outputs. +""" + +import dash +import dash_html_components as html +from dash.dependencies import Input, Output, State +from dash_3d_viewer import DashVolumeSlicer +import imageio + + +app = dash.Dash(__name__) + +vol = imageio.volread("imageio:stent.npz") +slicer = DashVolumeSlicer(app, vol) + +# We can access the components, and modify them +slicer.slider.value = 0 + +# Define the layour, including extra buttons +app.layout = html.Div( + [ + slicer.graph, + html.Br(), + html.Div( + style={"display": "flex"}, + children=[ + html.Div("", id="index-show", style={"padding": "0.4em"}), + html.Button("<", id="decrease-index"), + html.Div(slicer.slider, style={"flexGrow": "1"}), + html.Button(">", id="increase-index"), + ], + ), + *slicer.stores, + ] +) + +# New callbacks for our added widgets + + +@app.callback( + Output("index-show", "children"), + [Input(slicer.slider.id, "value")], +) +def show_slider_value(index): + return str(index) + + +@app.callback( + Output(slicer.slider.id, "value"), + [Input("decrease-index", "n_clicks"), Input("increase-index", "n_clicks")], + [State(slicer.slider.id, "value")], +) +def handle_button_input(press1, press2, index): + ctx = dash.callback_context + if ctx.triggered: + index += 1 if "increase" in ctx.triggered[0]["prop_id"] else -1 + return index + + +if __name__ == "__main__": + app.run_server(debug=True) From 07edeb199e2b9420b6618639c1d27c78e69c8b7f Mon Sep 17 00:00:00 2001 From: Almar Klein Date: Thu, 29 Oct 2020 16:49:22 +0100 Subject: [PATCH 10/11] add a note about plotlies _array_to_b64str --- dash_3d_viewer/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/dash_3d_viewer/utils.py b/dash_3d_viewer/utils.py index 5dd2e8b..61846e1 100644 --- a/dash_3d_viewer/utils.py +++ b/dash_3d_viewer/utils.py @@ -11,6 +11,9 @@ def gen_random_id(n=6): def img_array_to_uri(img_array): img_array = skimage.util.img_as_ubyte(img_array) + # todo: leverage this Plotly util once it becomes part of the public API (also drops the Pillow dependency) + # from plotly.express._imshow import _array_to_b64str + # return _array_to_b64str(img_array) img_pil = PIL.Image.fromarray(img_array) uri = ImageUriValidator.pil_image_to_uri(img_pil) return uri From dc27145fba12171d927fe99d05fef03836a0e146 Mon Sep 17 00:00:00 2001 From: Almar Klein Date: Thu, 29 Oct 2020 16:55:16 +0100 Subject: [PATCH 11/11] update readme some more --- README.md | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 7902d49..e7a5217 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,8 @@ # dash-3d-viewer -A tool to make it easy to build slice-views on 3D image data. +A tool to make it easy to build slice-views on 3D image data, in Dash apps. + +The API is currently a WIP. ## Installation @@ -8,8 +10,23 @@ A tool to make it easy to build slice-views on 3D image data. Eventually, this would be pip-installable. For now, use the developer workflow. +## Usage + +TODO, see the examples. + + +## License + +This code is distributed under MIT license. + + ## Developers -To run the examples: -with an env that has the appropriate requirements, from the repo's root directory, run `python example/xxxx.py`. +* Make sure that you have Python with the appropriate dependencies installed, e.g. via `venv`. +* Run `pip install -e .` to do an in-place install of the package. +* Run the examples using e.g. `python examples/slicer_with_1_view.py` + +* Use `black .` to autoformat. +* Use `flake8 .` to lint. +* Use `pytest .` to run the tests.