diff --git a/viscy/evaluation/evaluation_metrics.py b/viscy/evaluation/evaluation_metrics.py index fb83c06b..921b0e4e 100644 --- a/viscy/evaluation/evaluation_metrics.py +++ b/viscy/evaluation/evaluation_metrics.py @@ -126,13 +126,15 @@ def labels_to_masks(labels: torch.ShortTensor) -> torch.BoolTensor: """ if labels.ndim != 2: raise ValueError(f"Labels must be 2D, got shape {labels.shape}.") + segments = torch.unique(labels) + n_instances = segments.numel() - 1 masks = torch.zeros( - (labels.max(), *labels.shape), dtype=torch.bool, device=labels.device + (n_instances, *labels.shape), dtype=torch.bool, device=labels.device ) # TODO: optimize this? - for segment in range(labels.max()): + for s, segment in enumerate(segments): # start from label value 1, i.e. skip background label - masks[segment] = labels == (segment + 1) + masks[s - 1] = labels == segment return masks