diff --git a/piqa/lpips.py b/piqa/lpips.py index 3cbbe27..522a9b4 100644 --- a/piqa/lpips.py +++ b/piqa/lpips.py @@ -11,8 +11,6 @@ .. [Deng2009] ImageNet: A large-scale hierarchical image database (Deng et al, 2009) """ -import inspect -import os import torch import torch.nn as nn import torchvision.models as models @@ -119,6 +117,7 @@ class LPIPS(nn.Module): network: Specifies the perceptual network :math:`\mathcal{F}` to use: `'alex'` | `'squeeze'` | `'vgg'`. scaling: Whether the input and target need to be scaled w.r.t. [Deng2009]_. + epsilon: A numerical stability term. dropout: Whether dropout is used or not. pretrained: Whether the official weights :math:`w_l` are used or not. eval: Whether to initialize the object in evaluation mode or not. @@ -144,6 +143,7 @@ def __init__( self, network: str = 'alex', scaling: bool = True, + epsilon: float = 1e-10, dropout: bool = False, pretrained: bool = True, eval: bool = True, @@ -155,6 +155,7 @@ def __init__( self.scaling = scaling self.register_buffer('shift', SHIFT.reshape(1, -1, 1, 1)) self.register_buffer('scale', SCALE.reshape(1, -1, 1, 1)) + self.epsilon = epsilon # Perception layers if network == 'alex': # AlexNet @@ -210,8 +211,8 @@ def forward(self, input: Tensor, target: Tensor) -> Tensor: residuals = [] for lin, fx, fy in zip(self.lins, self.net(input), self.net(target)): - fx = fx / l2_norm(fx, dims=[1], keepdim=True) - fy = fy / l2_norm(fy, dims=[1], keepdim=True) + fx = fx / (l2_norm(fx, dims=[1], keepdim=True) + self.epsilon) + fy = fy / (l2_norm(fy, dims=[1], keepdim=True) + self.epsilon) mse = ((fx - fy) ** 2).mean(dim=(-1, -2), keepdim=True) residuals.append(lin(mse).flatten())