diff --git a/configs/model/fre.yaml b/configs/model/fre.yaml new file mode 100755 index 0000000000..deb92a117d --- /dev/null +++ b/configs/model/fre.yaml @@ -0,0 +1,12 @@ +model: + class_path: anomalib.models.Fre + init_args: + backbone: resnet50 + layer: layer3 + pre_trained: true + pooling_kernel_size: 2 + input_dim: 65536 + latent_dim: 220 + +trainer: + max_epochs: 220 diff --git a/src/anomalib/models/__init__.py b/src/anomalib/models/__init__.py index 722cd1dfe5..b4bb36a875 100644 --- a/src/anomalib/models/__init__.py +++ b/src/anomalib/models/__init__.py @@ -1,9 +1,8 @@ """Load Anomaly Model.""" -# Copyright (C) 2022-2023 Intel Corporation +# Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - import logging from importlib import import_module @@ -23,6 +22,7 @@ Dsr, EfficientAd, Fastflow, + Fre, Ganomaly, Padim, Patchcore, @@ -49,6 +49,7 @@ class UnknownModelError(ModuleNotFoundError): "Dsr", "EfficientAd", "Fastflow", + "Fre", "Ganomaly", "Padim", "Patchcore", diff --git a/src/anomalib/models/image/__init__.py b/src/anomalib/models/image/__init__.py index 8478747f01..f3a5435038 100644 --- a/src/anomalib/models/image/__init__.py +++ b/src/anomalib/models/image/__init__.py @@ -12,6 +12,7 @@ from .dsr import Dsr from .efficient_ad import EfficientAd from .fastflow import Fastflow +from .fre import Fre from .ganomaly import Ganomaly from .padim import Padim from .patchcore import Patchcore @@ -31,6 +32,7 @@ "Dsr", "EfficientAd", "Fastflow", + "Fre", "Ganomaly", "Padim", "Patchcore", diff --git a/src/anomalib/models/image/fre/README.md b/src/anomalib/models/image/fre/README.md new file mode 100755 index 0000000000..ed4f73cecc --- /dev/null +++ b/src/anomalib/models/image/fre/README.md @@ -0,0 +1,43 @@ +# FRE: A Fast Method For Anomaly Detection And Segmentation + +This is the implementation of [FRE](https://papers.bmvc2023.org/0614.pdf) paper. + +Model Type: Segmentation + +## Description + +Fast anomaly classification algorithm that consists of a deep feature extraction stage followed by anomaly classification stage comprising a shallow linear autoencoder. + +### Feature Extraction + +Features are extracted by feeding the images through a ResNet50 backbone, which was pre-trained on ImageNet. The output of an intermediate layer (layer3 by default) of the network is used to obtain a semantic feature vector with a fixed length of 65536. + +### Anomaly Detection + +In the anomaly classification stage, a shallow linear autoencoder is trained on the features of the chosen layer. A feature-reconstruction scores (norm of the error between the reconstructed output of the autoencoder and the original high-dimensional feature) is calculated as the anomaly score. Anomaly map is generated by reshaping and resizing the error tensor to match the input image dimension. + +## Usage + +`anomalib train --model Fre --data anomalib.data.MVTec` + +## Benchmark + +All results gathered with seed `42`. + +## [MVTec AD Dataset](https://www.mvtec.com/company/research/datasets/mvtec-ad) + +> Note: Metrics for ResNet 50 were calculated with pooling kernel size of 2 while for Wide ResNet 50, kernel size of 4 was used. + +### Image-Level AUC + +| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper | +| -------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: | +| ResNet-50 | 0.933 | 0.951 | 0.817 | 1 | 0.991 | 0.988 | 0.996 | 0.931 | 0.887 | 0.974 | 0.902 | 0.951 | 0.798 | 0.908 | 0.943 | 0.957 | +| Wide ResNet-50 | 0.947 | 0.928 | 0.909 | 1 | 0.991 | 0.950 | 0.996 | 0.944 | 0.908 | 0.973 | 0.933 | 0.971 | 0.827 | 0.950 | 0.963 | 0.968 | + +### Image F1 Score + +| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper | +| -------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: | +| ResNet-50 | 0.935 | 0.928 | 0.887 | 0.995 | 0.982 | 0.975 | 0.984 | 0.906 | 0.943 | 0.943 | 0.914 | 0.943 | 0.872 | 0.912 | 0.880 | 0.958 | +| Wide ResNet-50 | 0.941 | 0.919 | 0.906 | 0.989 | 0.982 | 0.948 | 0.984 | 0.911 | 0.951 | 0.950 | 0.934 | 0.960 | 0.885 | 0.931 | 0.895 | 0.967 | diff --git a/src/anomalib/models/image/fre/__init__.py b/src/anomalib/models/image/fre/__init__.py new file mode 100755 index 0000000000..7de3b5b399 --- /dev/null +++ b/src/anomalib/models/image/fre/__init__.py @@ -0,0 +1,8 @@ +"""Deep Feature Extraction (DFM) model.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .lightning_model import Fre + +__all__ = ["Fre"] diff --git a/src/anomalib/models/image/fre/lightning_model.py b/src/anomalib/models/image/fre/lightning_model.py new file mode 100755 index 0000000000..355844f0f7 --- /dev/null +++ b/src/anomalib/models/image/fre/lightning_model.py @@ -0,0 +1,121 @@ +"""FRE: Feature-Reconstruction Error. + +https://papers.bmvc2023.org/0614.pdf +""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging +from typing import Any + +import torch +from lightning.pytorch.utilities.types import STEP_OUTPUT +from torch import optim + +from anomalib import LearningType +from anomalib.models.components import AnomalyModule + +from .torch_model import FREModel + +logger = logging.getLogger(__name__) + + +class Fre(AnomalyModule): + """FRE: Feature-reconstruction error using Tied AutoEncoder. + + Args: + backbone (str): Backbone CNN network + Defaults to ``resnet50``. + layer (str): Layer to extract features from the backbone CNN + Defaults to ``layer3``. + pre_trained (bool, optional): Boolean to check whether to use a pre_trained backbone. + Defaults to ``True``. + pooling_kernel_size (int, optional): Kernel size to pool features extracted from the CNN. + Defaults to ``2``. + input_dim (int, optional): Dimension of feature at output of layer specified in layer. + Defaults to ``65536``. + latent_dim (int, optional): Reduced size of feature after applying dimensionality reduction + via shallow linear autoencoder. + Defaults to ``220``. + """ + + def __init__( + self, + backbone: str = "resnet50", + layer: str = "layer3", + pre_trained: bool = True, + pooling_kernel_size: int = 2, + input_dim: int = 65536, + latent_dim: int = 220, + ) -> None: + super().__init__() + + self.model: FREModel = FREModel( + backbone=backbone, + pre_trained=pre_trained, + layer=layer, + pooling_kernel_size=pooling_kernel_size, + input_dim=input_dim, + latent_dim=latent_dim, + ) + self.loss_fn = torch.nn.MSELoss() + + def configure_optimizers(self) -> torch.optim.Optimizer: + """Configure optimizers. + + Returns: + Optimizer: Adam optimizer + """ + return optim.Adam(params=self.model.fre_model.parameters(), lr=1e-3) + + def training_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> STEP_OUTPUT: + """Perform the training step of FRE. + + For each batch, features are extracted from the CNN. + + Args: + batch (dict[str, str | torch.Tensor]): Input batch + args: Arguments. + kwargs: Keyword arguments. + + Returns: + Deep CNN features. + """ + del args, kwargs # These variables are not used. + features_in, features_out, _ = self.model.get_features(batch["image"]) + loss = self.loss_fn(features_in, features_out) + self.log("train_loss", loss.item(), on_epoch=True, prog_bar=True, logger=True) + return {"loss": loss} + + def validation_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> STEP_OUTPUT: + """Perform the validation step of FRE. + + Similar to the training step, features are extracted from the CNN for each batch. + + Args: + batch (dict[str, str | torch.Tensor]): Input batch + args: Arguments. + kwargs: Keyword arguments. + + Returns: + Dictionary containing FRE anomaly scores and anomaly maps. + """ + del args, kwargs # These variables are not used. + + batch["pred_scores"], batch["anomaly_maps"] = self.model(batch["image"]) + return batch + + @property + def trainer_arguments(self) -> dict[str, Any]: + """Return FRE-specific trainer arguments.""" + return {"gradient_clip_val": 0, "max_epochs": 220, "num_sanity_val_steps": 0} + + @property + def learning_type(self) -> LearningType: + """Return the learning type of the model. + + Returns: + LearningType: Learning type of the model. + """ + return LearningType.ONE_CLASS diff --git a/src/anomalib/models/image/fre/torch_model.py b/src/anomalib/models/image/fre/torch_model.py new file mode 100755 index 0000000000..534521dd01 --- /dev/null +++ b/src/anomalib/models/image/fre/torch_model.py @@ -0,0 +1,114 @@ +"""PyTorch model for FRE model implementation.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import torch +from torch import nn +from torch.nn import functional as F # noqa: N812 + +from anomalib.models.components import TimmFeatureExtractor + + +class TiedAE(nn.Module): + """Model for the Tied AutoEncoder used for FRE calculation. + + Args: + input_dim (int): Dimension of input to the tied auto-encoder. + latent_dim (int): Dimension of the reduced-dimension latent space of the tied auto-encoder. + """ + + def __init__(self, input_dim: int, latent_dim: int) -> None: + super().__init__() + self.input_dim = input_dim + self.latent_dim = latent_dim + self.weight = nn.Parameter(torch.empty(latent_dim, input_dim)) + torch.nn.init.xavier_uniform_(self.weight) + self.decoder_bias = nn.Parameter(torch.zeros(input_dim)) + self.encoder_bias = nn.Parameter(torch.zeros(latent_dim)) + + def forward(self, features: torch.Tensor) -> torch.Tensor: + """Run input features through the autoencoder. + + Args: + features (torch.Tensor): Feature batch. + + Returns: + Tensor: torch.Tensor containing reconstructed features. + """ + encoded = F.linear(features, self.weight, self.encoder_bias) + return F.linear(encoded, self.weight.t(), self.decoder_bias) + + +class FREModel(nn.Module): + """Model for the FRE algorithm. + + Args: + backbone (str): Pre-trained model backbone. + layer (str): Layer from which to extract features. + pre_trained (bool, optional): Boolean to check whether to use a pre_trained backbone. + Defaults to ``True``. + pooling_kernel_size (int, optional): Kernel size to pool features extracted from the CNN. + Defaults to ``4``. + input_dim (int, optional): Dimension of feature at output of layer specified in layer. + Defaults to ``65536``. + latent_dim (int, optional): Reduced size of feature after applying dimensionality reduction + via shallow linear autoencoder. + Defaults to ``220``. + """ + + def __init__( + self, + backbone: str, + layer: str, + input_dim: int = 65536, + latent_dim: int = 220, + pre_trained: bool = True, + pooling_kernel_size: int = 4, + ) -> None: + super().__init__() + self.backbone = backbone + self.pooling_kernel_size = pooling_kernel_size + self.fre_model = TiedAE(input_dim, latent_dim) + self.layer = layer + self.feature_extractor = TimmFeatureExtractor( + backbone=self.backbone, + pre_trained=pre_trained, + layers=[layer], + ).eval() + + def get_features(self, batch: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Extract features from the pretrained network. + + Args: + batch (torch.Tensor): Image batch. + + Returns: + Tensor: torch.Tensor containing extracted features. + """ + self.feature_extractor.eval() + features_in = self.feature_extractor(batch)[self.layer] + batch_size = len(features_in) + if self.pooling_kernel_size > 1: + features_in = F.avg_pool2d(input=features_in, kernel_size=self.pooling_kernel_size) + feature_shapes = features_in.shape + features_in = features_in.view(batch_size, -1).detach() + features_out = self.fre_model(features_in) + return features_in, features_out, feature_shapes + + def forward(self, batch: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Compute score from input images. + + Args: + batch (torch.Tensor): Input images + + Returns: + tuple[torch.Tensor, torch.Tensor]: Scores, Anomaly Map + """ + features_in, features_out, feature_shapes = self.get_features(batch) + fre = torch.square(features_in - features_out).reshape(feature_shapes) + anomaly_map = torch.sum(fre, 1) # NxCxHxW --> NxHxW + score = torch.sum(anomaly_map, (1, 2)) # NxHxW --> N + anomaly_map = torch.unsqueeze(anomaly_map, 1) + anomaly_map = F.interpolate(anomaly_map, size=batch.shape[-2:], mode="bilinear", align_corners=False) + return score, anomaly_map