Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

πŸ”¨ v2 - Refactor: Add missing auxiliary attributes to AnomalibModule #2460

17 changes: 2 additions & 15 deletions src/anomalib/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from anomalib.deploy import CompressionType, ExportType
from anomalib.models import AnomalibModule
from anomalib.utils.path import create_versioned_dir
from anomalib.visualization import ImageVisualizer

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -258,13 +257,13 @@ def _setup_trainer(self, model: AnomalibModule) -> None:
self._cache.update(model)

# Setup anomalib callbacks to be used with the trainer
self._setup_anomalib_callbacks(model)
self._setup_anomalib_callbacks()

# Instantiate the trainer if it is not already instantiated
if self._trainer is None:
self._trainer = Trainer(**self._cache.args)

def _setup_anomalib_callbacks(self, model: AnomalibModule) -> None:
def _setup_anomalib_callbacks(self) -> None:
"""Set up callbacks for the trainer."""
_callbacks: list[Callback] = []

Expand All @@ -279,18 +278,6 @@ def _setup_anomalib_callbacks(self, model: AnomalibModule) -> None:
),
)

# Add the post-processor callback.
if isinstance(model.post_processor, Callback):
_callbacks.append(model.post_processor)

# Add the metrics callback.
if isinstance(model.evaluator, Callback):
_callbacks.append(model.evaluator)

# Add the image visualizer callback if it is passed by the user.
if not any(isinstance(callback, ImageVisualizer) for callback in self._cache.args["callbacks"]):
_callbacks.append(ImageVisualizer())

_callbacks.append(TimerCallback())

# Combine the callbacks, and update the trainer callbacks.
Expand Down
2 changes: 1 addition & 1 deletion src/anomalib/models/components/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Copyright (C) 2022-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from .anomaly_module import AnomalibModule
from .anomalib_module import AnomalibModule
from .buffer_list import BufferListMixin
from .dynamic_buffer import DynamicBufferMixin
from .memory_bank_module import MemoryBankMixin
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from anomalib.metrics.threshold import Threshold
from anomalib.post_processing import OneClassPostProcessor, PostProcessor
from anomalib.pre_processing import PreProcessor
from anomalib.visualization import ImageVisualizer, Visualizer

from .export_mixin import ExportMixin

Expand All @@ -40,8 +41,9 @@
def __init__(
self,
pre_processor: PreProcessor | bool = True,
post_processor: PostProcessor | None = None,
post_processor: PostProcessor | bool = True,
evaluator: Evaluator | bool = True,
visualizer: Visualizer | bool = True,
) -> None:
super().__init__()
logger.info("Initializing %s model.", self.__class__.__name__)
Expand All @@ -52,11 +54,12 @@
self.callbacks: list[Callback]

self.pre_processor = self._resolve_pre_processor(pre_processor)
self.post_processor = post_processor or self.default_post_processor()
self.post_processor = self._resolve_post_processor(post_processor)
self.evaluator = self._resolve_evaluator(evaluator)
self.visualizer = self._resolve_visualizer(visualizer)

self._input_size: tuple[int, int] | None = None
self._is_setup = False # flag to track if setup has been called from the trainer
self._is_setup = False

@property
def name(self) -> str:
Expand All @@ -79,28 +82,20 @@
initialization.
"""

def _resolve_pre_processor(self, pre_processor: PreProcessor | bool) -> PreProcessor | None:
"""Resolve and validate which pre-processor to use..

Args:
pre_processor: Pre-processor configuration
- True -> use default pre-processor
- False -> no pre-processor
- PreProcessor -> use the provided pre-processor
def configure_callbacks(self) -> Sequence[Callback] | Callback:
"""Configure default callbacks for AnomalibModule.

Returns:
Configured pre-processor
List of callbacks that includes the pre-processor, post-processor, evaluator,
and visualizer if they are available and inherit from Callback.
"""
if isinstance(pre_processor, PreProcessor):
return pre_processor
if isinstance(pre_processor, bool):
return self.configure_pre_processor() if pre_processor else None
msg = f"Invalid pre-processor type: {type(pre_processor)}"
raise TypeError(msg)

def configure_callbacks(self) -> Sequence[Callback] | Callback:
"""Configure default callbacks for AnomalibModule."""
return [self.pre_processor] if self.pre_processor else []
callbacks: list[Callback] = []
callbacks.extend(
component
for component in (self.pre_processor, self.post_processor, self.evaluator, self.visualizer)
if isinstance(component, Callback)
)
return callbacks

def forward(self, batch: torch.Tensor, *args, **kwargs) -> InferenceBatch:
"""Perform the forward-pass by passing input tensor to the module.
Expand Down Expand Up @@ -170,6 +165,25 @@
"""Learning type of the model."""
raise NotImplementedError

def _resolve_pre_processor(self, pre_processor: PreProcessor | bool) -> PreProcessor | None:
"""Resolve and validate which pre-processor to use..

Args:
pre_processor: Pre-processor configuration
- True -> use default pre-processor
- False -> no pre-processor
- PreProcessor -> use the provided pre-processor

