diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 5518322080f..12fa5288abc 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -1,4 +1,5 @@ import torch +from torch.nn.functional import conv2d from torchvision.prototype import features from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT @@ -111,6 +112,8 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float) if image.numel() == 0 or height <= 2 or width <= 2: return image + bound = _FT._max_value(image.dtype) + fp = image.is_floating_point() shape = image.shape if image.ndim > 4: @@ -119,7 +122,30 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float) else: needs_unsquash = False - output = _blend(image, _FT._blurred_degenerate_image(image), sharpness_factor) + # The following is a normalized 3x3 kernel with 1s in the edges and a 5 in the middle. + kernel_dtype = image.dtype if fp else torch.float32 + a, b = 1.0 / 13.0, 5.0 / 13.0 + kernel = torch.tensor([[a, a, a], [a, b, a], [a, a, a]], dtype=kernel_dtype, device=image.device) + kernel = kernel.expand(num_channels, 1, 3, 3) + + # We copy and cast at the same time to avoid modifications on the original data + output = image.to(dtype=kernel_dtype, copy=True) + blurred_degenerate = conv2d(output, kernel, groups=num_channels) + if not fp: + # it is better to round before cast + blurred_degenerate = blurred_degenerate.round_() + + # Create a view on the underlying output while pointing at the same data. We do this to avoid indexing twice. + view = output[..., 1:-1, 1:-1] + + # We speed up blending by minimizing flops and doing in-place. The 2 blend options are mathematically equivalent: + # x+(1-r)*(y-x) = x + (1-r)*y - (1-r)*x = x*r + y*(1-r) + view.add_(blurred_degenerate.sub_(view), alpha=(1.0 - sharpness_factor)) + + # The actual data of ouput have been modified by the above. We only need to clamp and cast now. + output = output.clamp_(0, bound) + if not fp: + output = output.to(image.dtype) if needs_unsquash: output = output.reshape(shape)