diff --git a/nncf/experimental/torch/sparsify_activations/ema_aggregator.py b/nncf/experimental/torch/sparsify_activations/ema_aggregator.py new file mode 100644 index 00000000000..cf90a47ab83 --- /dev/null +++ b/nncf/experimental/torch/sparsify_activations/ema_aggregator.py @@ -0,0 +1,39 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import nncf.tensor.functions as fns +from nncf.experimental.common.tensor_statistics.collectors import AggregationAxes +from nncf.experimental.common.tensor_statistics.collectors import OnlineAggregatorBase +from nncf.tensor import Tensor + + +# TODO: add tests +class EMAAggregator(OnlineAggregatorBase): + def __init__( + self, + alpha: float, + num_samples: Optional[int] = None, + window_size: Optional[int] = None, + ): + self._alpha = alpha + super().__init__(aggregation_axes=(0,), num_samples=num_samples, window_size=window_size) + + def _aggregation_fn(self, stacked_value: Tensor, axis: AggregationAxes, keepdims: bool) -> Tensor: + if self._collected_samples == 0: + return stacked_value + else: + beta = 1.0 - self._alpha + new_value = fns.expand_dims(stacked_value[0], 0) + old_value = fns.expand_dims(stacked_value[1], 0) + return new_value * self._alpha + old_value * beta * (1 - beta**self._collected_samples) / ( + 1 - beta ** (self._collected_samples + 1) + ) diff --git a/nncf/experimental/torch/sparsify_activations/openvino_backend.py b/nncf/experimental/torch/sparsify_activations/openvino_backend.py new file mode 100644 index 00000000000..f6a1c3dcb99 --- /dev/null +++ b/nncf/experimental/torch/sparsify_activations/openvino_backend.py @@ -0,0 +1,103 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Optional, Type, Union + +import openvino.runtime +from openvino.runtime import opset13 as opset + +from nncf.common.graph.graph import NNCFGraph +from nncf.common.graph.graph import NNCFNode +from nncf.common.graph.operator_metatypes import OperatorMetatype +from nncf.common.graph.transformations.commands import TargetType +from nncf.experimental.torch.sparsify_activations.sparsify_activations_impl import SparsifyActivationsAlgoBackend +from nncf.openvino.graph.metatypes import openvino_metatypes as om +from nncf.openvino.graph.model_transformer import OVModelTransformer +from nncf.openvino.graph.transformations.commands import OVTargetPoint +from nncf.openvino.statistics.collectors import OVAbsQuantileReducer +from nncf.torch.nncf_network import NNCFNetwork + +ACTIVATIONS_SPARSIFIER_PREFIX = "activations_sparsifier" + + +class OVSparsifyActivationsAlgoBackend(SparsifyActivationsAlgoBackend): + """ + OpenVINO backend for the activation sparsification algorithm. + """ + + @property + def supported_metatypes(self) -> List[Type[OperatorMetatype]]: + return [om.OVMatMulMetatype] + + def abs_quantile_reducer(self, quantile: Optional[Union[float, List[float]]] = None) -> OVAbsQuantileReducer: + return OVAbsQuantileReducer(quantile=quantile) + + def target_point(self, target_type: TargetType, target_node_name: str, port_id: int) -> OVTargetPoint: + return OVTargetPoint(TargetType.PRE_LAYER_OPERATION, target_node_name, port_id=port_id) + + def insert_sparsifiers( + self, + model: openvino.Model, + graph: NNCFGraph, + threshold_by_node: Dict[NNCFNode, float], + ) -> NNCFNetwork: + name_to_node_mapping = OVModelTransformer._get_name_to_node_mapping(model) + for nncf_node, threshold in threshold_by_node.items(): + activation_port_id = self.get_activation_port_id(nncf_node, graph) + matmul_node = name_to_node_mapping[nncf_node.node_name] + dense_activation = matmul_node.input(activation_port_id).get_source_output().get_node() + + dtype = dense_activation.get_element_type() + threshold_const = opset.constant(threshold, dtype=dtype, name=f"{matmul_node.name}/sparsity_threshold") + zero_const = opset.constant(0.0, dtype=dtype) + + less_mask = opset.less_equal(opset.abs(dense_activation), threshold_const) + sparse_activation = opset.select( + less_mask, zero_const, dense_activation, name=f"{matmul_node.name}/sparse_input" + ) + matmul_node.input(activation_port_id).replace_source_output(sparse_activation.output(0)) + + return model + + @staticmethod + def get_activation_port_id(matmul_node: NNCFNode, nncf_graph: NNCFGraph) -> int: + return 0 + n_inputs = len(nncf_graph.get_input_edges(matmul_node)) + if n_inputs != 2: + raise RuntimeError(f"Expected node to have two inputs, but found {n_inputs} for node {matmul_node}.") + + is_const_node_on_port = [ + nncf_graph.get_input_edges(matmul_node)[i].from_node.node_type == "Constant" for i in range(2) + ] + if is_const_node_on_port[0] != is_const_node_on_port[1]: + assert not is_const_node_on_port[0], matmul_node.node_name + return 1 if is_const_node_on_port[0] else 0 + + # Try to match compressed constant subgraph + for i in range(2): + node = nncf_graph.get_input_edges(matmul_node)[i].from_node + if node.node_type == "Convert": + node = nncf_graph.get_input_edges(node)[0].from_node + if node.node_type == "Reshape": + node = nncf_graph.get_input_edges(node)[0].from_node + if node.node_type == "Multiply": + node = nncf_graph.get_input_edges(node)[0].from_node + if node.node_type == "Subtract": + node = nncf_graph.get_input_edges(node)[0].from_node + if node.node_type == "Convert": + node = nncf_graph.get_input_edges(node)[0].from_node + else: + continue + if node.node_type == "Constant": + assert i == 1, matmul_node.node_name + return int(i == 0) + + raise RuntimeError(f"Could not find activation port id for node {matmul_node}.") diff --git a/nncf/experimental/torch/sparsify_activations/sparsify_activations_impl.py b/nncf/experimental/torch/sparsify_activations/sparsify_activations_impl.py index 83a7a418911..750b8715176 100644 --- a/nncf/experimental/torch/sparsify_activations/sparsify_activations_impl.py +++ b/nncf/experimental/torch/sparsify_activations/sparsify_activations_impl.py @@ -11,19 +11,25 @@ from abc import ABC from abc import abstractmethod -from typing import Dict, List, Optional, Type, TypeVar +from typing import Dict, List, Optional, Type, TypeVar, Union import nncf -from nncf.common import factory from nncf.common.factory import NNCFGraphFactory +from nncf.common.factory import StatisticsAggregatorFactory from nncf.common.graph.graph import NNCFGraph from nncf.common.graph.graph import NNCFNode from nncf.common.graph.operator_metatypes import OperatorMetatype -from nncf.common.logging.track_progress import track +from nncf.common.graph.transformations.commands import TargetPoint +from nncf.common.graph.transformations.commands import TargetType from nncf.common.scopes import should_consider_scope +from nncf.common.tensor_statistics.statistic_point import StatisticPoint +from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.common.utils.backend import BackendType from nncf.common.utils.backend import get_backend from nncf.data import Dataset +from nncf.experimental.common.tensor_statistics.collectors import AbsQuantileReducer +from nncf.experimental.common.tensor_statistics.collectors import TensorCollector +from nncf.experimental.torch.sparsify_activations.ema_aggregator import EMAAggregator from nncf.experimental.torch.sparsify_activations.target_scope import TargetScope from nncf.experimental.torch.sparsify_activations.target_scope import get_target_node_names_from_target_scope from nncf.scopes import IgnoredScope @@ -32,6 +38,8 @@ from nncf.torch.model_creation import wrap_model TModel = TypeVar("TModel") +STATISTIC_BRANCH_KEY = "abs_quantile" +ALGORITHM_KEY = "AS" class SparsifyActivationsAlgoBackend(ABC): @@ -39,24 +47,6 @@ class SparsifyActivationsAlgoBackend(ABC): Abstract class for activation sparsification algorithm backend. """ - CALIBRATION_TRACKING_DESC = "Conducting Activations Sparsifier Calibration" - - @staticmethod - def do_inference(model: TModel, dataset: Dataset): - """ - Conducts model inference on given dataset to calibrate the activation sparsifiers. - - :param model: The model with activation sparsifiers. - :param dataset: The calibration dataset to update the sparsifiers. - """ - engine = factory.EngineFactory.create(model) - for input_data in track( - dataset.get_inference_data(), - total=dataset.get_length(), - description=SparsifyActivationsAlgoBackend.CALIBRATION_TRACKING_DESC, - ): - engine.infer(input_data) - @property @abstractmethod def supported_metatypes(self) -> List[Type[OperatorMetatype]]: @@ -64,12 +54,20 @@ def supported_metatypes(self) -> List[Type[OperatorMetatype]]: Property for the backend-specific metatypes for supported layers. """ + @abstractmethod + def abs_quantile_reducer(self, quantile: Optional[Union[float, List[float]]] = None) -> AbsQuantileReducer: + """ """ + + @abstractmethod + def target_point(self, target_type: TargetType, target_node_name: str, port_id: int) -> TargetPoint: + """ """ + @abstractmethod def insert_sparsifiers( self, model: TModel, graph: NNCFGraph, - target_sparsity_by_node: Dict[NNCFNode, float], + threshold_by_node: Dict[NNCFNode, float], ) -> TModel: """ Inserts the activation sparsifiers to the model. @@ -80,15 +78,15 @@ def insert_sparsifiers( :return: The model with inserted activation sparsifiers. """ + @staticmethod @abstractmethod - def calibrate_sparsifiers(self, model: TModel, graph: NNCFGraph, dataset: Dataset) -> TModel: + def get_activation_port_id(node: NNCFNode, graph: NNCFGraph) -> int: """ - Calibrates the thresholds in the activation sparsifiers. + Finds the input activation port id for the node. - :param model: The model with inserted activation sparsifiers. - :param graph: The model's NNCF graph. - :param dataset: The calibration dataset to update the thresholds in the sparsifiers. - :return: The model with calibrated activation sparsifiers. + :param node: The node to find its activation port id. + :param graph: The NNCF graph containing the node. + :return: The activation port id. """ @@ -116,7 +114,7 @@ def available_backends(self) -> List[BackendType]: """ Supported backends for this algorithm. """ - return [BackendType.TORCH] + return [BackendType.TORCH, BackendType.OPENVINO] def apply( self, @@ -134,30 +132,10 @@ def apply( """ self._set_backend_entity(model) target_sparsity_by_node = self._get_target_sparsity_by_node(graph) - sparse_model = self.do_sparsification(model, graph, target_sparsity_by_node, dataset) + threshold_by_node = self._get_threshold_by_node(model, graph, target_sparsity_by_node, dataset) + sparse_model = self._backend_entity.insert_sparsifiers(model, graph, threshold_by_node) return sparse_model - def do_sparsification( - self, - model: TModel, - graph: NNCFGraph, - target_sparsity_by_node: Dict[NNCFNode, float], - dataset: Dataset, - ): - """ - Transforms the model into a sparsified one with node-specific target activation sparsity levels. - - :param model: The model to be sparsified. - :param graph: The model's NNCF graph. - :param target_sparsity_by_node: A dictionary that defines the target sparsity level - for specified node layers. - :param dataset: The dataset to calibrate the activation sparsifiers. - :return: The sparsified model. - """ - model = self._backend_entity.insert_sparsifiers(model, graph, target_sparsity_by_node) - model = self._backend_entity.calibrate_sparsifiers(model, graph, dataset) - return model - def _set_backend_entity(self, model: TModel) -> None: """ Creates a helper class with a backend-specific logic of the algorithm. @@ -169,6 +147,10 @@ def _set_backend_entity(self, model: TModel) -> None: from nncf.experimental.torch.sparsify_activations.torch_backend import PTSparsifyActivationsAlgoBackend self._backend_entity = PTSparsifyActivationsAlgoBackend() + elif model_backend == BackendType.OPENVINO: + from nncf.experimental.torch.sparsify_activations.openvino_backend import OVSparsifyActivationsAlgoBackend + + self._backend_entity = OVSparsifyActivationsAlgoBackend() else: raise nncf.UnsupportedBackendError( f"{model_backend.value} backend is not supported for `sparsify_activations`." @@ -203,6 +185,46 @@ def _get_target_sparsity_by_node(self, graph: NNCFGraph) -> Dict[NNCFNode, float raise nncf.ValidationError("No layers to conduct activation sparsification.") return target_sparsity_by_node + def _get_threshold_by_node(self, model, graph, target_sparsity_by_node, dataset: Dataset) -> Dict[NNCFNode, float]: + statistic_points_container = StatisticPointsContainer() + for node, sparsity in target_sparsity_by_node.items(): + stat_collector = TensorCollector() + stat_collector.register_statistic_branch( + container_key=STATISTIC_BRANCH_KEY, + reducer=self._backend_entity.abs_quantile_reducer( + quantile=[ + sparsity, + ] + ), + aggregator=EMAAggregator(alpha=0.2), + ) + activation_port_id = self._backend_entity.get_activation_port_id(node, graph) + statistic_point = StatisticPoint( + target_point=self._backend_entity.target_point( + TargetType.PRE_LAYER_OPERATION, node.node_name, port_id=activation_port_id + ), + tensor_collector=stat_collector, + algorithm=ALGORITHM_KEY, + ) + statistic_points_container.add_statistic_point(statistic_point) + + statistics_aggregator = StatisticsAggregatorFactory.create(model, dataset) + statistics_aggregator.register_statistic_points(statistic_points_container) + statistics_aggregator.collect_statistics(model, graph) + + threshold_by_node = {} + for nncf_node in target_sparsity_by_node: + tensor_collector = next( + iter( + statistic_points_container.get_algo_statistics_for_node( + nncf_node.node_name, lambda args: True, ALGORITHM_KEY + ) + ) + ) + threshold_by_node[nncf_node] = tensor_collector.get_statistics()[STATISTIC_BRANCH_KEY].data + + return threshold_by_node + def sparsify_activations( model: TModel, diff --git a/nncf/experimental/torch/sparsify_activations/torch_backend.py b/nncf/experimental/torch/sparsify_activations/torch_backend.py index a10f12c6518..87478db0ede 100644 --- a/nncf/experimental/torch/sparsify_activations/torch_backend.py +++ b/nncf/experimental/torch/sparsify_activations/torch_backend.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Type, TypeVar +from typing import Dict, List, Optional, Type, Union import torch import torch.nn as nn @@ -20,19 +20,16 @@ from nncf.common.graph.operator_metatypes import CONST_NOOP_METATYPES from nncf.common.graph.operator_metatypes import OperatorMetatype from nncf.common.graph.transformations.commands import TargetType -from nncf.data import Dataset +from nncf.experimental.common.tensor_statistics.collectors import AbsQuantileReducer from nncf.experimental.torch.sparsify_activations.sparsify_activations_impl import SparsifyActivationsAlgoBackend -from nncf.tensor.functions.torch_numeric import quantile from nncf.torch.graph import operator_metatypes as om from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand from nncf.torch.graph.transformations.commands import PTTargetPoint from nncf.torch.graph.transformations.layout import PTTransformationLayout from nncf.torch.model_transformer import PTModelTransformer from nncf.torch.nncf_network import NNCFNetwork -from nncf.torch.utils import training_mode_switcher ACTIVATIONS_SPARSIFIER_PREFIX = "activations_sparsifier" -TModel = TypeVar("TModel") class ActivationsSparsifier(nn.Module): @@ -40,82 +37,19 @@ class ActivationsSparsifier(nn.Module): Sparsifies input activations by masking out values around zero. """ - def __init__(self, target_sparsity: float, alpha: float = 0.2): - """ - :param target_sparsity: The target activation sparsity level. - :param alpha: The exponential moving average decay factor in range (0, 1) for calibrating - the threshold. A larger alpha will give more weight to the most recent batches. - """ + def __init__(self, threshold: float): + """ """ super().__init__() - self.target_sparsity = target_sparsity - if alpha <= 0.0 or alpha >= 1.0: - raise ValueError("The decay factor `alpha` should be in range (0, 1).") - self.alpha = alpha - self.register_buffer("running_threshold", torch.tensor(float("-inf"))) - self.register_buffer("num_batches_tracked", torch.tensor(0)) - self.running_threshold: torch.Tensor - self.num_batches_tracked: torch.Tensor - self._freeze = True - - @staticmethod - def calculate_threshold(x: torch.Tensor, target_sparsity: float) -> torch.Tensor: - """ - Calculates the threshold to sparsify the input tensor with target sparsity if locations of - `x.abs() <= threshold` are zeroed out. - - :param x: The input tensor. - :param target_sparsity: The target sparsity level on the input tensor. - :return: The threshold value. - """ - return quantile(x.detach().abs().view(-1), q=target_sparsity, axis=0) - - @property - def freeze(self): - return self._freeze - - @freeze.setter - def freeze(self, value: bool): - self._freeze = value + self.register_buffer("threshold", torch.tensor(threshold, dtype=torch.float32)) + self.threshold: torch.Tensor def forward(self, x: torch.Tensor) -> torch.Tensor: - if not self.freeze: - threshold = self.calculate_threshold(x, self.target_sparsity) - self._update(threshold, dtype=x.dtype) - mask = torch.le(x.abs(), self.running_threshold) + mask = torch.le(x.abs(), self.threshold) x = torch.masked_fill(x, mask, 0.0) return x - def reset_running_stats(self): - """ - Resets the running threshold and the number of tracked batches to the initial stage. - """ - self.running_threshold.fill_(float("-inf")) - self.num_batches_tracked.zero_() - def extra_repr(self) -> str: - return f"target_sparsity={self.target_sparsity}" - - def _update(self, threshold: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: - """ - Updates the running threshold by exponential moving average with decaying adjustment. - The updating logic is similar to `pandas.DataFrame.ewm(adjust=True)`. - - :param threshold: The threshold value derived from this batch to update the running threshold. - :param dtype: Data type of the updated running threshold. - :return: The updated running threshold. - """ - if self.num_batches_tracked == 0: - running_threshold = threshold - else: - beta = 1.0 - self.alpha - old_running_threshold = self.running_threshold.to(device=threshold.device, dtype=torch.float64) - running_threshold = ( - threshold.to(torch.float64) * self.alpha - + old_running_threshold * beta * (1 - beta**self.num_batches_tracked) - ) / (1 - beta ** (self.num_batches_tracked + 1)) - self.running_threshold = running_threshold.type(dtype) - self.num_batches_tracked += 1 - return self.running_threshold + return f"target_sparsity={self.threshold}" class PTSparsifyActivationsAlgoBackend(SparsifyActivationsAlgoBackend): @@ -123,32 +57,26 @@ class PTSparsifyActivationsAlgoBackend(SparsifyActivationsAlgoBackend): Torch backend for the activation sparsification algorithm. """ - SUPPORTED_METATYPES = [om.PTLinearMetatype] - - @staticmethod - def get_sparsifiers(model: NNCFNetwork) -> List[ActivationsSparsifier]: - """ - Finds all the activation sparsifiers in the model. - - :param model: The model with activation sparsifiers. - :return: List of activation sparsifiers. - """ - return [m for m in model.nncf.modules() if isinstance(m, ActivationsSparsifier)] - @property def supported_metatypes(self) -> List[Type[OperatorMetatype]]: - return PTSparsifyActivationsAlgoBackend.SUPPORTED_METATYPES + return [om.PTLinearMetatype] + + def abs_quantile_reducer(self, quantile: Optional[Union[float, List[float]]] = None) -> AbsQuantileReducer: + return AbsQuantileReducer(quantile=quantile) + + def target_point(self, target_type: TargetType, target_node_name: str, port_id: int) -> PTTargetPoint: + return PTTargetPoint(TargetType.PRE_LAYER_OPERATION, target_node_name, input_port_id=port_id) def insert_sparsifiers( self, model: NNCFNetwork, graph: NNCFGraph, - target_sparsity_by_node: Dict[NNCFNode, float], + threshold_by_node: Dict[NNCFNode, float], ) -> NNCFNetwork: transformation_layout = PTTransformationLayout() - for node, target_sparsity in target_sparsity_by_node.items(): - activation_port_id = self._get_activation_port_id(node, graph) - sparsifier = ActivationsSparsifier(target_sparsity=target_sparsity) + for node, threshold in threshold_by_node.items(): + activation_port_id = self.get_activation_port_id(node, graph) + sparsifier = ActivationsSparsifier(threshold=threshold) sparsifier_name = f"{ACTIVATIONS_SPARSIFIER_PREFIX}_{node.node_name.replace('.', '_')}" transformation_layout.register( PTSharedFnInsertionCommand( @@ -167,20 +95,8 @@ def insert_sparsifiers( transformed_model = PTModelTransformer(model).transform(transformation_layout) return transformed_model - def calibrate_sparsifiers(self, model: NNCFNetwork, graph: NNCFGraph, dataset: Dataset) -> NNCFNetwork: - sparsifiers = self.get_sparsifiers(model) - for sparsifier in sparsifiers: - sparsifier.reset_running_stats() - sparsifier.freeze = False - with training_mode_switcher(model, is_training=False): - with torch.no_grad(): - self.do_inference(model, dataset) - for sparsifier in sparsifiers: - sparsifier.freeze = True - return model - @staticmethod - def _get_activation_port_id(node: NNCFNode, graph: NNCFGraph) -> int: + def get_activation_port_id(node: NNCFNode, graph: NNCFGraph) -> int: """ Finds the input activation port id for the node. diff --git a/tests/torch/experimental/sparsify_activations/test_algo.py b/tests/torch/experimental/sparsify_activations/test_algo.py index 2059d7ef947..0bc77af98e7 100644 --- a/tests/torch/experimental/sparsify_activations/test_algo.py +++ b/tests/torch/experimental/sparsify_activations/test_algo.py @@ -164,8 +164,6 @@ def test_inserted_sparsifier(self): num_sparsifiers = 0 for name, op in model.nncf.external_op.items(): if isinstance(op, ActivationsSparsifier): - assert op.target_sparsity == desc.ref_sparsifier_target_sparsity[name] - assert op.num_batches_tracked == desc.ref_num_batches_tracked num_sparsifiers += 1 assert num_sparsifiers == len(desc.ref_sparsifier_target_sparsity) @@ -196,7 +194,7 @@ def test_export_openvino(self): ov_outputs = compiled_model(example_input.cpu()).to_tuple() assert len(torch_outputs) == len(ov_outputs) for torch_output, ov_output in zip(torch_outputs, ov_outputs): - torch.testing.assert_close(torch_output.cpu(), torch.from_numpy(ov_output), rtol=1e-3, atol=1e-3) + torch.testing.assert_close(torch_output.cpu(), torch.from_numpy(ov_output), rtol=1e-2, atol=1e-2) @dataclass diff --git a/tests/torch/experimental/sparsify_activations/test_components.py b/tests/torch/experimental/sparsify_activations/test_components.py index 9c5fde1c9e5..d6dc399279b 100644 --- a/tests/torch/experimental/sparsify_activations/test_components.py +++ b/tests/torch/experimental/sparsify_activations/test_components.py @@ -108,72 +108,30 @@ def setup(self, use_cuda: bool): self.device = torch.device("cuda" if use_cuda else "cpu") @pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) - def test_forward_before_calibration(self, use_cuda: bool, dtype: torch.dtype): + def test_sparsifier_forward(self, use_cuda: bool, dtype: torch.dtype): device = self.device - input_tensor = torch.rand([3, 3], device=device, dtype=dtype) - sparsifier = ActivationsSparsifier(target_sparsity=0.9).to(device) - assert sparsifier.freeze is True - assert not sparsifier.num_batches_tracked.is_nonzero() - assert sparsifier.running_threshold.isneginf() - output_tensor = sparsifier(input_tensor) - # The output tensor is a new tensor - assert not output_tensor.is_set_to(input_tensor) - # Before calibration, the sparsifier does not change the input - torch.testing.assert_close(output_tensor, input_tensor, rtol=1e-4, atol=1e-4) - - @pytest.mark.parametrize( - "desc", - sparsifier_forward_during_calibration_test_descs.values(), - ids=sparsifier_forward_during_calibration_test_descs.keys(), - ) - def test_forward_during_calibration(self, use_cuda: bool, desc: SparsifierForwardTestDesc): - device = self.device - sparsifier = ActivationsSparsifier(desc.target_sparsity, desc.alpha).to(device) - sparsifier.freeze = False - running_thresholds = [] - outputs = [] - with torch.no_grad(): - for batch in desc.input_batches: - output = sparsifier(batch.to(device)) - running_thresholds.append(sparsifier.running_threshold) - outputs.append(output) - assert sparsifier.num_batches_tracked == len(desc.input_batches) - assert len(running_thresholds) == len(desc.ref_running_thresholds) - for threshold, ref_threshold in zip(running_thresholds, desc.ref_running_thresholds): - assert threshold.device.type == device.type - torch.testing.assert_close(threshold, ref_threshold, rtol=1e-4, atol=1e-4, check_device=False) - assert len(outputs) == len(desc.ref_outputs) - for output, ref_output in zip(outputs, desc.ref_outputs): - assert output.device.type == device.type - torch.testing.assert_close(output, ref_output, rtol=1e-4, atol=1e-4, check_device=False) - - @pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) - def test_forward_after_calibration(self, use_cuda: bool, dtype: torch.dtype): - device = self.device - sparsifier = ActivationsSparsifier(target_sparsity=0.9).to(device) - sparsifier.running_threshold.fill_(0.1) - sparsifier.num_batches_tracked.fill_(100) + sparsifier = ActivationsSparsifier(threshold=0.1).to(device) for _ in range(2): # The sparsifier does not change in the following forwards input_tensor = torch.rand([2, 10], device=device, dtype=dtype) ref_output = torch.where(input_tensor.abs() <= 0.1, 0.0, input_tensor) output_tensor = sparsifier(ref_output) - assert sparsifier.num_batches_tracked == 100 - torch.testing.assert_close( - sparsifier.running_threshold, torch.tensor(0.1, device=device), rtol=1e-4, atol=1e-4 - ) + torch.testing.assert_close(sparsifier.threshold, torch.tensor(0.1, device=device), rtol=1e-4, atol=1e-4) torch.testing.assert_close(output_tensor, ref_output, rtol=1e-4, atol=1e-4) class TestPTSparsifyActivationsAlgoBackend: + @staticmethod + def get_sparsifiers(model: NNCFNetwork) -> List[ActivationsSparsifier]: + return [m for m in model.nncf.modules() if isinstance(m, ActivationsSparsifier)] + def test_get_sparsifiers(self): model, dataset = self.create_model_and_dataset() sparse_model = nncf.experimental.torch.sparsify_activations.sparsify_activations( model, dataset, target_sparsity_by_scope={TargetScope(patterns=[".*"]): 0.5} ) - backend = PTSparsifyActivationsAlgoBackend() - sparsifiers = backend.get_sparsifiers(sparse_model) + sparsifiers = self.get_sparsifiers(sparse_model) assert len(sparsifiers) == 3 @pytest.mark.parametrize("compress_weights", [False, True]) @@ -183,35 +141,15 @@ def test_insert_sparsifiers(self, compress_weights: bool): ref_output = model(example_input) graph = model.nncf.get_graph() - nodes = graph.get_nodes_by_metatypes(PTSparsifyActivationsAlgoBackend.SUPPORTED_METATYPES) backend = PTSparsifyActivationsAlgoBackend() - model_with_sparsifiers = backend.insert_sparsifiers(model, graph, {node: 0.9 for node in nodes}) - assert len(backend.get_sparsifiers(model_with_sparsifiers)) == len(nodes) + nodes = graph.get_nodes_by_metatypes(backend.supported_metatypes) + model_with_sparsifiers = backend.insert_sparsifiers(model, graph, {node: 0.0 for node in nodes}) + assert len(self.get_sparsifiers(model_with_sparsifiers)) == len(nodes) output = model_with_sparsifiers(example_input) torch.testing.assert_close( output, ref_output, rtol=1e-4, atol=1e-4 - ) # At this time the sparsifers do not change the output - - def test_calibrate_sparsifiers(self, mocker): - model, dataset = self.create_model_and_dataset() - graph = model.nncf.get_graph() - backend = PTSparsifyActivationsAlgoBackend() - mock_sparsifier = ActivationsSparsifier(0.5, 0.1) - mock_sparsifier.freeze = True - num_model_forward_calls = 0 - - def model_forward_pre_hook(model: NNCFNetwork, args): - nonlocal num_model_forward_calls - num_model_forward_calls += 1 - assert model.training is False - - model.register_forward_pre_hook(model_forward_pre_hook) - - with mocker.patch.object(backend, "get_sparsifiers", return_value=[mock_sparsifier]): - backend.calibrate_sparsifiers(model, graph, dataset) - assert mock_sparsifier.freeze is True - assert num_model_forward_calls == dataset.get_length() + ) # Since threshold is 0.0 sparsifiers do not change the output def create_model_and_dataset(self, compress_weights: bool = False): model = ThreeLinearModel()