Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

πŸš€ Add Auto-Encoder based FRE #2025

Merged
merged 24 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
eee7ab5
Initial commit of FRE
nahuja-intel Apr 26, 2024
82d6c91
Merge branch 'nilesh/dfm' into main
nahuja-intel Apr 26, 2024
3b89c86
Fixing code-quality check failures
nahuja-intel Apr 26, 2024
dcf1e66
More pre-commit fixes
nahuja-intel Apr 26, 2024
f254e87
Address ruff issues
samet-akcay Apr 26, 2024
df47b91
Merge pull request #1 from samet-akcay/fix/address-pre-commit-in-fre
nahuja-intel Apr 26, 2024
9554b28
Update configs/model/fre.yaml
samet-akcay Apr 29, 2024
64ae814
Update configs/model/fre.yaml
samet-akcay Apr 29, 2024
7514ce8
Update src/anomalib/models/image/fre/lightning_model.py
samet-akcay Apr 29, 2024
442269d
Update src/anomalib/models/image/fre/lightning_model.py
samet-akcay Apr 29, 2024
6cb31f0
Update src/anomalib/models/image/fre/lightning_model.py
samet-akcay Apr 29, 2024
f260abc
Update src/anomalib/models/image/fre/torch_model.py
samet-akcay Apr 29, 2024
4149092
Update src/anomalib/models/image/fre/lightning_model.py
samet-akcay Apr 29, 2024
8c3b5c2
Update src/anomalib/models/image/fre/torch_model.py
samet-akcay Apr 29, 2024
850142e
Update src/anomalib/models/image/fre/torch_model.py
samet-akcay Apr 29, 2024
a93f3dd
Update src/anomalib/models/image/fre/torch_model.py
samet-akcay Apr 29, 2024
e31c469
Updating README with results for ResNet50 and Wide ResNet50
nahuja-intel May 17, 2024
131fa89
Fixing max_epochs to 220
nahuja-intel May 17, 2024
82fda1f
Update configs/model/fre.yaml
nahuja-intel May 17, 2024
3e093c5
Update configs/model/fre.yaml
nahuja-intel May 17, 2024
2a81ee2
Merge branch 'main' into main
samet-akcay May 19, 2024
4abe656
Update src/anomalib/models/image/fre/README.md
samet-akcay May 20, 2024
a52b195
Update src/anomalib/models/image/fre/README.md
samet-akcay May 20, 2024
32f2083
Merge branch 'main' into main
samet-akcay May 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions configs/model/fre.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
model:
class_path: anomalib.models.Fre
init_args:
backbone: resnet50
layer: layer3
pre_trained: true
pooling_kernel_size: 2
input_dim: 65536
latent_dim: 220

trainer:
max_epochs: 220
5 changes: 3 additions & 2 deletions src/anomalib/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"""Load Anomaly Model."""

# Copyright (C) 2022-2023 Intel Corporation
# Copyright (C) 2022-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0


import logging
from importlib import import_module

Expand All @@ -23,6 +22,7 @@
Dsr,
EfficientAd,
Fastflow,
Fre,
Ganomaly,
Padim,
Patchcore,
Expand All @@ -49,6 +49,7 @@ class UnknownModelError(ModuleNotFoundError):
"Dsr",
"EfficientAd",
"Fastflow",
"Fre",
"Ganomaly",
"Padim",
"Patchcore",
Expand Down
2 changes: 2 additions & 0 deletions src/anomalib/models/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .dsr import Dsr
from .efficient_ad import EfficientAd
from .fastflow import Fastflow
from .fre import Fre
from .ganomaly import Ganomaly
from .padim import Padim
from .patchcore import Patchcore
Expand All @@ -31,6 +32,7 @@
"Dsr",
"EfficientAd",
"Fastflow",
"Fre",
"Ganomaly",
"Padim",
"Patchcore",
Expand Down
43 changes: 43 additions & 0 deletions src/anomalib/models/image/fre/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# FRE: A Fast Method For Anomaly Detection And Segmentation

