Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use categorical mask #204

Merged
merged 4 commits into from
Oct 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/data_gradients/dataset_adapters/formatters/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from data_gradients.dataset_adapters.utils import check_all_integers
from data_gradients.dataset_adapters.formatters.base import BatchFormatter
from data_gradients.dataset_adapters.formatters.utils import check_images_shape, ensure_channel_first, drop_nan
from data_gradients.dataset_adapters.formatters.utils import check_images_shape, ensure_channel_first
from data_gradients.dataset_adapters.config.data_config import DetectionDataConfig
from data_gradients.dataset_adapters.formatters.utils import DatasetFormatError

Expand Down Expand Up @@ -60,7 +60,6 @@ def format(self, images: Tensor, labels: Tensor) -> Tuple[Tensor, List[Tensor]]:
images *= 255
images = images.to(torch.uint8)

labels = drop_nan(labels)
labels = self.ensure_labels_shape(annotated_bboxes=labels, batch_size=images.shape[0])

# Labels format transformations are only relevant if we have labels
Expand Down
147 changes: 53 additions & 94 deletions src/data_gradients/dataset_adapters/formatters/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from torch import Tensor

from data_gradients.dataset_adapters.formatters.base import BatchFormatter
from data_gradients.dataset_adapters.utils import check_all_integers, to_one_hot
from data_gradients.dataset_adapters.utils import check_all_integers
from data_gradients.dataset_adapters.config.data_config import SegmentationDataConfig
from data_gradients.dataset_adapters.formatters.utils import DatasetFormatError, check_images_shape, ensure_channel_first, drop_nan
from data_gradients.dataset_adapters.formatters.utils import DatasetFormatError, check_images_shape, ensure_channel_first


class SegmentationBatchFormatter(BatchFormatter):
Expand Down Expand Up @@ -35,11 +35,11 @@ def __init__(
def format(self, images: Tensor, labels: Tensor) -> Tuple[Tensor, Tensor]:
"""Validate batch images and labels format, and ensure that they are in the relevant format for segmentation.

:param images: Batch of images, in (BS, ...) format
:param labels: Batch of labels, in (BS, ...) format
:param images: Batch of images, in (BS, ...) format, or single sample
:param labels: Batch of labels, in (BS, ...) format, or single sample
:return:
- images: Batch of images already formatted into (BS, C, H, W)
- labels: Batch of labels already formatted into (BS, N, H, W)
- labels: Batch of labels already formatted into (BS, H, W) - categorical representation
"""

if self.class_ids_to_ignore is None:
Expand All @@ -52,107 +52,66 @@ def format(self, images: Tensor, labels: Tensor) -> Tuple[Tensor, Tensor]:
images = images.unsqueeze(0)
labels = labels.unsqueeze(0)

images = drop_nan(images)
labels = drop_nan(labels)

images = ensure_channel_first(images, n_image_channels=self.get_n_image_channels(images=images))
labels = ensure_channel_first(labels, n_image_channels=self.get_n_image_channels(images=images))

images = check_images_shape(images, n_image_channels=self.get_n_image_channels(images=images))

labels = self.validate_labels_dim(labels, n_classes=self.data_config.get_n_classes(), ignore_labels=self.ignore_labels)
labels = self.ensure_hard_labels(labels, n_classes=self.data_config.get_n_classes(), threshold_value=self.threshold_value)

if self.require_onehot(labels=labels, n_classes=self.data_config.get_n_classes()):
labels = to_one_hot(labels, n_classes=self.data_config.get_n_classes())
images = self._format_images(images)
labels = self._format_labels(labels)

for class_id_to_ignore in self.class_ids_to_ignore:
labels[:, class_id_to_ignore, ...] = 0

if 0 <= images.min() and images.max() <= 1:
images *= 255
images = images.to(torch.uint8)
elif images.min() < 0: # images were normalized with some unknown mean and std
images -= images.min()
images /= images.max()
images *= 255
images = images.to(torch.uint8)
labels[labels == class_id_to_ignore] = -1

return images, labels

def check_is_batch(self, images: Tensor, labels: Tensor) -> bool:
if images.ndim == 4 or labels.ndim == 4:
# if less any dim is 4, we know it's a batch
self.data_config.is_batch = True
return self.data_config.is_batch
elif images.ndim == 2 or labels.ndim == 2:
# If image or mask only includes 2 dims, we can guess it's a single sample
self.data_config.is_batch = False
return self.data_config.is_batch
else:
# Otherwise, we need to ask the user
hint = f" - Image shape: {images.shape}\n - Mask shape: {labels.shape}"
return self.data_config.get_is_batch(hint=hint)

@staticmethod
def ensure_hard_labels(labels: Tensor, n_classes: int, threshold_value: float) -> Tensor:
unique_values = torch.unique(labels)

if check_all_integers(unique_values):
return labels
elif 0 <= min(unique_values) and max(unique_values) <= 1 and check_all_integers(unique_values * 255):
return labels * 255
hint = f"Image shape: {images.shape}\nMask shape: {labels.shape}"
self.data_config.is_batch = self.data_config.get_is_batch(hint=hint)
return self.data_config.is_batch

def _format_images(self, images: Tensor) -> Tensor:
images = ensure_channel_first(images, n_image_channels=self.get_n_image_channels(images=images))
images = check_images_shape(images, n_image_channels=self.get_n_image_channels(images=images))
images = adjust_image_values(images)
return images

def _format_labels(self, labels: Tensor) -> Tensor:
labels = labels.squeeze((1, -1)) # If (BS, 1, H, W) or (BS, H, W, 1) -> (BS, H, W)
if labels.ndim == 3:
labels = ensure_hard_labels(labels, n_classes=self.data_config.get_n_classes(), threshold_value=self.threshold_value)
elif labels.ndim == 4:
labels = convert_to_categorical(labels, n_classes=self.data_config.get_n_classes())
else:
if n_classes > 1:
raise DatasetFormatError(f"Not supporting soft-labeling for number of classes > 1!\nGot {n_classes} classes.")
labels = SegmentationBatchFormatter.binary_mask_above_threshold(labels=labels, threshold_value=threshold_value)
raise DatasetFormatError(f"Labels should be either 3D (categorical) or 4D (onehot), but got {labels.ndim}D")
return labels

@staticmethod
def is_soft_labels(labels: Tensor) -> bool:
unique_values = torch.unique(labels)
if check_all_integers(unique_values):
return False
elif 0 <= min(unique_values) and max(unique_values) <= 1 and check_all_integers(unique_values * 255):
return False
return True

@staticmethod
def require_onehot(labels: Tensor, n_classes: int) -> bool:
is_binary = n_classes == 1
is_onehot = labels.shape[1] == n_classes
return not (is_binary or is_onehot)

@staticmethod
def validate_labels_dim(labels: Tensor, n_classes: int, ignore_labels: List[int]) -> Tensor:
"""
Validating labels dimensions are (BS, N, H, W) where N is either 1 or number of valid classes
:param labels: Tensor [BS, W, H] or [BS, N, W, H]
:return: labels: Tensor [BS, N, W, H]
"""
if labels.dim() == 3:
return labels # Assuming [BS, W, H]
elif labels.dim() == 4:
total_n_classes = n_classes + len(ignore_labels)

# Check if first or last dim is 1; it can be due to mask being saved with [1, H, W] or [H, W, 1]
if labels.shape[1] == 1 and labels.shape[1] != total_n_classes:
return labels.squeeze(1) # [BS, 1, W, H] -> [BS, W, H] (categorical representation)
elif labels.shape[-1] == 1 and labels.shape[-1] != total_n_classes:
return labels.squeeze(-1) # [BS, W, H, 1] -> [BS, W, H] (categorical representation)
elif not (labels.shape[1] == total_n_classes or labels.shape[-1] == total_n_classes):
# We have 4 dims, but it's neither [BS, N, W, H] nor [BS, W, H, N]
raise DatasetFormatError(f"Labels batch shape should be [BS, N, W, H] where N is n_classes. Got {labels.shape}")
return labels
else:
raise DatasetFormatError(f"Labels batch shape should be [Channels x Width x Height] or [BatchSize x Channels x Width x Height]. Got {labels.shape}")

@staticmethod
def binary_mask_above_threshold(labels: Tensor, threshold_value: float) -> Tensor:
# Support only for binary segmentation
labels = torch.where(
labels > threshold_value,
torch.ones_like(labels),
torch.zeros_like(labels),
)

def adjust_image_values(images: Tensor) -> Tensor:
if 0 <= images.min() and images.max() <= 1:
return (images * 255).to(torch.uint8)
elif images.min() < 0:
images = (images - images.min()) / images.max() * 255
return images.to(torch.uint8)
return images


def ensure_hard_labels(labels: Tensor, n_classes: int, threshold_value: float) -> Tensor:
unique_values = torch.unique(labels)
if check_all_integers(unique_values):
return labels
elif 0 <= min(unique_values) and max(unique_values) <= 1 and check_all_integers(unique_values * 255):
return labels * 255
elif n_classes == 1:
return binary_mask_above_threshold(labels, threshold_value)
raise DatasetFormatError(f"Not supporting soft-labeling for number of classes > 1! Got {n_classes} classes.")


def convert_to_categorical(labels: Tensor, n_classes: int) -> Tensor:
if labels.shape[1] == n_classes:
labels = labels.permute(0, 2, 3, 1) # (BS, C, H, W) -> (BS, H, W, C)
return torch.argmax(labels, dim=-1)


def binary_mask_above_threshold(labels: Tensor, threshold_value: float) -> Tensor:
return torch.where(labels > threshold_value, torch.ones_like(labels), torch.zeros_like(labels))
16 changes: 0 additions & 16 deletions src/data_gradients/dataset_adapters/formatters/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import torch
from torch import Tensor

from data_gradients.dataset_adapters.utils import channels_last_to_first
Expand All @@ -8,21 +7,6 @@ class DatasetFormatError(Exception):
...


def drop_nan(tensor: Tensor) -> Tensor:
"""Remove rows containing NaN values from a given PyTorch tensor.

:param tensor: Tensor with shape (N, M) where N is the number of rows and M is the number of columns.
:return: Tensor with the same number of columns as the input tensor, but without rows containing NaN.
"""
nans = torch.isnan(tensor)
if nans.any():
nan_indices = set(nans.nonzero()[:, 0].tolist())
all_indices = set(i for i in range(tensor.shape[0]))
valid_indices = all_indices - nan_indices
return tensor[valid_indices]
return tensor


def ensure_channel_first(images: Tensor, n_image_channels: int) -> Tensor:
"""Images should be [BS, C, H, W]. If [BS, W, H, C], permute

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from data_gradients.utils.image_processing import resize_in_chunks
from data_gradients.utils.data_classes import SegmentationSample
from data_gradients.feature_extractors.common.heatmap import BaseClassHeatmap
from data_gradients.utils.segmentation import mask_to_onehot


@register_feature_extractor()
Expand All @@ -22,14 +23,16 @@ def update(self, sample: SegmentationSample):
if not self.class_names:
self.class_names = sample.class_names

# Objects are resized to a fix size
mask = sample.mask.transpose((1, 2, 0))
# (H, W) -> (C, H, W)
n_classes = np.max(sample.mask)
mask_onehot = mask_to_onehot(mask_categorical=sample.mask, n_classes=n_classes)
mask_onehot = mask_onehot.transpose((1, 2, 0)) # H, W, C -> C, H, W

target_size = self.heatmap_shape[1], self.heatmap_shape[0]
resized_masks = resize_in_chunks(img=mask.astype(np.uint8), size=target_size, interpolation=cv2.INTER_LINEAR).astype(np.uint8)
resized_masks = resized_masks.transpose((2, 0, 1))
resized_masks = resize_in_chunks(img=mask_onehot, size=target_size, interpolation=cv2.INTER_LINEAR).astype(np.uint8)
resized_masks = resized_masks.transpose((2, 0, 1)) # H, W, C -> C, H, W

split_heatmap = self.heatmaps_per_split.get(sample.split, np.zeros((len(sample.class_names), *self.heatmap_shape)))
split_heatmap = self.heatmaps_per_split.get(sample.split, np.zeros((n_classes, *self.heatmap_shape)))
split_heatmap += resized_masks
self.heatmaps_per_split[sample.split] = split_heatmap

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,17 @@ class SegmentationComponentsErosion(AbstractFeatureExtractor):
def __init__(self):
self.kernel_shape = (3, 3)
self.data = []
self.class_names = None

def update(self, sample: SegmentationSample):
opened_mask = self.apply_mask_opening(mask=sample.mask, kernel_shape=self.kernel_shape)
contours_after_opening = contours.get_contours(opened_mask)
from data_gradients.utils.segmentation import mask_to_onehot

onehot_mask = mask_to_onehot(mask_categorical=sample.mask, n_classes=len(sample.class_names))
opened_onehot_mask = self.apply_mask_opening(onehot_mask=onehot_mask, kernel_shape=self.kernel_shape)
opened_categorical_mask = np.argmax(opened_onehot_mask, axis=-1)

# TODO: This will be removed once we support sparse class representation (e.g. class_ids=[0, 4, 255])
contours_after_opening = contours.get_contours(label=opened_categorical_mask, class_ids=range(len(sample.class_names)))

if sample.contours:
n_components_without_opening = sum(1 for class_channel in sample.contours for _contour in class_channel)
Expand Down Expand Up @@ -71,7 +78,7 @@ def description(self) -> str:
)
# FIXME: Can this also lead to increase of components, when breaking existing component into 2?

def apply_mask_opening(self, mask: np.ndarray, kernel_shape: Tuple[int, int]) -> np.ndarray:
def apply_mask_opening(self, onehot_mask: np.ndarray, kernel_shape: Tuple[int, int]) -> np.ndarray:
"""Opening is just another name of erosion followed by dilation.

It is useful in removing noise, as we explained above. Here we use the function, cv2.morphologyEx(). See [Official OpenCV documentation](
Expand All @@ -81,6 +88,6 @@ def apply_mask_opening(self, mask: np.ndarray, kernel_shape: Tuple[int, int]) ->
:param kernel_shape: Shape of the kernel used for Opening (Eroded + Dilated)
:return: Opened (Eroded + Dilated) mask in shape [N, H, W]
"""
masks = mask.transpose((1, 2, 0)).astype(np.uint8)
masks = onehot_mask.transpose((1, 2, 0)).astype(np.uint8)
masks = cv2.morphologyEx(masks, cv2.MORPH_OPEN, np.ones(kernel_shape, np.uint8))
return masks.transpose((2, 0, 1))
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,9 @@ def _prepare_sample_visualization(self, sample: SegmentationSample) -> np.ndarra

image = sample.image_as_rgb

# Onehot to categorical labels
categorical_labels = np.argmax(sample.mask, axis=0)

# Normalize the labels to the range [0, 255]
normalized_labels = np.ceil((categorical_labels * 255) / np.max(categorical_labels))
mask = sample.mask.astype(np.float32)
normalized_labels = np.ceil((mask * 255) / np.max(mask))
normalized_labels = normalized_labels[:, :, np.newaxis].repeat(3, axis=-1)

# Stack the image and label color map horizontally or vertically
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def preprocess_samples(self, dataset: Iterable[SupportedDataType], split: str) -
labels = np.uint8(labels.cpu().numpy())

for image, mask in zip(images, labels):
contours = get_contours(mask)
# TODO: This will be removed once we support sparse class representation (e.g. class_ids=[0, 4, 255])
contours = get_contours(mask, class_ids=range(len(self.data_config.get_class_names())))

yield SegmentationSample(
image=image,
Expand Down
21 changes: 10 additions & 11 deletions src/data_gradients/sample_preprocessor/utils/contours.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,32 @@
from typing import List, Tuple, Dict
from typing import List, Tuple, Dict, Sequence

import cv2
import numpy as np

from data_gradients.utils.data_classes.contour import Contour


def get_contours(label: np.ndarray) -> List[list]:
def get_contours(label: np.ndarray, class_ids: Sequence[int]) -> List[list]:
"""
Find contours in each class-channel individually, using opencv findContours method
:param label: Tensor [N, W, H] where N is number of valid classes
:return: List with the shape [N, Nc, P, 1, 2] where N is number of valid classes, Nc are number of contours
:param label: Categorical representation of mask, of shape [H, W]
:param class_ids: List of class-ids.
:return: List with the shape [N, Nc, P, 1, 2] where N is number of valid classes, Nc are number of contours
per class, P are number of points for each contour and (1, 2) are set of points.
"""
if not isinstance(label, np.ndarray):
raise TypeError(f"Expected numpy.ndarray, got {type(label)}")

# Type to INT8 as for Index array
label = label.astype(np.uint8, copy=False)

all_onehot_contour = []
# For each class
for class_channel in range(label.shape[0]):
# Get tensor [class, W, H]
onehot = label[class_channel, ...]

for class_channel in class_ids:

onehot = (label == class_channel).astype(np.uint8) # Boolean mask of shape [H, W]
if np.max(onehot) == 0:
continue
# Find contours and return shape of [N, P, 1, 2] where N is number of contours and P list of points
onehot_contour, _ = cv2.findContours(onehot, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
onehot_contour, _ = cv2.findContours(onehot * 1, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# Check if contour is OK
valid_onehot_contours = get_valid_contours(onehot_contour, class_channel)
if len(valid_onehot_contours):
Expand Down
2 changes: 1 addition & 1 deletion src/data_gradients/utils/data_classes/data_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class SegmentationSample(ImageSample):
:attr sample_id: The unique identifier of the sample. Could be the image path or the image name.
:attr split: The name of the dataset split. Could be "train", "val", "test", etc.
:attr image: np.ndarray of shape [H,W,C] - The image as a numpy array with channels last.
:attr mask: np.ndarray of shape [N, H, W] representing one-hot encoded mask for each class.
:attr mask: np.ndarray of shape [H, W], categorical representation of the mask.
:attr contours: A list of contours for each class in the mask.
:attr class_names: List of all class names in the dataset. The index should represent the class_id.
"""
Expand Down
16 changes: 16 additions & 0 deletions src/data_gradients/utils/segmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import numpy as np


def mask_to_onehot(mask_categorical: np.ndarray, n_classes: int) -> np.ndarray:
"""Convert a categorical segmentation mask to its one-hot encoded representation.

:param mask_categorical: Categorical representation of mask (H, W).
:param n_classes: The total number of classes in the dataset.
:return: Onehot representation of mask (C, H, W)
"""
onehot_mask = np.zeros((n_classes, mask_categorical.shape[0], mask_categorical.shape[1]), dtype=np.int8)

for c in range(n_classes):
onehot_mask[c] = mask_categorical == c

return onehot_mask.astype(np.uint8)
Loading