Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature request] transforms for object detection #3286

Open
ydcjeff opened this issue Jan 24, 2021 · 2 comments
Open

[feature request] transforms for object detection #3286

ydcjeff opened this issue Jan 24, 2021 · 2 comments

Comments

@ydcjeff
Copy link
Contributor

ydcjeff commented Jan 24, 2021

🚀 Feature

I would like to start adding/supporting transforms (both functional and class) for object detection, I know I can take some of them from references folder. But, it would nice to have OOTB. Here are a few basic transforms I would like to add first -

  • RandomHorizontalFlipWithBBox
  • RandomVerticalFlipWithBBox
  • LetterBox

Pitch

All of the above transforms will accept 2 arguments when they are called. This breaks the purpose of Compose and nn.Sequential, but currently aren't we writing custom Compose or nn.Sequential? So I think it's ok to start introducing necessary transforms taking 2 arguments for detection, segmentation, etc and let users write custom Compose or nn.Sequential the way they would to like to call the transforms.

Additional context

Current code:

class RandomHorizontalFlipWithBBox(nn.Module):
    def __init__(self, prob: float = 0.5):
        super().__init__()
        self.prob = prob

    def forward(self, img, target):
        if random.random() < self.prob:
            width = img.width
            xmin, xmax = target[..., 0], target[..., 2]
            diff = abs(xmax - xmin)
            target[..., 0] = width - xmin - diff
            target[..., 2] = width - xmax + diff
            return FT.hflip(img), target
        return img, target

    def __repr__(self):
        return self.__class__.__name__ + "(p={})".format(self.prob)
class RandomVerticalFlipWithBBox(nn.Module):
    def __init__(self, prob: float = 0.5):
        super().__init__()
        self.prob = prob

    def forward(self, img, target):
        if random.random() < self.prob:
            height = img.height
            ymin, ymax = target[..., 1], target[..., 3]
            diff = abs(ymax - ymin)
            target[..., 1] = height - ymin - diff
            target[..., 3] = height - ymax + diff
            return FT.vflip(img), target
        return img, target

    def __repr__(self):
        return self.__class__.__name__ + "(p={})".format(self.prob)
class LetterBox(nn.Module):
    """
    Make letter box transform to image and bounding box target.

    Args:
        size (int or tuple of int): the size of the transformed image.
    """

    def __init__(self, size: Union[int, Tuple[int]]):
        super().__init__()
        self.size = size
        if isinstance(size, int):
            self.size = (size, size)

    def forward(self, img: Image.Image, target: Union[np.ndarray, Tensor]):
        """
        Args:
            img (PIL Image): Image to be transformed.
            target (np.ndarray or Tensor): bounding box target to be transformed.

        Returns:
            tuple: (image, target)
        """
        old_width, old_height = img.size
        width, height = self.size

        ratio = min(width / old_width, height / old_height)
        new_width = int(old_width * ratio)
        new_height = int(old_height * ratio)
        img = T.functional.resize(img, (new_height, new_width))

        pad_x = (width - new_width) * 0.5
        pad_y = (height - new_height) * 0.5
        left, right = round(pad_x + 0.1), round(pad_x - 0.1)
        top, bottom = round(pad_y + 0.1), round(pad_y - 0.1)
        padding = (left, top, right, bottom)
        img = T.functional.pad(img, padding, 255 // 2)

        if isinstance(target, torch.Tensor):
            target[..., 0] = torch.round(ratio * target[..., 0]) + left
            target[..., 1] = torch.round(ratio * target[..., 1]) + top
            target[..., 2] = torch.round(ratio * target[..., 2]) + right
            target[..., 3] = torch.round(ratio * target[..., 3]) + bottom
        elif isinstance(target, np.ndarray):
            target[..., 0] = np.rint(ratio * target[..., 0]) + left
            target[..., 1] = np.rint(ratio * target[..., 1]) + top
            target[..., 2] = np.rint(ratio * target[..., 2]) + right
            target[..., 3] = np.rint(ratio * target[..., 3]) + bottom
        return img, target

    def __repr__(self):
        return self.__class__.__name__ + f"({self.size})"

Thank you!

cc @vfdev-5, @fmassa

@fmassa
Copy link
Member

fmassa commented Jan 25, 2021

Hi,

Thanks for opening this issue.

This has been in our radar for a while already, but we never really managed to find out the right balance between simplicity and generality about the API.
For example, about the API you proposed, it wouldn't be enough if we wanted to work on image + boxes + keypoints, or even image +segmentation map, so we would need a number of repeated implementations to cover the models in torchvision.

For an earlier attempt for the APIs, see #1406 and the discussion within.

I would love to hear your thoughts on this.

@zhiqwang
Copy link
Contributor

zhiqwang commented Feb 11, 2021

FYI, It seems that the existing batch_images in GeneralizedRCNNTransform plays the same role as the proposed LetterBox here.

def batch_images(self, images, size_divisible=32):
# type: (List[Tensor], int) -> Tensor
if torchvision._is_tracing():
# batch_images() does not export well to ONNX
# call _onnx_batch_images() instead
return self._onnx_batch_images(images, size_divisible)
max_size = self.max_by_axis([list(img.shape) for img in images])
stride = float(size_divisible)
max_size = list(max_size)
max_size[1] = int(math.ceil(float(max_size[1]) / stride) * stride)
max_size[2] = int(math.ceil(float(max_size[2]) / stride) * stride)
batch_shape = [len(images)] + max_size
batched_imgs = images[0].new_full(batch_shape, 0)
for img, pad_img in zip(images, batched_imgs):
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
return batched_imgs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants