From 6ecb9272be34b6331aa3181881fd2b3c5bfa409f Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 3 Nov 2022 11:21:32 +0000 Subject: [PATCH 1/6] Bbox resize optimization --- .../prototype/transforms/functional/_geometry.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 6656ecfe85b..8e90b1373e1 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -181,9 +181,11 @@ def resize_bounding_box( ) -> Tuple[torch.Tensor, Tuple[int, int]]: old_height, old_width = spatial_size new_height, new_width = _compute_resized_output_size(spatial_size, size=size, max_size=max_size) - ratios = torch.tensor((new_width / old_width, new_height / old_height), device=bounding_box.device) + w_ratio = new_width / old_width + h_ratio = new_height / old_height + ratios = torch.tensor([w_ratio, h_ratio, w_ratio, h_ratio], device=bounding_box.device) return ( - bounding_box.reshape(-1, 2, 2).mul(ratios).to(bounding_box.dtype).reshape(bounding_box.shape), + bounding_box.mul(ratios).to(bounding_box.dtype), (new_height, new_width), ) @@ -367,6 +369,7 @@ def _affine_bounding_box_xyxy( # 3) Reshape transformed points to [N boxes, 4 points, x/y coords] # and compute bounding box from 4 transformed points: transformed_points = transformed_points.reshape(-1, 4, 2) + # TODO: check if aminmax could help here out_bbox_mins, _ = torch.min(transformed_points, dim=1) out_bbox_maxs, _ = torch.max(transformed_points, dim=1) out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1) @@ -388,6 +391,7 @@ def _affine_bounding_box_xyxy( new_points = torch.matmul(points, transposed_affine_matrix) tr, _ = torch.min(new_points, dim=0, keepdim=True) # Translate bounding boxes + # TODO: performance improvement by using inplace on intermediate results out_bboxes[:, 0::2] = out_bboxes[:, 0::2] - tr[:, 0] out_bboxes[:, 1::2] = out_bboxes[:, 1::2] - tr[:, 1] # Estimate meta-data for image with inverted=True and with center=[0,0] @@ -947,7 +951,7 @@ def perspective_bounding_box( # 2) Now let's transform the points using perspective matrices # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1) # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1) - + # TODO: Investigate potential optimizations by in-placing intermediate results, aminmax etc numer_points = torch.matmul(points, theta1.T) denom_points = torch.matmul(points, theta2.T) transformed_points = numer_points / denom_points @@ -1063,7 +1067,7 @@ def elastic_bounding_box( # Question (vfdev-5): should we rely on good displacement shape and fetch image size from it # Or add spatial_size arg and check displacement shape spatial_size = displacement.shape[-3], displacement.shape[-2] - + # TODO: Investigate potential optimizations by in-placing intermediate results, aminmax etc id_grid = _FT._create_identity_grid(list(spatial_size)).to(bounding_box.device) # We construct an approximation of inverse grid as inv_grid = id_grid - displacement # This is not an exact inverse of the grid From cef0a23a94965203ac8b2ec5210b8970fd1d79dd Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 3 Nov 2022 11:42:35 +0000 Subject: [PATCH 2/6] Other (untested) optimizations on `_affine_bounding_box_xyxy` and `elastic_bounding_box`. --- .../transforms/functional/_geometry.py | 26 +++++++------------ torchvision/transforms/functional_tensor.py | 6 ++--- 2 files changed, 13 insertions(+), 19 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 8e90b1373e1..6dcb1930cc3 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -369,9 +369,7 @@ def _affine_bounding_box_xyxy( # 3) Reshape transformed points to [N boxes, 4 points, x/y coords] # and compute bounding box from 4 transformed points: transformed_points = transformed_points.reshape(-1, 4, 2) - # TODO: check if aminmax could help here - out_bbox_mins, _ = torch.min(transformed_points, dim=1) - out_bbox_maxs, _ = torch.max(transformed_points, dim=1) + out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1) out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1) if expand: @@ -391,9 +389,8 @@ def _affine_bounding_box_xyxy( new_points = torch.matmul(points, transposed_affine_matrix) tr, _ = torch.min(new_points, dim=0, keepdim=True) # Translate bounding boxes - # TODO: performance improvement by using inplace on intermediate results - out_bboxes[:, 0::2] = out_bboxes[:, 0::2] - tr[:, 0] - out_bboxes[:, 1::2] = out_bboxes[:, 1::2] - tr[:, 1] + out_bboxes[:, 0::2].sub_(tr[:, 0]) + out_bboxes[:, 1::2].sub_(tr[:, 1]) # Estimate meta-data for image with inverted=True and with center=[0,0] affine_vector = _get_inverse_affine_matrix([0.0, 0.0], angle, translate, scale, shear) new_width, new_height = _FT._compute_affine_output_size(affine_vector, width, height) @@ -951,7 +948,6 @@ def perspective_bounding_box( # 2) Now let's transform the points using perspective matrices # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1) # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1) - # TODO: Investigate potential optimizations by in-placing intermediate results, aminmax etc numer_points = torch.matmul(points, theta1.T) denom_points = torch.matmul(points, theta2.T) transformed_points = numer_points / denom_points @@ -1067,23 +1063,21 @@ def elastic_bounding_box( # Question (vfdev-5): should we rely on good displacement shape and fetch image size from it # Or add spatial_size arg and check displacement shape spatial_size = displacement.shape[-3], displacement.shape[-2] - # TODO: Investigate potential optimizations by in-placing intermediate results, aminmax etc - id_grid = _FT._create_identity_grid(list(spatial_size)).to(bounding_box.device) + id_grid = _FT._create_identity_grid(list(spatial_size), bounding_box.device) # We construct an approximation of inverse grid as inv_grid = id_grid - displacement # This is not an exact inverse of the grid - inv_grid = id_grid - displacement + inv_grid = id_grid.sub_(displacement) # Get points from bboxes - points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2) - index_x = torch.floor(points[:, 0] + 0.5).to(dtype=torch.long) - index_y = torch.floor(points[:, 1] + 0.5).to(dtype=torch.long) + points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].ceil_().reshape(-1, 2) + index_x = points[:, 0].to(dtype=torch.long) + index_y = points[:, 1].to(dtype=torch.long) # Transform points: t_size = torch.tensor(spatial_size[::-1], device=displacement.device, dtype=displacement.dtype) - transformed_points = (inv_grid[0, index_y, index_x, :] + 1) * 0.5 * t_size - 0.5 + transformed_points = inv_grid[0, index_y, index_x, :].add_(1).mul_(0.5 * t_size).sub_(0.5) transformed_points = transformed_points.reshape(-1, 4, 2) - out_bbox_mins, _ = torch.min(transformed_points, dim=1) - out_bbox_maxs, _ = torch.max(transformed_points, dim=1) + out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1) out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype) return convert_format_bounding_box( diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index ca641faf161..9b4a1eb7342 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -940,8 +940,8 @@ def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool return img -def _create_identity_grid(size: List[int]) -> Tensor: - hw_space = [torch.linspace((-s + 1) / s, (s - 1) / s, s) for s in size] +def _create_identity_grid(size: List[int], device: torch.device) -> Tensor: + hw_space = [torch.linspace((-s + 1) / s, (s - 1) / s, s, device=device) for s in size] grid_y, grid_x = torch.meshgrid(hw_space, indexing="ij") return torch.stack([grid_x, grid_y], -1).unsqueeze(0) # 1 x H x W x 2 @@ -959,6 +959,6 @@ def elastic_transform( size = list(img.shape[-2:]) displacement = displacement.to(img.device) - identity_grid = _create_identity_grid(size) + identity_grid = _create_identity_grid(size, img.device) grid = identity_grid.to(img.device) + displacement return _apply_grid_transform(img, grid, interpolation, fill) From 8a657a98ac1ab51e67f49399f9611056534c664b Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 3 Nov 2022 11:46:32 +0000 Subject: [PATCH 3/6] fix conflict --- torchvision/prototype/transforms/functional/_geometry.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 6dcb1930cc3..9814d167159 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -948,6 +948,7 @@ def perspective_bounding_box( # 2) Now let's transform the points using perspective matrices # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1) # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1) + numer_points = torch.matmul(points, theta1.T) denom_points = torch.matmul(points, theta2.T) transformed_points = numer_points / denom_points From 27605e241c36124a912473b2c8dd336804b26821 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 3 Nov 2022 11:54:42 +0000 Subject: [PATCH 4/6] Reverting changes on elastic --- .../prototype/transforms/functional/_geometry.py | 14 ++++++++------ torchvision/transforms/functional_tensor.py | 6 +++--- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 66f5185da77..4250130a708 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -1065,21 +1065,23 @@ def elastic_bounding_box( # Question (vfdev-5): should we rely on good displacement shape and fetch image size from it # Or add spatial_size arg and check displacement shape spatial_size = displacement.shape[-3], displacement.shape[-2] - id_grid = _FT._create_identity_grid(list(spatial_size), bounding_box.device) + + id_grid = _FT._create_identity_grid(list(spatial_size)).to(bounding_box.device) # We construct an approximation of inverse grid as inv_grid = id_grid - displacement # This is not an exact inverse of the grid inv_grid = id_grid.sub_(displacement) # Get points from bboxes - points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].ceil_().reshape(-1, 2) - index_x = points[:, 0].to(dtype=torch.long) - index_y = points[:, 1].to(dtype=torch.long) + points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2) + index_x = torch.floor(points[:, 0] + 0.5).to(dtype=torch.long) + index_y = torch.floor(points[:, 1] + 0.5).to(dtype=torch.long) # Transform points: t_size = torch.tensor(spatial_size[::-1], device=displacement.device, dtype=displacement.dtype) - transformed_points = inv_grid[0, index_y, index_x, :].add_(1).mul_(0.5 * t_size).sub_(0.5) + transformed_points = (inv_grid[0, index_y, index_x, :] + 1) * 0.5 * t_size - 0.5 transformed_points = transformed_points.reshape(-1, 4, 2) - out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1) + out_bbox_mins, _ = torch.min(transformed_points, dim=1) + out_bbox_maxs, _ = torch.max(transformed_points, dim=1) out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype) return convert_format_bounding_box( diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 9b4a1eb7342..ca641faf161 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -940,8 +940,8 @@ def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool return img -def _create_identity_grid(size: List[int], device: torch.device) -> Tensor: - hw_space = [torch.linspace((-s + 1) / s, (s - 1) / s, s, device=device) for s in size] +def _create_identity_grid(size: List[int]) -> Tensor: + hw_space = [torch.linspace((-s + 1) / s, (s - 1) / s, s) for s in size] grid_y, grid_x = torch.meshgrid(hw_space, indexing="ij") return torch.stack([grid_x, grid_y], -1).unsqueeze(0) # 1 x H x W x 2 @@ -959,6 +959,6 @@ def elastic_transform( size = list(img.shape[-2:]) displacement = displacement.to(img.device) - identity_grid = _create_identity_grid(size, img.device) + identity_grid = _create_identity_grid(size) grid = identity_grid.to(img.device) + displacement return _apply_grid_transform(img, grid, interpolation, fill) From a2b96817497f25abc04061046392c7c6ecf75530 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 3 Nov 2022 11:55:46 +0000 Subject: [PATCH 5/6] revert one more change --- torchvision/prototype/transforms/functional/_geometry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 4250130a708..a20d934520e 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -1069,7 +1069,7 @@ def elastic_bounding_box( id_grid = _FT._create_identity_grid(list(spatial_size)).to(bounding_box.device) # We construct an approximation of inverse grid as inv_grid = id_grid - displacement # This is not an exact inverse of the grid - inv_grid = id_grid.sub_(displacement) + inv_grid = id_grid - displacement # Get points from bboxes points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].reshape(-1, 2) From fd7f0d59c7dc33f8f1163544748d7c3e8ec3bc4a Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 3 Nov 2022 12:33:57 +0000 Subject: [PATCH 6/6] Further improvement --- torchvision/prototype/transforms/functional/_geometry.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index a20d934520e..40fa904ade2 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -389,8 +389,7 @@ def _affine_bounding_box_xyxy( new_points = torch.matmul(points, transposed_affine_matrix) tr, _ = torch.min(new_points, dim=0, keepdim=True) # Translate bounding boxes - out_bboxes[:, 0::2].sub_(tr[:, 0]) - out_bboxes[:, 1::2].sub_(tr[:, 1]) + out_bboxes.sub_(tr.repeat((1, 2))) # Estimate meta-data for image with inverted=True and with center=[0,0] affine_vector = _get_inverse_affine_matrix([0.0, 0.0], angle, translate, scale, shear) new_width, new_height = _FT._compute_affine_output_size(affine_vector, width, height)