This is the implementation of [FRE](https://papers.bmvc2023.org/0614.pdf) paper.

Model Type: Segmentation

## Description

Fast anomaly classification algorithm that consists of a deep feature extraction stage followed by anomaly classification stage comprising a shallow linear autoencoder.

### Feature Extraction

Features are extracted by feeding the images through a ResNet50 backbone, which was pre-trained on ImageNet. The output of an intermediate layer (layer3 by default) of the network is used to obtain a semantic feature vector with a fixed length of 65536.

### Anomaly Detection

In the anomaly classification stage, a shallow linear autoencoder is trained on the features of the chosen layer. A feature-reconstruction scores (norm of the error between the reconstructed output of the autoencoder and the original high-dimensional feature) is calculated as the anomaly score. Anomaly map is generated by reshaping and resizing the error tensor to match the input image dimension.

## Usage

`anomalib train --model Fre --data anomalib.data.MVTec`

## Benchmark

All results gathered with seed `42`.

## [MVTec AD Dataset](https://www.mvtec.com/company/research/datasets/mvtec-ad)

> Note: Metrics for ResNet 50 were calculated with pooling kernel size of 2 while for Wide ResNet 50, kernel size of 4 was used.

### Image-Level AUC

| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper |
| -------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: |
| ResNet-50 | 0.933 | 0.951 | 0.817 | 1 | 0.991 | 0.988 | 0.996 | 0.931 | 0.887 | 0.974 | 0.902 | 0.951 | 0.798 | 0.908 | 0.943 | 0.957 |
| Wide ResNet-50 | 0.947 | 0.928 | 0.909 | 1 | 0.991 | 0.950 | 0.996 | 0.944 | 0.908 | 0.973 | 0.933 | 0.971 | 0.827 | 0.950 | 0.963 | 0.968 |

### Image F1 Score

| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper |
| -------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: |
| ResNet-50 | 0.935 | 0.928 | 0.887 | 0.995 | 0.982 | 0.975 | 0.984 | 0.906 | 0.943 | 0.943 | 0.914 | 0.943 | 0.872 | 0.912 | 0.880 | 0.958 |
| Wide ResNet-50 | 0.941 | 0.919 | 0.906 | 0.989 | 0.982 | 0.948 | 0.984 | 0.911 | 0.951 | 0.950 | 0.934 | 0.960 | 0.885 | 0.931 | 0.895 | 0.967 |
8 changes: 8 additions & 0 deletions src/anomalib/models/image/fre/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Deep Feature Extraction (DFM) model."""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from .lightning_model import Fre

__all__ = ["Fre"]
121 changes: 121 additions & 0 deletions src/anomalib/models/image/fre/lightning_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""FRE: Feature-Reconstruction Error.

https://papers.bmvc2023.org/0614.pdf
"""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import logging
from typing import Any

import torch
from lightning.pytorch.utilities.types import STEP_OUTPUT
from torch import optim

from anomalib import LearningType
from anomalib.models.components import AnomalyModule

from .torch_model import FREModel

logger = logging.getLogger(__name__)


class Fre(AnomalyModule):
"""FRE: Feature-reconstruction error using Tied AutoEncoder.

Args:
backbone (str): Backbone CNN network
Defaults to ``resnet50``.
layer (str): Layer to extract features from the backbone CNN
Defaults to ``layer3``.
pre_trained (bool, optional): Boolean to check whether to use a pre_trained backbone.
Defaults to ``True``.
pooling_kernel_size (int, optional): Kernel size to pool features extracted from the CNN.
Defaults to ``2``.
input_dim (int, optional): Dimension of feature at output of layer specified in layer.
Defaults to ``65536``.
latent_dim (int, optional): Reduced size of feature after applying dimensionality reduction
via shallow linear autoencoder.
Defaults to ``220``.
"""

def __init__(
self,
backbone: str = "resnet50",
layer: str = "layer3",
pre_trained: bool = True,
pooling_kernel_size: int = 2,
input_dim: int = 65536,
latent_dim: int = 220,
) -> None:
super().__init__()

self.model: FREModel = FREModel(
backbone=backbone,
pre_trained=pre_trained,
layer=layer,
pooling_kernel_size=pooling_kernel_size,
input_dim=input_dim,
latent_dim=latent_dim,
)
self.loss_fn = torch.nn.MSELoss()

def configure_optimizers(self) -> torch.optim.Optimizer:
"""Configure optimizers.

Returns:
Optimizer: Adam optimizer
"""
return optim.Adam(params=self.model.fre_model.parameters(), lr=1e-3)

def training_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> STEP_OUTPUT:
"""Perform the training step of FRE.

For each batch, features are extracted from the CNN.

Args:
batch (dict[str, str | torch.Tensor]): Input batch
args: Arguments.
kwargs: Keyword arguments.

Returns:
Deep CNN features.
"""
del args, kwargs # These variables are not used.
features_in, features_out, _ = self.model.get_features(batch["image"])
loss = self.loss_fn(features_in, features_out)
self.log("train_loss", loss.item(), on_epoch=True, prog_bar=True, logger=True)
return {"loss": loss}

def validation_step(self, batch: dict[str, str | torch.Tensor], *args, **kwargs) -> STEP_OUTPUT:
"""Perform the validation step of FRE.

Similar to the training step, features are extracted from the CNN for each batch.

Args:
batch (dict[str, str | torch.Tensor]): Input batch
args: Arguments.
kwargs: Keyword arguments.

Returns:
Dictionary containing FRE anomaly scores and anomaly maps.
"""
del args, kwargs # These variables are not used.

batch["pred_scores"], batch["anomaly_maps"] = self.model(batch["image"])
return batch

@property
def trainer_arguments(self) -> dict[str, Any]:
"""Return FRE-specific trainer arguments."""
return {"gradient_clip_val": 0, "max_epochs": 220, "num_sanity_val_steps": 0}

@property
def learning_type(self) -> LearningType:
"""Return the learning type of the model.

Returns:
LearningType: Learning type of the model.
"""
return LearningType.ONE_CLASS
114 changes: 114 additions & 0 deletions src/anomalib/models/image/fre/torch_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""PyTorch model for FRE model implementation."""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import torch
from torch import nn
from torch.nn import functional as F # noqa: N812

from anomalib.models.components import TimmFeatureExtractor


class TiedAE(nn.Module):
"""Model for the Tied AutoEncoder used for FRE calculation.

Args:
input_dim (int): Dimension of input to the tied auto-encoder.
latent_dim (int): Dimension of the reduced-dimension latent space of the tied auto-encoder.
"""

def __init__(self, input_dim: int, latent_dim: int) -> None:
super().__init__()
self.input_dim = input_dim
self.latent_dim = latent_dim
self.weight = nn.Parameter(torch.empty(latent_dim, input_dim))
torch.nn.init.xavier_uniform_(self.weight)
self.decoder_bias = nn.Parameter(torch.zeros(input_dim))
self.encoder_bias = nn.Parameter(torch.zeros(latent_dim))

def forward(self, features: torch.Tensor) -> torch.Tensor:
"""Run input features through the autoencoder.

Args:
features (torch.Tensor): Feature batch.

Returns:
Tensor: torch.Tensor containing reconstructed features.
"""
encoded = F.linear(features, self.weight, self.encoder_bias)
return F.linear(encoded, self.weight.t(), self.decoder_bias)


class FREModel(nn.Module):
"""Model for the FRE algorithm.

Args:
backbone (str): Pre-trained model backbone.
layer (str): Layer from which to extract features.
pre_trained (bool, optional): Boolean to check whether to use a pre_trained backbone.
Defaults to ``True``.
pooling_kernel_size (int, optional): Kernel size to pool features extracted from the CNN.
Defaults to ``4``.
input_dim (int, optional): Dimension of feature at output of layer specified in layer.
Defaults to ``65536``.
latent_dim (int, optional): Reduced size of feature after applying dimensionality reduction
via shallow linear autoencoder.
Defaults to ``220``.
"""

def __init__(
self,
backbone: str,
layer: str,
input_dim: int = 65536,
latent_dim: int = 220,
pre_trained: bool = True,
pooling_kernel_size: int = 4,
) -> None:
super().__init__()
self.backbone = backbone
self.pooling_kernel_size = pooling_kernel_size
self.fre_model = TiedAE(input_dim, latent_dim)
self.layer = layer
self.feature_extractor = TimmFeatureExtractor(
backbone=self.backbone,
pre_trained=pre_trained,
layers=[layer],
).eval()

def get_features(self, batch: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Extract features from the pretrained network.

Args:
batch (torch.Tensor): Image batch.

Returns:
Tensor: torch.Tensor containing extracted features.
"""
self.feature_extractor.eval()
features_in = self.feature_extractor(batch)[self.layer]
batch_size = len(features_in)
if self.pooling_kernel_size > 1:
features_in = F.avg_pool2d(input=features_in, kernel_size=self.pooling_kernel_size)
feature_shapes = features_in.shape
features_in = features_in.view(batch_size, -1).detach()
features_out = self.fre_model(features_in)
return features_in, features_out, feature_shapes

def forward(self, batch: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute score from input images.

Args:
batch (torch.Tensor): Input images

Returns:
tuple[torch.Tensor, torch.Tensor]: Scores, Anomaly Map
"""
features_in, features_out, feature_shapes = self.get_features(batch)
fre = torch.square(features_in - features_out).reshape(feature_shapes)
anomaly_map = torch.sum(fre, 1) # NxCxHxW --> NxHxW
score = torch.sum(anomaly_map, (1, 2)) # NxHxW --> N
anomaly_map = torch.unsqueeze(anomaly_map, 1)
anomaly_map = F.interpolate(anomaly_map, size=batch.shape[-2:], mode="bilinear", align_corners=False)
return score, anomaly_map
Loading