Skip to content

Commit

Permalink
cleaner interface
Browse files Browse the repository at this point in the history
  • Loading branch information
conradry committed Mar 4, 2022
1 parent 159c185 commit f8074c4
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 63 deletions.
46 changes: 21 additions & 25 deletions empanada_napari/_merge_split_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,29 @@
from qtpy.QtWidgets import QWidget, QHBoxLayout, QPushButton
from magicgui import magic_factory, magicgui
from empanada.array_utils import merge_boxes, crop_and_binarize
from skimage.measure import regionprops
from skimage.measure import regionprops
#from napari_tools_menu import register_function
import napari
import numpy as np
import dask.array as da

def delete_labels():

@magicgui(
call_button='Run',
call_button='Delete labels',
layout='vertical',
)

def widget(
viewer: napari.viewer.Viewer,
labels_layer: napari.layers.Labels,
points_layer: napari.layers.Points,

):
if points_layer is None:
points_layer = viewer.add_points([])
points_layer.mode = 'ADD'
points_layer.mode = 'ADD'
return

labels = labels_layer.data
world_points = points_layer.data

Expand Down Expand Up @@ -54,23 +53,22 @@ def widget(
return widget

def merge_labels():

@magicgui(
call_button='Run',
call_button='Merge labels',
layout='vertical',
)

def widget(
viewer: napari.viewer.Viewer,
labels_layer: napari.layers.Labels,
points_layer: napari.layers.Points,

):
if points_layer is None:
points_layer = viewer.add_points([])
points_layer.mode = 'ADD'
points_layer.mode = 'ADD'
return

labels = labels_layer.data
world_points = points_layer.data

Expand Down Expand Up @@ -102,7 +100,7 @@ def widget(

def split_function():
@magicgui(
call_button='Run',
call_button='Split labels',
layout='vertical',
)

Expand Down Expand Up @@ -151,7 +149,7 @@ def widget(

#distance = ndi.distance_transform_edt(binary)
#coords = peak_local_max(distance, footprint=np.ones((3, 3)), labels=binary)

def translate_point_in_box(point, shed_box):
n = len(shed_box)
n_dim = n//2
Expand All @@ -166,8 +164,8 @@ def box_to_slice(shed_box):
for i in range(n_dim):
slices.append(slice(shed_box[i], shed_box[i+n_dim]))

return tuple(slices)
return tuple(slices)

mask = np.zeros(binary.shape, dtype=bool)
for i in local_points:
mask[translate_point_in_box(i, shed_box)] = True
Expand All @@ -182,7 +180,7 @@ def box_to_slice(shed_box):

slices = box_to_slice(shed_box)


if type(labels) == da.core.Array:
#new_labels[binary] += labels.max()
max_label = 0
Expand All @@ -199,9 +197,8 @@ def box_to_slice(shed_box):
return widget

def split_widget_distance():

@magicgui(
call_button='Run',
call_button='Split labels by distance watershed',
layout='vertical',
min_distance=dict(widget_type='Slider', label='Minimum Distance', min=1, max=100, value=10, tooltip='Min Distance between Markers'),
)
Expand Down Expand Up @@ -266,14 +263,14 @@ def box_to_slice(shed_box):
distance = ndi.distance_transform_edt(binary)

if np.squeeze(distance).ndim == distance.ndim - 1:
coords = peak_local_max(np.squeeze(distance), min_distance=min_distance)
coords = peak_local_max(np.squeeze(distance), min_distance=min_distance)
mask = np.zeros(np.squeeze(distance).shape, dtype=bool)
mask[tuple(coords.T)] = True
mask[tuple(coords.T)] = True
mask = mask[None]
else:
coords = peak_local_max(distance, min_distance=min_distance)
mask = np.zeros(distance.shape, dtype=bool)
mask[tuple(coords.T)] = True
mask[tuple(coords.T)] = True

markers, _ = ndi.label(mask)

Expand All @@ -291,15 +288,15 @@ def box_to_slice(shed_box):
labels[slices][binary] = new_labels[binary] + labels.max()

labels_layer.data = labels
points_layer.data = []
points_layer.data = []

print('Done')
print('Done')

return widget

def jump_to_label():
@magicgui(
call_button='Run',
call_button='Jump to label',
layout='vertical',
label_id=dict(widget_type='LineEdit', label='Label ID', value='1', tooltip='Label to jump to'),
)
Expand Down Expand Up @@ -365,4 +362,3 @@ def split_labels_widget():
@napari_hook_implementation(specname='napari_experimental_provide_dock_widget')
def delete_labels_widget():
return delete_labels, {'name': 'Delete Labels'}

71 changes: 33 additions & 38 deletions empanada_napari/_volume_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,62 +53,63 @@ def orthoplane_inference(engine, volume):

return trackers_dict

gui_params = dict(
@magicgui(
label_head=dict(widget_type='Label', label=f'<h1 style="text-align:center"><img src="{logo}"></h1>'),
call_button='Run 3D Inference',
layout='vertical',
model_config=dict(widget_type='ComboBox', label='model', choices=list(model_configs.keys()), value=list(model_configs.keys())[0], tooltip='Model to use for inference'),
store_dir=dict(widget_type='FileEdit', value='no zarr storage', label='Zarr Directory (optional)', mode='d', tooltip='location to store segmentations on disk'),
use_gpu=dict(widget_type='CheckBox', text='Use GPU', value=device_count() >= 1, tooltip='If checked, run on GPU 0'),
multigpu=dict(widget_type='CheckBox', text='Multi GPU', value=False, tooltip='If checked, run on all available GPUs'),
parameters2d_head=dict(widget_type='Label', label=f'<h3 text-align="center">2D Parameters</h3>'),
downsampling=dict(widget_type='ComboBox', choices=[1, 2, 4, 8, 16, 32, 64], value=1, label='Image Downsampling', tooltip='Downsampling factor to apply before inference'),
confidence_thr=dict(widget_type='FloatSpinBox', value=0.5, min=0.1, max=0.9, step=0.1, label='Segmentation Confidence Thr'),
center_confidence_thr=dict(widget_type='FloatSpinBox', value=0.1, min=0.05, max=0.9, step=0.05, label='Center Confidence Thr'),
min_distance_object_centers=dict(widget_type='SpinBox', value=3, min=1, max=21, step=1, label='Centers Min Distance'),
fine_boundaries=dict(widget_type='CheckBox', text='Fine boundaries', value=False, tooltip='Finer boundaries between objects'),
semantic_only=dict(widget_type='CheckBox', text='Semantic only', value=False, tooltip='Only run semantic segmentation for all classes.'),
fine_boundaries=dict(widget_type='CheckBox', text='Fine Boundaries', value=False, tooltip='Finer boundaries between objects'),
semantic_only=dict(widget_type='CheckBox', text='Semantic Only', value=False, tooltip='Only run semantic segmentation for all classes.'),
parameters_stack_head=dict(widget_type='Label', label=f'<h3 text-align="center">Stack Parameters</h3>'),
median_slices=dict(widget_type='ComboBox', choices=[1, 3, 5, 7, 9, 11], value=3, label='Median Filter Size', tooltip='Median filter size'),
pixel_vote_thr=dict(widget_type='SpinBox', value=2, min=1, max=3, step=1, label='Voxel Vote Thr'),
merge_iou_thr=dict(widget_type='FloatSpinBox', value=0.25, min=0.1, max=0.9, step=0.05, label= 'IoU Matching Thr'),
merge_ioa_thr=dict(widget_type='FloatSpinBox', value=0.25, min=0.1, max=0.9, step=0.05, label= 'IoA Matching Thr'),
min_size=dict(widget_type='SpinBox', value=500, min=0, max=1e6, step=100, label='Min Size (Voxels)'),
min_extent=dict(widget_type='SpinBox', value=5, min=0, max=1000, step=1, label='Min Box Extent'),
cluster_iou_thr=dict(widget_type='FloatSpinBox', value=0.75, min=0.1, max=0.9, step=0.05, label='Cluster IoU Thr'),
allow_one_view=dict(widget_type='CheckBox', text='Allow detections from 1 stack', value=False, tooltip='Whether to allow detections into consensus that were picked up by inference in just 1 stack'),
maximum_objects_per_class=dict(widget_type='LineEdit', value='100000', label='Max objects per class'),
return_panoptic=dict(widget_type='CheckBox', text='Return panoptic', value=False, tooltip='whether to return the panoptic segmentations'),
orthoplane=dict(widget_type='CheckBox', text='Run orthoplane', value=False, tooltip='whether to run orthoplane inference'),
)

gui_params['use_gpu'] = dict(widget_type='CheckBox', text='Use GPU', value=device_count() >= 1, tooltip='If checked, run on GPU 0')
gui_params['multigpu'] = dict(widget_type='CheckBox', text='Multi GPU', value=False, tooltip='If checked, run on all available GPUs')
gui_params['store_dir']=dict(widget_type='FileEdit', value='no zarr storage', label='Zarr Directory (optional)', mode='d', tooltip='location to store segmentations on disk')
maximum_objects_per_class=dict(widget_type='LineEdit', value='100000', label='Max objects per class in 3D'),
@magicgui(
label_head=dict(widget_type='Label', label=f'<h1 style="text-align:center"><img src="{logo}"></h1>'),
call_button='Run 3D Inference',
layout='vertical',
**gui_params
parameters_ortho_head=dict(widget_type='Label', label=f'<h3 text-align="center">Ortho-plane Parameters (Optional)</h3>'),
orthoplane=dict(widget_type='CheckBox', text='Run ortho-plane', value=False, tooltip='Whether to run orthoplane inference'),
return_panoptic=dict(widget_type='CheckBox', text='Return xy, xz, yz stacks', value=False, tooltip='Whether to return the inference stacks.'),
pixel_vote_thr=dict(widget_type='SpinBox', value=2, min=1, max=3, step=1, label='Voxel Vote Thr Out of 3', tooltip='Number of votes out of 3 for a voxel to be labeled in the consensus'),
allow_one_view=dict(widget_type='CheckBox', text='Permit detections found in 1 stack into consensus', value=False, tooltip='Whether to allow detections into consensus that were picked up by inference in just 1 stack')
)
def widget(
viewer: napari.viewer.Viewer,
label_head,
image_layer: Image,
model_config,
store_dir,
use_gpu,
multigpu,

parameters2d_head,
downsampling,
confidence_thr,
center_confidence_thr,
min_distance_object_centers,
fine_boundaries,
semantic_only,

parameters_stack_head,
median_slices,
pixel_vote_thr,
merge_iou_thr,
merge_ioa_thr,
min_size,
min_extent,
cluster_iou_thr,
allow_one_view,
maximum_objects_per_class,
return_panoptic,

parameters_ortho_head,
orthoplane,
use_gpu,
multigpu,
store_dir
return_panoptic,
pixel_vote_thr,
allow_one_view
):
# load the model config
model_config_name = model_config
Expand Down Expand Up @@ -140,8 +141,6 @@ def widget(
nms_kernel=min_distance_object_centers,
nms_threshold=center_confidence_thr,
confidence_thr=confidence_thr,
merge_iou_thr=merge_iou_thr,
merge_ioa_thr=merge_ioa_thr,
min_size=min_size,
min_extent=min_extent,
fine_boundaries=fine_boundaries,
Expand All @@ -159,8 +158,6 @@ def widget(
nms_kernel=min_distance_object_centers,
nms_threshold=center_confidence_thr,
confidence_thr=confidence_thr,
merge_iou_thr=merge_iou_thr,
merge_ioa_thr=merge_ioa_thr,
min_size=min_size,
min_extent=min_extent,
fine_boundaries=fine_boundaries,
Expand All @@ -180,8 +177,6 @@ def widget(
nms_kernel=min_distance_object_centers,
nms_threshold=center_confidence_thr,
confidence_thr=confidence_thr,
merge_iou_thr=merge_iou_thr,
merge_ioa_thr=merge_ioa_thr,
min_size=min_size,
min_extent=min_extent,
fine_boundaries=fine_boundaries,
Expand Down Expand Up @@ -238,8 +233,8 @@ def start_postprocess_worker(*args):

def start_consensus_worker(trackers_dict):
consensus_worker = tracker_consensus(
trackers_dict, store_url, model_config, label_divisor=maximum_objects_per_class, pixel_vote_thr=pixel_vote_thr,
cluster_iou_thr=cluster_iou_thr, allow_one_view=allow_one_view,
trackers_dict, store_url, model_config, label_divisor=maximum_objects_per_class,
pixel_vote_thr=pixel_vote_thr, allow_one_view=allow_one_view,
min_size=min_size, min_extent=min_extent, dtype=widget.engine.dtype
)
consensus_worker.yielded.connect(_new_class_stack)
Expand Down
File renamed without changes.
3 changes: 3 additions & 0 deletions empanada_napari/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@

from empanada_napari.utils import Preprocessor

MODEL_DIR = os.path.join(os.path.expanduser('~'), '.empanada/configs')
torch.hub.set_dir(MODEL_DIR)

def instance_relabel(tracker):
r"""Relabels instances starting from 1"""
instance_id = 1
Expand Down
3 changes: 3 additions & 0 deletions empanada_napari/multigpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@

from empanada_napari.utils import Preprocessor

MODEL_DIR = os.path.join(os.path.expanduser('~'), '.empanada/configs')
torch.hub.set_dir(MODEL_DIR)

#----------------------------------------------------------
# Utilities for all gathering outputs from each GPU process
#----------------------------------------------------------
Expand Down

0 comments on commit f8074c4

Please sign in to comment.