diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 5b71c79d34a..b4a528c5478 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -36,17 +36,16 @@ def horizontal_flip_bounding_box( ) -> torch.Tensor: shape = bounding_box.shape - # TODO: Investigate if it makes sense from a performance perspective to have an implementation for every - # BoundingBoxFormat instead of converting back and forth - bounding_box = convert_format_bounding_box( - bounding_box.clone(), old_format=format, new_format=features.BoundingBoxFormat.XYXY, inplace=True - ).reshape(-1, 4) + bounding_box = bounding_box.clone().reshape(-1, 4) - bounding_box[:, [0, 2]] = spatial_size[1] - bounding_box[:, [2, 0]] + if format == features.BoundingBoxFormat.XYXY: + bounding_box[:, [2, 0]] = bounding_box[:, [0, 2]].sub_(spatial_size[1]).neg_() + elif format == features.BoundingBoxFormat.XYWH: + bounding_box[:, 0].add_(bounding_box[:, 2]).sub_(spatial_size[1]).neg_() + else: # format == features.BoundingBoxFormat.CXCYWH: + bounding_box[:, 0].sub_(spatial_size[1]).neg_() - return convert_format_bounding_box( - bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, inplace=True - ).reshape(shape) + return bounding_box.reshape(shape) def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor: @@ -75,17 +74,16 @@ def vertical_flip_bounding_box( ) -> torch.Tensor: shape = bounding_box.shape - # TODO: Investigate if it makes sense from a performance perspective to have an implementation for every - # BoundingBoxFormat instead of converting back and forth - bounding_box = convert_format_bounding_box( - bounding_box.clone(), old_format=format, new_format=features.BoundingBoxFormat.XYXY, inplace=True - ).reshape(-1, 4) + bounding_box = bounding_box.clone().reshape(-1, 4) - bounding_box[:, [1, 3]] = spatial_size[0] - bounding_box[:, [3, 1]] + if format == features.BoundingBoxFormat.XYXY: + bounding_box[:, [1, 3]] = bounding_box[:, [3, 1]].sub_(spatial_size[0]).neg_() + elif format == features.BoundingBoxFormat.XYWH: + bounding_box[:, 1].add_(bounding_box[:, 3]).sub_(spatial_size[0]).neg_() + else: # format == features.BoundingBoxFormat.CXCYWH: + bounding_box[:, 1].sub_(spatial_size[0]).neg_() - return convert_format_bounding_box( - bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, inplace=True - ).reshape(shape) + return bounding_box.reshape(shape) def vertical_flip_video(video: torch.Tensor) -> torch.Tensor: