Skip to content

Commit f7b10c7

Browse files
ziw-liuedyoshikun
authored andcommitted
filter empty detections (#74)
1 parent cf0333d commit f7b10c7

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

viscy/evaluation/evaluation_metrics.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,15 @@ def labels_to_masks(labels: torch.ShortTensor) -> torch.BoolTensor:
126126
"""
127127
if labels.ndim != 2:
128128
raise ValueError(f"Labels must be 2D, got shape {labels.shape}.")
129+
segments = torch.unique(labels)
130+
n_instances = segments.numel() - 1
129131
masks = torch.zeros(
130-
(labels.max(), *labels.shape), dtype=torch.bool, device=labels.device
132+
(n_instances, *labels.shape), dtype=torch.bool, device=labels.device
131133
)
132134
# TODO: optimize this?
133-
for segment in range(labels.max()):
135+
for s, segment in enumerate(segments):
134136
# start from label value 1, i.e. skip background label
135-
masks[segment] = labels == (segment + 1)
137+
masks[s - 1] = labels == segment
136138
return masks
137139

138140

0 commit comments

Comments
 (0)