Returns:
Configured pre-processor
"""
if isinstance(pre_processor, PreProcessor):
return pre_processor

Check warning on line 181 in src/anomalib/models/components/base/anomalib_module.py

View check run for this annotation

Codecov / codecov/patch

src/anomalib/models/components/base/anomalib_module.py#L181

Added line #L181 was not covered by tests
if isinstance(pre_processor, bool):
return self.configure_pre_processor() if pre_processor else None
msg = f"Invalid pre-processor type: {type(pre_processor)}"
raise TypeError(msg)

Check warning on line 185 in src/anomalib/models/components/base/anomalib_module.py

View check run for this annotation

Codecov / codecov/patch

src/anomalib/models/components/base/anomalib_module.py#L184-L185

Added lines #L184 - L185 were not covered by tests

@classmethod
def configure_pre_processor(cls, image_size: tuple[int, int] | None = None) -> PreProcessor:
"""Configure the pre-processor.
Expand Down Expand Up @@ -214,15 +228,54 @@
]),
)

def default_post_processor(self) -> PostProcessor | None:
"""Default post processor.
def _resolve_post_processor(self, post_processor: PostProcessor | bool) -> PostProcessor | None:
"""Resolve and validate which post-processor to use.

Override in subclass for model-specific post-processing behaviour.
Args:
post_processor: Post-processor configuration
- True -> use default post-processor
- False -> no post-processor
- PostProcessor -> use the provided post-processor

Returns:
Configured post-processor
"""
if isinstance(post_processor, PostProcessor):
return post_processor

Check warning on line 244 in src/anomalib/models/components/base/anomalib_module.py

View check run for this annotation

Codecov / codecov/patch

src/anomalib/models/components/base/anomalib_module.py#L244

Added line #L244 was not covered by tests
if isinstance(post_processor, bool):
return self.configure_post_processor() if post_processor else None
msg = f"Invalid post-processor type: {type(post_processor)}"
raise TypeError(msg)

Check warning on line 248 in src/anomalib/models/components/base/anomalib_module.py

View check run for this annotation

Codecov / codecov/patch

src/anomalib/models/components/base/anomalib_module.py#L247-L248

Added lines #L247 - L248 were not covered by tests

def configure_post_processor(self) -> PostProcessor | None:
"""Configure the default post-processor based on the learning type.

Returns:
PostProcessor: Configured post-processor instance.

Raises:
NotImplementedError: If no default post-processor is available for the model's learning type.

Examples:
Get default post-processor:

>>> post_processor = AnomalibModule.configure_post_processor()

Create model with custom post-processor:

>>> custom_post_processor = CustomPostProcessor()
>>> model = PatchCore(post_processor=custom_post_processor)

Disable post-processing:

>>> model = PatchCore(post_processor=False)
"""
if self.learning_type == LearningType.ONE_CLASS:
return OneClassPostProcessor()
msg = f"No default post-processor available for model {self.__name__} with learning type {self.learning_type}. \
Please override the default_post_processor method in the model implementation."
msg = (

Check warning on line 275 in src/anomalib/models/components/base/anomalib_module.py

View check run for this annotation

Codecov / codecov/patch

src/anomalib/models/components/base/anomalib_module.py#L275

Added line #L275 was not covered by tests
f"No default post-processor available for model with learning type {self.learning_type}. "
"Please override the configure_post_processor method in the model implementation."
)
raise NotImplementedError(msg)

def _resolve_evaluator(self, evaluator: Evaluator | bool) -> Evaluator | None:
Expand Down Expand Up @@ -251,6 +304,63 @@
test_metrics = [image_auroc, image_f1score, pixel_auroc, pixel_f1score]
return Evaluator(test_metrics=test_metrics)

def _resolve_visualizer(self, visualizer: Visualizer | bool) -> Visualizer | None:
"""Resolve and validate which visualizer to use.

Args:
visualizer: Visualizer configuration
- True -> use default visualizer
- False -> no visualizer
- Visualizer -> use the provided visualizer

Returns:
Configured visualizer
"""
if isinstance(visualizer, Visualizer):
return visualizer

Check warning on line 320 in src/anomalib/models/components/base/anomalib_module.py

View check run for this annotation

Codecov / codecov/patch

src/anomalib/models/components/base/anomalib_module.py#L320

Added line #L320 was not covered by tests
if isinstance(visualizer, bool):
return self.configure_visualizer() if visualizer else None
msg = f"Visualizer must be of type Visualizer or bool, got {type(visualizer)}"
raise TypeError(msg)

Check warning on line 324 in src/anomalib/models/components/base/anomalib_module.py

View check run for this annotation

Codecov / codecov/patch

src/anomalib/models/components/base/anomalib_module.py#L323-L324

Added lines #L323 - L324 were not covered by tests

@classmethod
def configure_visualizer(cls) -> ImageVisualizer:
"""Configure the default visualizer.

By default, this method returns an ImageVisualizer instance, which is suitable for
visualizing image-based anomaly detection results. However, the visualizer can be
customized based on your needs - for example, using VideoVisualizer for video data
or implementing a custom visualizer for specific visualization requirements.

Returns:
Visualizer: Configured visualizer instance (ImageVisualizer by default).

