Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix code logic to allow for passing of nn.Module to TorchFXFeatureExtractor #935

Merged
merged 11 commits into from
Mar 1, 2023
71 changes: 48 additions & 23 deletions anomalib/models/components/feature_extractors/torchfx.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
class BackboneParams:
"""Used for serializing the backbone."""

class_path: str | nn.Module
class_path: str | type[nn.Module]
init_args: dict = field(default_factory=dict)


Expand All @@ -30,7 +30,8 @@ class TorchFXFeatureExtractor(nn.Module):
Args:
backbone (str | BackboneParams | dict | nn.Module): The backbone to which the feature extraction hooks are
attached. If the name is provided, the model is loaded from torchvision. Otherwise, the model class can be
provided and it will try to load the weights from the provided weights file.
provided and it will try to load the weights from the provided weights file. Last, an instance of nn.Module
can also be passed directly.
return_nodes (Iterable[str]): List of layer names of the backbone to which the hooks are attached.
You can find the names of these nodes by using ``get_graph_node_names`` function.
weights (str | WeightsEnum | None): Weights enum to use for the model. Torchvision models require
Expand Down Expand Up @@ -59,6 +60,7 @@ class TorchFXFeatureExtractor(nn.Module):

With custom models:

>>> import torch
>>> from anomalib.models.components.feature_extractors import TorchFXFeatureExtractor
>>> feature_extractor = TorchFXFeatureExtractor(
"path.to.CustomModel", ["linear_relu_stack.3"], weights="path/to/weights.pth"
Expand All @@ -67,6 +69,20 @@ class TorchFXFeatureExtractor(nn.Module):
>>> features = feature_extractor(input)
>>> [layer for layer in features.keys()]
["linear_relu_stack.3"]

with model instances:

>>> import torch
>>> from anomalib.models.components.feature_extractors import TorchFXFeatureExtractor
>>> from timm import create_model
>>> model = create_model("resnet18", pretrained=True)
>>> feature_extractor = TorchFXFeatureExtractor(model, ["layer1"])
>>> input = torch.rand((32, 3, 256, 256))
>>> features = feature_extractor(input)
>>> [layer for layer in features.keys()]
["layer1"]
>>> [feature.shape for feature in features.values()]
[torch.Size([32, 64, 64, 64])]
"""

def __init__(
Expand All @@ -79,24 +95,29 @@ def __init__(
super().__init__()
if isinstance(backbone, dict):
backbone = BackboneParams(**backbone)
elif not isinstance(backbone, BackboneParams): # if str or nn.Module
elif isinstance(backbone, str):
backbone = BackboneParams(class_path=backbone)
elif not isinstance(backbone, (nn.Module, BackboneParams)):
raise ValueError(
f"backbone needs to be of type str | BackboneParams | dict | nn.Module, but was type {type(backbone)}"
)

self.feature_extractor = self.initialize_feature_extractor(backbone, return_nodes, weights, requires_grad)

def initialize_feature_extractor(
self,
backbone: BackboneParams,
backbone: BackboneParams | nn.Module,
return_nodes: list[str],
weights: str | WeightsEnum | None = None,
requires_grad: bool = False,
) -> GraphModule | nn.Module:
) -> GraphModule:
"""Extract features from a CNN.

Args:
backbone (BackboneParams): The backbone to which the feature extraction hooks are attached.
If the name is provided, the model is loaded from torchvision. Otherwise, the model class can be
provided and it will try to load the weights from the provided weights file.
backbone (BackboneParams | nn.Module): The backbone to which the feature extraction hooks are attached.
If the name is provided for BackboneParams, the model is loaded from torchvision. Otherwise, the model
class can be provided and it will try to load the weights from the provided weights file. Last, an
instance of the model can be provided as well, which will be used as-is.
return_nodes (Iterable[str]): List of layer names of the backbone to which the hooks are attached.
You can find the names of these nodes by using ``get_graph_node_names`` function.
weights (str | WeightsEnum | None): Weights enum to use for the model. Torchvision models require
Expand All @@ -108,22 +129,26 @@ def initialize_feature_extractor(
Returns:
Feature Extractor based on TorchFX.
"""
if isinstance(backbone.class_path, str):
backbone_class = self._get_backbone_class(backbone.class_path)
backbone_model = backbone_class(weights=weights, **backbone.init_args)
if isinstance(backbone, nn.Module):
backbone_model = backbone
else:
backbone_class = backbone.class_path
backbone_model = backbone_class(**backbone.init_args)
if isinstance(weights, WeightsEnum): # torchvision models
feature_extractor = create_feature_extractor(model=backbone_model, return_nodes=return_nodes)
else:
if weights is not None:
assert isinstance(weights, str), "Weights should point to a path"
model_weights = torch.load(weights)
if "state_dict" in model_weights:
model_weights = model_weights["state_dict"]
backbone_model.load_state_dict(model_weights)
feature_extractor = create_feature_extractor(backbone_model, return_nodes)
if isinstance(backbone.class_path, str):
backbone_class = self._get_backbone_class(backbone.class_path)
backbone_model = backbone_class(weights=weights, **backbone.init_args)
else:
backbone_class = backbone.class_path
backbone_model = backbone_class(**backbone.init_args)
if isinstance(weights, WeightsEnum): # torchvision models
feature_extractor = create_feature_extractor(model=backbone_model, return_nodes=return_nodes)
else:
if weights is not None:
assert isinstance(weights, str), "Weights should point to a path"
model_weights = torch.load(weights)
if "state_dict" in model_weights:
model_weights = model_weights["state_dict"]
backbone_model.load_state_dict(model_weights)

feature_extractor = create_feature_extractor(backbone_model, return_nodes)

if not requires_grad:
feature_extractor.eval()
Expand Down
15 changes: 13 additions & 2 deletions tests/pre_merge/models/test_feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@

import pytest
import torch
from torchvision.models import ResNet18_Weights, resnet18
from torchvision.models.efficientnet import EfficientNet_B5_Weights
from torchvision.models.resnet import ResNet18_Weights

from anomalib.models.components.feature_extractors import (
BackboneParams,
FeatureExtractor,
TorchFXFeatureExtractor,
dryrun_find_featuremap_dims,
Expand Down Expand Up @@ -70,17 +72,26 @@ def test_torchfx_feature_extraction(self):
assert features["layer2"].shape == torch.Size((32, 128, 32, 32))
assert features["layer3"].shape == torch.Size((32, 256, 16, 16))

# Test if local model can be loaded using string of weights path
# Test if local model can be instantiated from class and weights can be loaded using string of weights path
with TemporaryDirectory() as tmpdir:
torch.save(DummyModel().state_dict(), tmpdir + "/dummy_model.pt")
model = TorchFXFeatureExtractor(
backbone=DummyModel,
backbone=BackboneParams(class_path=DummyModel),
weights=tmpdir + "/dummy_model.pt",
return_nodes=["conv3"],
)
features = model(test_input)
assert features["conv3"].shape == torch.Size((32, 1, 244, 244))

# Test if nn.Module instance can be passed directly
resnet = resnet18(weights=ResNet18_Weights)
model = TorchFXFeatureExtractor(resnet, ["layer1", "layer2", "layer3"])
test_input = torch.rand((32, 3, 256, 256))
features = model(test_input)
assert features["layer1"].shape == torch.Size((32, 64, 64, 64))
assert features["layer2"].shape == torch.Size((32, 128, 32, 32))
assert features["layer3"].shape == torch.Size((32, 256, 16, 16))


@pytest.mark.parametrize(
"backbone",
Expand Down