From a00c58987a33eefbaa4eb3dca9361dcf11e9f486 Mon Sep 17 00:00:00 2001 From: Adrian Boguszewski Date: Fri, 17 May 2024 18:04:24 +0200 Subject: [PATCH 1/9] Added accuracy control quantization Signed-off-by: Adrian Boguszewski --- src/anomalib/deploy/export.py | 6 +++ src/anomalib/engine/engine.py | 6 +++ .../models/components/base/export_mixin.py | 40 ++++++++++++++++++- 3 files changed, 50 insertions(+), 2 deletions(-) diff --git a/src/anomalib/deploy/export.py b/src/anomalib/deploy/export.py index 87066c9ef9..c53fb964dd 100644 --- a/src/anomalib/deploy/export.py +++ b/src/anomalib/deploy/export.py @@ -57,6 +57,12 @@ class CompressionType(str, Enum): Full integer post-training quantization (INT8) All weights and operations are quantized to INT8. Inference is done in INT8 precision. """ + INT8_ACQ = "int8_acq" + """ + Accuracy-control quantization (INT8) + Weights and operations are quantized to INT8, except those that would degrade quality of the model more than is + acceptable. Inference is done in mixed precision. + """ class InferenceModel(nn.Module): diff --git a/src/anomalib/engine/engine.py b/src/anomalib/engine/engine.py index c41ecdf531..e8e5c6c40b 100644 --- a/src/anomalib/engine/engine.py +++ b/src/anomalib/engine/engine.py @@ -14,6 +14,7 @@ from lightning.pytorch.trainer import Trainer from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT, EVAL_DATALOADERS, TRAIN_DATALOADERS from torch.utils.data import DataLoader, Dataset +from torchmetrics import Metric from torchvision.transforms.v2 import Transform from anomalib import LearningType, TaskType @@ -871,6 +872,7 @@ def export( transform: Transform | None = None, compression_type: CompressionType | None = None, datamodule: AnomalibDataModule | None = None, + metric: Metric | str | None = None, ov_args: dict[str, Any] | None = None, ckpt_path: str | Path | None = None, ) -> Path | None: @@ -891,6 +893,9 @@ def export( datamodule (AnomalibDataModule | None, optional): Lightning datamodule. Must be provided if CompressionType.INT8_PTQ is selected. Defaults to ``None``. + metric (Metric | str | None, optional): Metric to measure quality loss when quantizing. + Must be provided if CompressionType.INT8_ACQ is selected. + Defaults to ``None``. ov_args (dict[str, Any] | None, optional): This is optional and used only for OpenVINO's model optimizer. Defaults to None. ckpt_path (str | Path | None): Checkpoint path. If provided, the model will be loaded from this path. @@ -954,6 +959,7 @@ def export( task=self.task, compression_type=compression_type, datamodule=datamodule, + metric=metric, ov_args=ov_args, ) else: diff --git a/src/anomalib/models/components/base/export_mixin.py b/src/anomalib/models/components/base/export_mixin.py index 9b0c2d41e2..7b5dfb7718 100644 --- a/src/anomalib/models/components/base/export_mixin.py +++ b/src/anomalib/models/components/base/export_mixin.py @@ -6,7 +6,7 @@ import json import logging -from collections.abc import Callable +from collections.abc import Callable, Iterable from pathlib import Path from tempfile import TemporaryDirectory from typing import TYPE_CHECKING, Any @@ -14,11 +14,13 @@ import numpy as np import torch from torch import nn +from torchmetrics import Metric from torchvision.transforms.v2 import Transform from anomalib import TaskType from anomalib.data import AnomalibDataModule from anomalib.deploy.export import CompressionType, ExportType, InferenceModel +from anomalib.metrics import create_metric_collection from anomalib.utils.exceptions import try_import if TYPE_CHECKING: @@ -159,6 +161,7 @@ def to_openvino( transform: Transform | None = None, compression_type: CompressionType | None = None, datamodule: AnomalibDataModule | None = None, + metric: Metric | str | None = None, ov_args: dict[str, Any] | None = None, task: TaskType | None = None, ) -> Path: @@ -176,6 +179,9 @@ def to_openvino( datamodule (AnomalibDataModule | None, optional): Lightning datamodule. Must be provided if CompressionType.INT8_PTQ is selected. Defaults to ``None``. + metric (Metric | str | None, optional): Metric to measure quality loss when quantizing. + Must be provided if CompressionType.INT8_ACQ is selected. + Defaults to ``None``. ov_args (dict | None): Model optimizer arguments for OpenVINO model conversion. Defaults to ``None``. task (TaskType | None): Task type. @@ -242,13 +248,43 @@ def to_openvino( msg = "Datamodule must be provided for OpenVINO INT8_PTQ compression" raise ValueError(msg) - dataloader = datamodule.val_dataloader() + dataloader = datamodule.train_dataloader() if len(dataloader.dataset) < 300: logger.warning( f">300 images recommended for INT8 quantization, found only {len(dataloader.dataset)} images", ) + calibration_dataset = nncf.Dataset(dataloader, lambda x: x["image"]) model = nncf.quantize(model, calibration_dataset) + elif compression_type == CompressionType.INT8_ACQ: + if datamodule is None: + msg = "Datamodule must be provided for OpenVINO INT8_PTQ compression" + raise ValueError(msg) + if metric is None: + msg = "Metric must be provided for OpenVINO INT8_ACQ compression" + raise ValueError(msg) + + dataloader = datamodule.train_dataloader() + if len(dataloader.dataset) < 300: + logger.warning( + f">300 images recommended for INT8 quantization, found only {len(dataloader.dataset)} images", + ) + + calibration_dataset = nncf.Dataset(dataloader, lambda x: x["image"]) + validation_dataset = nncf.Dataset(datamodule.val_dataloader()) + + if isinstance(metric, str): + metric = create_metric_collection([metric])[metric] + + # validation function to evaluate the quality loss after quantization + def val_fn(nncf_model: ov.CompiledModel, validation_data: Iterable) -> float: + for batch in validation_data: + preds = torch.from_numpy(nncf_model(batch["image"])[0]) + target = batch["mask"] + metric.update(preds, target) + return metric.compute() + + model = nncf.quantize_with_accuracy_control(model, calibration_dataset, validation_dataset, val_fn) # fp16 compression is enabled by default compress_to_fp16 = compression_type == CompressionType.FP16 From c277a1d4d760ded5f80dcca5f239eaadae3c447a Mon Sep 17 00:00:00 2001 From: Adrian Boguszewski Date: Mon, 20 May 2024 18:08:23 +0200 Subject: [PATCH 2/9] Fixed issues with static shape models Signed-off-by: Adrian Boguszewski --- src/anomalib/deploy/export.py | 2 +- .../models/components/base/export_mixin.py | 19 ++++++++++++------- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/src/anomalib/deploy/export.py b/src/anomalib/deploy/export.py index c53fb964dd..34edc1ac47 100644 --- a/src/anomalib/deploy/export.py +++ b/src/anomalib/deploy/export.py @@ -61,7 +61,7 @@ class CompressionType(str, Enum): """ Accuracy-control quantization (INT8) Weights and operations are quantized to INT8, except those that would degrade quality of the model more than is - acceptable. Inference is done in mixed precision. + acceptable. Inference is done in a mixed precision. """ diff --git a/src/anomalib/models/components/base/export_mixin.py b/src/anomalib/models/components/base/export_mixin.py index 7b5dfb7718..2534089a2a 100644 --- a/src/anomalib/models/components/base/export_mixin.py +++ b/src/anomalib/models/components/base/export_mixin.py @@ -180,7 +180,8 @@ def to_openvino( Must be provided if CompressionType.INT8_PTQ is selected. Defaults to ``None``. metric (Metric | str | None, optional): Metric to measure quality loss when quantizing. - Must be provided if CompressionType.INT8_ACQ is selected. + Must be provided if CompressionType.INT8_ACQ is selected and must return higher value for better + performance of the model. Defaults to ``None``. ov_args (dict | None): Model optimizer arguments for OpenVINO model conversion. Defaults to ``None``. @@ -224,11 +225,8 @@ def to_openvino( ... task="segmentation", ... ) """ - if not try_import("openvino"): - logger.exception("Could not find OpenVINO. Please check OpenVINO installation.") - raise ModuleNotFoundError - if not try_import("nncf"): - logger.exception("Could not find NNCF. Please check NNCF installation.") + if not try_import("openvino") or not try_import("nncf"): + logger.exception("Could not find OpenVINO or NCCF. Please check OpenVINO and NNCF installation.") raise ModuleNotFoundError import nncf @@ -241,6 +239,8 @@ def to_openvino( ov_args = {} if ov_args is None else ov_args model = ov.convert_model(model_path, **ov_args) + model_input = model.input(0) + if compression_type == CompressionType.INT8: model = nncf.compress_weights(model) elif compression_type == CompressionType.INT8_PTQ: @@ -248,6 +248,8 @@ def to_openvino( msg = "Datamodule must be provided for OpenVINO INT8_PTQ compression" raise ValueError(msg) + if model_input.partial_shape[0].is_static: + datamodule.train_batch_size = model_input.shape[0] dataloader = datamodule.train_dataloader() if len(dataloader.dataset) < 300: logger.warning( @@ -264,6 +266,9 @@ def to_openvino( msg = "Metric must be provided for OpenVINO INT8_ACQ compression" raise ValueError(msg) + if model_input.partial_shape[0].is_static: + datamodule.train_batch_size = model_input.shape[0] + datamodule.eval_batch_size = model_input.shape[0] dataloader = datamodule.train_dataloader() if len(dataloader.dataset) < 300: logger.warning( @@ -280,7 +285,7 @@ def to_openvino( def val_fn(nncf_model: ov.CompiledModel, validation_data: Iterable) -> float: for batch in validation_data: preds = torch.from_numpy(nncf_model(batch["image"])[0]) - target = batch["mask"] + target = batch["mask"][:, None, :, :] metric.update(preds, target) return metric.compute() From 57e778b1d14076d76890ce116773b291b0550ea3 Mon Sep 17 00:00:00 2001 From: Adrian Boguszewski Date: Mon, 20 May 2024 18:53:15 +0200 Subject: [PATCH 3/9] Support classification task Signed-off-by: Adrian Boguszewski --- src/anomalib/models/components/base/export_mixin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anomalib/models/components/base/export_mixin.py b/src/anomalib/models/components/base/export_mixin.py index 2534089a2a..d74314507b 100644 --- a/src/anomalib/models/components/base/export_mixin.py +++ b/src/anomalib/models/components/base/export_mixin.py @@ -285,7 +285,7 @@ def to_openvino( def val_fn(nncf_model: ov.CompiledModel, validation_data: Iterable) -> float: for batch in validation_data: preds = torch.from_numpy(nncf_model(batch["image"])[0]) - target = batch["mask"][:, None, :, :] + target = batch["mask"][:, None, :, :] if task == TaskType.SEGMENTATION else batch["label"] metric.update(preds, target) return metric.compute() From c0de302d9628aceaab3832834af342a0d054a9bb Mon Sep 17 00:00:00 2001 From: Adrian Boguszewski Date: Fri, 24 May 2024 14:28:00 +0200 Subject: [PATCH 4/9] Moved nncf quantization to a separate function Signed-off-by: Adrian Boguszewski --- src/anomalib/engine/engine.py | 5 +- .../models/components/base/export_mixin.py | 151 +++++++++++------- 2 files changed, 100 insertions(+), 56 deletions(-) diff --git a/src/anomalib/engine/engine.py b/src/anomalib/engine/engine.py index e8e5c6c40b..37aa7e634b 100644 --- a/src/anomalib/engine/engine.py +++ b/src/anomalib/engine/engine.py @@ -891,10 +891,11 @@ def export( compression_type (CompressionType | None, optional): Compression type for OpenVINO exporting only. Defaults to ``None``. datamodule (AnomalibDataModule | None, optional): Lightning datamodule. - Must be provided if CompressionType.INT8_PTQ is selected. + Must be provided if CompressionType.INT8_PTQ is selected (OpenVINO export only). Defaults to ``None``. metric (Metric | str | None, optional): Metric to measure quality loss when quantizing. - Must be provided if CompressionType.INT8_ACQ is selected. + Must be provided if CompressionType.INT8_ACQ is selected and must return higher value for better + performance of the model (OpenVINO export only). Defaults to ``None``. ov_args (dict[str, Any] | None, optional): This is optional and used only for OpenVINO's model optimizer. Defaults to None. diff --git a/src/anomalib/models/components/base/export_mixin.py b/src/anomalib/models/components/base/export_mixin.py index d74314507b..e29f451e43 100644 --- a/src/anomalib/models/components/base/export_mixin.py +++ b/src/anomalib/models/components/base/export_mixin.py @@ -24,8 +24,13 @@ from anomalib.utils.exceptions import try_import if TYPE_CHECKING: + from importlib.util import find_spec + from torch.types import Number + if find_spec("openvino") is not None: + from openvino import CompiledModel + logger = logging.getLogger(__name__) @@ -225,11 +230,10 @@ def to_openvino( ... task="segmentation", ... ) """ - if not try_import("openvino") or not try_import("nncf"): - logger.exception("Could not find OpenVINO or NCCF. Please check OpenVINO and NNCF installation.") + if not try_import("openvino"): + logger.exception("Could not find OpenVINO. Please check OpenVINO installation.") raise ModuleNotFoundError - import nncf import openvino as ov with TemporaryDirectory() as onnx_directory: @@ -239,57 +243,7 @@ def to_openvino( ov_args = {} if ov_args is None else ov_args model = ov.convert_model(model_path, **ov_args) - model_input = model.input(0) - - if compression_type == CompressionType.INT8: - model = nncf.compress_weights(model) - elif compression_type == CompressionType.INT8_PTQ: - if datamodule is None: - msg = "Datamodule must be provided for OpenVINO INT8_PTQ compression" - raise ValueError(msg) - - if model_input.partial_shape[0].is_static: - datamodule.train_batch_size = model_input.shape[0] - dataloader = datamodule.train_dataloader() - if len(dataloader.dataset) < 300: - logger.warning( - f">300 images recommended for INT8 quantization, found only {len(dataloader.dataset)} images", - ) - - calibration_dataset = nncf.Dataset(dataloader, lambda x: x["image"]) - model = nncf.quantize(model, calibration_dataset) - elif compression_type == CompressionType.INT8_ACQ: - if datamodule is None: - msg = "Datamodule must be provided for OpenVINO INT8_PTQ compression" - raise ValueError(msg) - if metric is None: - msg = "Metric must be provided for OpenVINO INT8_ACQ compression" - raise ValueError(msg) - - if model_input.partial_shape[0].is_static: - datamodule.train_batch_size = model_input.shape[0] - datamodule.eval_batch_size = model_input.shape[0] - dataloader = datamodule.train_dataloader() - if len(dataloader.dataset) < 300: - logger.warning( - f">300 images recommended for INT8 quantization, found only {len(dataloader.dataset)} images", - ) - - calibration_dataset = nncf.Dataset(dataloader, lambda x: x["image"]) - validation_dataset = nncf.Dataset(datamodule.val_dataloader()) - - if isinstance(metric, str): - metric = create_metric_collection([metric])[metric] - - # validation function to evaluate the quality loss after quantization - def val_fn(nncf_model: ov.CompiledModel, validation_data: Iterable) -> float: - for batch in validation_data: - preds = torch.from_numpy(nncf_model(batch["image"])[0]) - target = batch["mask"][:, None, :, :] if task == TaskType.SEGMENTATION else batch["label"] - metric.update(preds, target) - return metric.compute() - - model = nncf.quantize_with_accuracy_control(model, calibration_dataset, validation_dataset, val_fn) + model = self._compress_ov_model(model, compression_type, datamodule, metric, task) # fp16 compression is enabled by default compress_to_fp16 = compression_type == CompressionType.FP16 @@ -298,6 +252,95 @@ def val_fn(nncf_model: ov.CompiledModel, validation_data: Iterable) -> float: return ov_model_path + def _compress_ov_model( + self, + model: "CompiledModel", + compression_type: CompressionType | None = None, + datamodule: AnomalibDataModule | None = None, + metric: Metric | str | None = None, + task: TaskType | None = None, + ) -> "CompiledModel": + """Compress OpenVINO model with NNCF. + + model (CompiledModel): Model already exported to OpenVINO format. + compression_type (CompressionType, optional): Compression type for better inference performance. + Defaults to ``None``. + datamodule (AnomalibDataModule | None, optional): Lightning datamodule. + Must be provided if CompressionType.INT8_PTQ is selected. + Defaults to ``None``. + metric (Metric | str | None, optional): Metric to measure quality loss when quantizing. + Must be provided if CompressionType.INT8_ACQ is selected and must return higher value for better + performance of the model. + Defaults to ``None``. + task (TaskType | None): Task type. + Defaults to ``None``. + + Returns: + model (CompiledModel): Model in the OpenVINO format compressed with NNCF quantization. + """ + if not try_import("nncf"): + logger.exception("Could not find NCCF. Please check NNCF installation.") + raise ModuleNotFoundError + + import nncf + + model_input = model.input(0) + + # weights compression + if compression_type == CompressionType.INT8: + model = nncf.compress_weights(model) + # post-training quantization + elif compression_type == CompressionType.INT8_PTQ: + if datamodule is None: + msg = "Datamodule must be provided for OpenVINO INT8_PTQ compression" + raise ValueError(msg) + + if model_input.partial_shape[0].is_static: + datamodule.train_batch_size = model_input.shape[0] + dataloader = datamodule.train_dataloader() + if len(dataloader.dataset) < 300: + logger.warning( + f">300 images recommended for INT8 quantization, found only {len(dataloader.dataset)} images", + ) + + calibration_dataset = nncf.Dataset(dataloader, lambda x: x["image"]) + model = nncf.quantize(model, calibration_dataset) + # accuracy-control quantization + elif compression_type == CompressionType.INT8_ACQ: + if datamodule is None: + msg = "Datamodule must be provided for OpenVINO INT8_PTQ compression" + raise ValueError(msg) + if metric is None: + msg = "Metric must be provided for OpenVINO INT8_ACQ compression" + raise ValueError(msg) + + if model_input.partial_shape[0].is_static: + datamodule.train_batch_size = model_input.shape[0] + datamodule.eval_batch_size = model_input.shape[0] + dataloader = datamodule.train_dataloader() + if len(dataloader.dataset) < 300: + logger.warning( + f">300 images recommended for INT8 quantization, found only {len(dataloader.dataset)} images", + ) + + calibration_dataset = nncf.Dataset(dataloader, lambda x: x["image"]) + validation_dataset = nncf.Dataset(datamodule.val_dataloader()) + + if isinstance(metric, str): + metric = create_metric_collection([metric])[metric] + + # validation function to evaluate the quality loss after quantization + def val_fn(nncf_model: "CompiledModel", validation_data: Iterable) -> float: + for batch in validation_data: + preds = torch.from_numpy(nncf_model(batch["image"])[0]) + target = batch["label"] if task == TaskType.CLASSIFICATION else batch["mask"][:, None, :, :] + metric.update(preds, target) + return metric.compute() + + model = nncf.quantize_with_accuracy_control(model, calibration_dataset, validation_dataset, val_fn) + + return model + def _get_metadata( self, task: TaskType | None = None, From 7a2b0674612098b6f11d82f7dfb3e9b65165a788 Mon Sep 17 00:00:00 2001 From: Adrian Boguszewski Date: Mon, 3 Jun 2024 15:41:42 +0200 Subject: [PATCH 5/9] Improved code according to the review Signed-off-by: Adrian Boguszewski --- src/anomalib/deploy/export.py | 29 ++-- src/anomalib/engine/engine.py | 5 +- .../models/components/base/export_mixin.py | 161 ++++++++++++------ 3 files changed, 121 insertions(+), 74 deletions(-) diff --git a/src/anomalib/deploy/export.py b/src/anomalib/deploy/export.py index 34edc1ac47..aae359c035 100644 --- a/src/anomalib/deploy/export.py +++ b/src/anomalib/deploy/export.py @@ -36,6 +36,18 @@ class ExportType(str, Enum): class CompressionType(str, Enum): """Model compression type when exporting to OpenVINO. + Attributes: + FP16 (str): Weight compression (FP16). All weights are converted to FP16. + INT8 (str): Weight compression (INT8). All weights are quantized to INT8, + but are dequantized to floating point before inference. + INT8_PTQ (str): Full integer post-training quantization (INT8). + All weights and operations are quantized to INT8. Inference is done + in INT8 precision. + INT8_ACQ (str): Accuracy-control quantization (INT8). Weights and + operations are quantized to INT8, except those that would degrade + quality of the model more than is acceptable. Inference is done in + a mixed precision. + Examples: >>> from anomalib.deploy import CompressionType >>> CompressionType.INT8_PTQ @@ -43,26 +55,9 @@ class CompressionType(str, Enum): """ FP16 = "fp16" - """ - Weight compression (FP16) - All weights are converted to FP16. - """ INT8 = "int8" - """ - Weight compression (INT8) - All weights are quantized to INT8, but are dequantized to floating point before inference. - """ INT8_PTQ = "int8_ptq" - """ - Full integer post-training quantization (INT8) - All weights and operations are quantized to INT8. Inference is done in INT8 precision. - """ INT8_ACQ = "int8_acq" - """ - Accuracy-control quantization (INT8) - Weights and operations are quantized to INT8, except those that would degrade quality of the model more than is - acceptable. Inference is done in a mixed precision. - """ class InferenceModel(nn.Module): diff --git a/src/anomalib/engine/engine.py b/src/anomalib/engine/engine.py index b3f5e9f0bf..f0c2fdfe23 100644 --- a/src/anomalib/engine/engine.py +++ b/src/anomalib/engine/engine.py @@ -891,10 +891,11 @@ def export( compression_type (CompressionType | None, optional): Compression type for OpenVINO exporting only. Defaults to ``None``. datamodule (AnomalibDataModule | None, optional): Lightning datamodule. - Must be provided if CompressionType.INT8_PTQ is selected (OpenVINO export only). + Must be provided if ``CompressionType.INT8_PTQ`` or `CompressionType.INT8_ACQ`` is selected + (OpenVINO export only). Defaults to ``None``. metric (Metric | str | None, optional): Metric to measure quality loss when quantizing. - Must be provided if CompressionType.INT8_ACQ is selected and must return higher value for better + Must be provided if ``CompressionType.INT8_ACQ`` is selected and must return higher value for better performance of the model (OpenVINO export only). Defaults to ``None``. ov_args (dict[str, Any] | None, optional): This is optional and used only for OpenVINO's model optimizer. diff --git a/src/anomalib/models/components/base/export_mixin.py b/src/anomalib/models/components/base/export_mixin.py index e29f451e43..49143172f5 100644 --- a/src/anomalib/models/components/base/export_mixin.py +++ b/src/anomalib/models/components/base/export_mixin.py @@ -182,10 +182,10 @@ def to_openvino( compression_type (CompressionType, optional): Compression type for better inference performance. Defaults to ``None``. datamodule (AnomalibDataModule | None, optional): Lightning datamodule. - Must be provided if CompressionType.INT8_PTQ is selected. + Must be provided if ``CompressionType.INT8_PTQ`` or ``CompressionType.INT8_ACQ`` is selected. Defaults to ``None``. metric (Metric | str | None, optional): Metric to measure quality loss when quantizing. - Must be provided if CompressionType.INT8_ACQ is selected and must return higher value for better + Must be provided if ``CompressionType.INT8_ACQ`` is selected and must return higher value for better performance of the model. Defaults to ``None``. ov_args (dict | None): Model optimizer arguments for OpenVINO model conversion. @@ -243,7 +243,8 @@ def to_openvino( ov_args = {} if ov_args is None else ov_args model = ov.convert_model(model_path, **ov_args) - model = self._compress_ov_model(model, compression_type, datamodule, metric, task) + if compression_type != CompressionType.FP16: + model = self._compress_ov_model(model, compression_type, datamodule, metric, task) # fp16 compression is enabled by default compress_to_fp16 = compression_type == CompressionType.FP16 @@ -266,10 +267,10 @@ def _compress_ov_model( compression_type (CompressionType, optional): Compression type for better inference performance. Defaults to ``None``. datamodule (AnomalibDataModule | None, optional): Lightning datamodule. - Must be provided if CompressionType.INT8_PTQ is selected. + Must be provided if ``CompressionType.INT8_PTQ`` or ``CompressionType.INT8_ACQ`` is selected. Defaults to ``None``. metric (Metric | str | None, optional): Metric to measure quality loss when quantizing. - Must be provided if CompressionType.INT8_ACQ is selected and must return higher value for better + Must be provided if ``CompressionType.INT8_ACQ`` is selected and must return higher value for better performance of the model. Defaults to ``None``. task (TaskType | None): Task type. @@ -284,63 +285,113 @@ def _compress_ov_model( import nncf - model_input = model.input(0) - - # weights compression if compression_type == CompressionType.INT8: model = nncf.compress_weights(model) - # post-training quantization elif compression_type == CompressionType.INT8_PTQ: - if datamodule is None: - msg = "Datamodule must be provided for OpenVINO INT8_PTQ compression" - raise ValueError(msg) - - if model_input.partial_shape[0].is_static: - datamodule.train_batch_size = model_input.shape[0] - dataloader = datamodule.train_dataloader() - if len(dataloader.dataset) < 300: - logger.warning( - f">300 images recommended for INT8 quantization, found only {len(dataloader.dataset)} images", - ) - - calibration_dataset = nncf.Dataset(dataloader, lambda x: x["image"]) - model = nncf.quantize(model, calibration_dataset) - # accuracy-control quantization + model = self._post_training_quantization_ov(model, datamodule) elif compression_type == CompressionType.INT8_ACQ: - if datamodule is None: - msg = "Datamodule must be provided for OpenVINO INT8_PTQ compression" - raise ValueError(msg) - if metric is None: - msg = "Metric must be provided for OpenVINO INT8_ACQ compression" - raise ValueError(msg) - - if model_input.partial_shape[0].is_static: - datamodule.train_batch_size = model_input.shape[0] - datamodule.eval_batch_size = model_input.shape[0] - dataloader = datamodule.train_dataloader() - if len(dataloader.dataset) < 300: - logger.warning( - f">300 images recommended for INT8 quantization, found only {len(dataloader.dataset)} images", - ) - - calibration_dataset = nncf.Dataset(dataloader, lambda x: x["image"]) - validation_dataset = nncf.Dataset(datamodule.val_dataloader()) - - if isinstance(metric, str): - metric = create_metric_collection([metric])[metric] - - # validation function to evaluate the quality loss after quantization - def val_fn(nncf_model: "CompiledModel", validation_data: Iterable) -> float: - for batch in validation_data: - preds = torch.from_numpy(nncf_model(batch["image"])[0]) - target = batch["label"] if task == TaskType.CLASSIFICATION else batch["mask"][:, None, :, :] - metric.update(preds, target) - return metric.compute() - - model = nncf.quantize_with_accuracy_control(model, calibration_dataset, validation_dataset, val_fn) + model = self._accuracy_control_quantization_ov(model, datamodule, metric, task) + else: + msg = f"Unrecognized compression type: {compression_type}" + raise ValueError(msg) return model + def _post_training_quantization_ov( + self, + model: "CompiledModel", + datamodule: AnomalibDataModule | None = None, + ) -> "CompiledModel": + """Post-Training Quantization model with NNCF. + + model (CompiledModel): Model already exported to OpenVINO format. + datamodule (AnomalibDataModule | None, optional): Lightning datamodule. + Must be provided if ``CompressionType.INT8_PTQ`` or ``CompressionType.INT8_ACQ`` is selected. + Defaults to ``None``. + + Returns: + model (CompiledModel): Quantized model. + """ + import nncf + + if datamodule is None: + msg = "Datamodule must be provided for OpenVINO INT8_PTQ compression" + raise ValueError(msg) + + model_input = model.input(0) + + if model_input.partial_shape[0].is_static: + datamodule.train_batch_size = model_input.shape[0] + + dataloader = datamodule.train_dataloader() + if len(dataloader.dataset) < 300: + logger.warning( + f">300 images recommended for INT8 quantization, found only {len(dataloader.dataset)} images", + ) + + calibration_dataset = nncf.Dataset(dataloader, lambda x: x["image"]) + return nncf.quantize(model, calibration_dataset) + + def _accuracy_control_quantization_ov( + self, + model: "CompiledModel", + datamodule: AnomalibDataModule | None = None, + metric: Metric | str | None = None, + task: TaskType | None = None, + ) -> "CompiledModel": + """Accuracy-Control Quantization with NNCF. + + model (CompiledModel): Model already exported to OpenVINO format. + datamodule (AnomalibDataModule | None, optional): Lightning datamodule. + Must be provided if ``CompressionType.INT8_PTQ`` or ``CompressionType.INT8_ACQ`` is selected. + Defaults to ``None``. + metric (Metric | str | None, optional): Metric to measure quality loss when quantizing. + Must be provided if ``CompressionType.INT8_ACQ`` is selected and must return higher value for better + performance of the model. + Defaults to ``None``. + task (TaskType | None): Task type. + Defaults to ``None``. + + Returns: + model (CompiledModel): Quantized model. + """ + import nncf + + if datamodule is None: + msg = "Datamodule must be provided for OpenVINO INT8_PTQ compression" + raise ValueError(msg) + if metric is None: + msg = "Metric must be provided for OpenVINO INT8_ACQ compression" + raise ValueError(msg) + + model_input = model.input(0) + + if model_input.partial_shape[0].is_static: + datamodule.train_batch_size = model_input.shape[0] + datamodule.eval_batch_size = model_input.shape[0] + + dataloader = datamodule.train_dataloader() + if len(dataloader.dataset) < 300: + logger.warning( + f">300 images recommended for INT8 quantization, found only {len(dataloader.dataset)} images", + ) + + calibration_dataset = nncf.Dataset(dataloader, lambda x: x["image"]) + validation_dataset = nncf.Dataset(datamodule.val_dataloader()) + + if isinstance(metric, str): + metric = create_metric_collection([metric])[metric] + + # validation function to evaluate the quality loss after quantization + def val_fn(nncf_model: "CompiledModel", validation_data: Iterable) -> float: + for batch in validation_data: + preds = torch.from_numpy(nncf_model(batch["image"])[0]) + target = batch["label"] if task == TaskType.CLASSIFICATION else batch["mask"][:, None, :, :] + metric.update(preds, target) + return metric.compute() + + return nncf.quantize_with_accuracy_control(model, calibration_dataset, validation_dataset, val_fn) + def _get_metadata( self, task: TaskType | None = None, From fd866b51b6f141e88a7caa0f2be1a5a89a861ec8 Mon Sep 17 00:00:00 2001 From: Adrian Boguszewski Date: Mon, 3 Jun 2024 15:53:58 +0200 Subject: [PATCH 6/9] Added usage examples Signed-off-by: Adrian Boguszewski --- src/anomalib/engine/engine.py | 6 +++--- .../models/components/base/export_mixin.py | 16 ++++++++++++++++ 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/src/anomalib/engine/engine.py b/src/anomalib/engine/engine.py index f0c2fdfe23..05b1d1d6af 100644 --- a/src/anomalib/engine/engine.py +++ b/src/anomalib/engine/engine.py @@ -922,12 +922,12 @@ def export( 3. To export as an OpenVINO ``.xml`` and ``.bin`` file you can run the following command. ```python anomalib export --model Padim --export_mode openvino --ckpt_path \ - --input_size "[256,256]" + --input_size "[256,256] --compression_type "fp16" ``` - 4. You can also override OpenVINO model optimizer by adding the ``--ov_args.`` arguments. + 4. You can also quantize OpenVINO model with the following. ```python anomalib export --model Padim --export_mode openvino --ckpt_path \ - --input_size "[256,256]" --ov_args.compress_to_fp16 False + --input_size "[256,256]" --compression_type "int8_ptq" --data MVTec ``` """ export_type = ExportType(export_type) diff --git a/src/anomalib/models/components/base/export_mixin.py b/src/anomalib/models/components/base/export_mixin.py index 49143172f5..a1a00164cd 100644 --- a/src/anomalib/models/components/base/export_mixin.py +++ b/src/anomalib/models/components/base/export_mixin.py @@ -218,6 +218,22 @@ def to_openvino( ... task=datamodule.test_data.task ... ) + Export and Quantize the Model (OpenVINO IR): + This example demonstrates how to export and quantize the model to OpenVINO IR. + + >>> from anomalib.models import Patchcore + >>> from anomalib.data import Visa + ... + >>> datamodule = Visa() + >>> model = Patchcore() + ... + >>> model.to_openvino( + ... export_root="path/to/export", + ... compression_type=CompressionType.INT8_PTQ, + ... datamodule=datamodule, + ... task=datamodule.test_data.task + ... ) + Using Custom Transforms: This example shows how to use a custom ``Transform`` object for the ``transform`` argument. From 17dfb293b1b935b0fa1e44d55976e43ecb4b9aeb Mon Sep 17 00:00:00 2001 From: Samet Akcay Date: Mon, 3 Jun 2024 15:10:55 +0100 Subject: [PATCH 7/9] Update src/anomalib/models/components/base/export_mixin.py --- src/anomalib/models/components/base/export_mixin.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/anomalib/models/components/base/export_mixin.py b/src/anomalib/models/components/base/export_mixin.py index a1a00164cd..fb122c9853 100644 --- a/src/anomalib/models/components/base/export_mixin.py +++ b/src/anomalib/models/components/base/export_mixin.py @@ -223,7 +223,6 @@ def to_openvino( >>> from anomalib.models import Patchcore >>> from anomalib.data import Visa - ... >>> datamodule = Visa() >>> model = Patchcore() ... From 4512dbdcb90019b579bdb4185cc1c092f93650f1 Mon Sep 17 00:00:00 2001 From: Samet Akcay Date: Mon, 3 Jun 2024 15:11:01 +0100 Subject: [PATCH 8/9] Update src/anomalib/models/components/base/export_mixin.py --- src/anomalib/models/components/base/export_mixin.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/anomalib/models/components/base/export_mixin.py b/src/anomalib/models/components/base/export_mixin.py index fb122c9853..35aec78e1a 100644 --- a/src/anomalib/models/components/base/export_mixin.py +++ b/src/anomalib/models/components/base/export_mixin.py @@ -225,7 +225,6 @@ def to_openvino( >>> from anomalib.data import Visa >>> datamodule = Visa() >>> model = Patchcore() - ... >>> model.to_openvino( ... export_root="path/to/export", ... compression_type=CompressionType.INT8_PTQ, From 7bf3f2ca02ff5ed343bb46fde9ae4573855397c2 Mon Sep 17 00:00:00 2001 From: Adrian Boguszewski Date: Mon, 3 Jun 2024 16:54:00 +0200 Subject: [PATCH 9/9] Update src/anomalib/models/components/base/export_mixin.py --- src/anomalib/models/components/base/export_mixin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anomalib/models/components/base/export_mixin.py b/src/anomalib/models/components/base/export_mixin.py index 35aec78e1a..3d6f5088da 100644 --- a/src/anomalib/models/components/base/export_mixin.py +++ b/src/anomalib/models/components/base/export_mixin.py @@ -257,7 +257,7 @@ def to_openvino( ov_args = {} if ov_args is None else ov_args model = ov.convert_model(model_path, **ov_args) - if compression_type != CompressionType.FP16: + if compression_type and compression_type != CompressionType.FP16: model = self._compress_ov_model(model, compression_type, datamodule, metric, task) # fp16 compression is enabled by default