Examples:
Get default ImageVisualizer:

>>> visualizer = AnomalibModule.configure_visualizer()

Create model with VideoVisualizer:

>>> from custom_module import VideoVisualizer
>>> video_visualizer = VideoVisualizer()
>>> model = PatchCore(visualizer=video_visualizer)

Create model with custom visualizer:

>>> class CustomVisualizer(Visualizer):
... def __init__(self):
... super().__init__()
... # Custom visualization logic
>>> custom_visualizer = CustomVisualizer()
>>> model = PatchCore(visualizer=custom_visualizer)

Disable visualization:

>>> model = PatchCore(visualizer=False)
"""
return ImageVisualizer()

@property
def input_size(self) -> tuple[int, int] | None:
"""Return the effective input size of the model.
Expand Down
12 changes: 10 additions & 2 deletions src/anomalib/models/image/cfa/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from anomalib.models.components import AnomalibModule
from anomalib.post_processing import PostProcessor
from anomalib.pre_processing import PreProcessor
from anomalib.visualization import Visualizer

from .loss import CfaLoss
from .torch_model import CfaModel
Expand Down Expand Up @@ -58,11 +59,18 @@ def __init__(
num_nearest_neighbors: int = 3,
num_hard_negative_features: int = 3,
radius: float = 1e-5,
# Anomalib's Auxiliary Components
pre_processor: PreProcessor | bool = True,
post_processor: PostProcessor | None = None,
post_processor: PostProcessor | bool = True,
evaluator: Evaluator | bool = True,
visualizer: Visualizer | bool = True,
) -> None:
super().__init__(pre_processor=pre_processor, post_processor=post_processor, evaluator=evaluator)
super().__init__(
pre_processor=pre_processor,
post_processor=post_processor,
evaluator=evaluator,
visualizer=visualizer,
)
self.model: CfaModel = CfaModel(
backbone=backbone,
gamma_c=gamma_c,
Expand Down
11 changes: 9 additions & 2 deletions src/anomalib/models/image/cflow/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from anomalib.models.components import AnomalibModule
from anomalib.post_processing import PostProcessor
from anomalib.pre_processing import PreProcessor
from anomalib.visualization import Visualizer

from .torch_model import CflowModel
from .utils import get_logp, positional_encoding_2d
Expand Down Expand Up @@ -71,10 +72,16 @@ def __init__(
permute_soft: bool = False,
lr: float = 0.0001,
pre_processor: PreProcessor | bool = True,
post_processor: PostProcessor | None = None,
post_processor: PostProcessor | bool = True,
evaluator: Evaluator | bool = True,
visualizer: Visualizer | bool = True,
) -> None:
super().__init__(pre_processor=pre_processor, post_processor=post_processor, evaluator=evaluator)
super().__init__(
pre_processor=pre_processor,
post_processor=post_processor,
evaluator=evaluator,
visualizer=visualizer,
)

self.model: CflowModel = CflowModel(
backbone=backbone,
Expand Down
11 changes: 9 additions & 2 deletions src/anomalib/models/image/csflow/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from anomalib.models.components import AnomalibModule
from anomalib.post_processing import PostProcessor
from anomalib.pre_processing import PreProcessor
from anomalib.visualization import Visualizer

from .loss import CsFlowLoss
from .torch_model import CsFlowModel
Expand Down Expand Up @@ -48,10 +49,16 @@ def __init__(
clamp: int = 3,
num_channels: int = 3,
pre_processor: PreProcessor | bool = True,
post_processor: PostProcessor | None = None,
post_processor: PostProcessor | bool = True,
evaluator: Evaluator | bool = True,
visualizer: Visualizer | bool = True,
) -> None:
super().__init__(pre_processor=pre_processor, post_processor=post_processor, evaluator=evaluator)
super().__init__(
pre_processor=pre_processor,
post_processor=post_processor,
evaluator=evaluator,
visualizer=visualizer,
)
if self.input_size is None:
msg = "CsFlow needs input size to build torch model."
raise ValueError(msg)
Expand Down
11 changes: 9 additions & 2 deletions src/anomalib/models/image/dfkde/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from anomalib.models.components.classification import FeatureScalingMethod
from anomalib.post_processing import PostProcessor
from anomalib.pre_processing import PreProcessor
from anomalib.visualization import Visualizer

from .torch_model import DfkdeModel

Expand Down Expand Up @@ -50,10 +51,16 @@ def __init__(
feature_scaling_method: FeatureScalingMethod = FeatureScalingMethod.SCALE,
max_training_points: int = 40000,
pre_processor: PreProcessor | bool = True,
post_processor: PostProcessor | None = None,
post_processor: PostProcessor | bool = True,
evaluator: Evaluator | bool = True,
visualizer: Visualizer | bool = True,
) -> None:
super().__init__(pre_processor=pre_processor, post_processor=post_processor, evaluator=evaluator)
super().__init__(
pre_processor=pre_processor,
post_processor=post_processor,
evaluator=evaluator,
visualizer=visualizer,
)

self.model = DfkdeModel(
layers=layers,
Expand Down
Loading