Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Activation Sparsity OV backend #2924

Draft
wants to merge 11 commits into
base: develop
Choose a base branch
from
39 changes: 39 additions & 0 deletions nncf/experimental/torch/sparsify_activations/ema_aggregator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional

import nncf.tensor.functions as fns
from nncf.experimental.common.tensor_statistics.collectors import AggregationAxes
from nncf.experimental.common.tensor_statistics.collectors import OnlineAggregatorBase
from nncf.tensor import Tensor


# TODO: add tests
class EMAAggregator(OnlineAggregatorBase):
def __init__(
self,
alpha: float,
num_samples: Optional[int] = None,
window_size: Optional[int] = None,
):
self._alpha = alpha
super().__init__(aggregation_axes=(0,), num_samples=num_samples, window_size=window_size)

def _aggregation_fn(self, stacked_value: Tensor, axis: AggregationAxes, keepdims: bool) -> Tensor:
if self._collected_samples == 0:
return stacked_value
else:
beta = 1.0 - self._alpha
new_value = fns.expand_dims(stacked_value[0], 0)
old_value = fns.expand_dims(stacked_value[1], 0)
return new_value * self._alpha + old_value * beta * (1 - beta**self._collected_samples) / (
1 - beta ** (self._collected_samples + 1)
)
103 changes: 103 additions & 0 deletions nncf/experimental/torch/sparsify_activations/openvino_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, Optional, Type, Union

import openvino.runtime
from openvino.runtime import opset13 as opset

from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.graph import NNCFNode
from nncf.common.graph.operator_metatypes import OperatorMetatype
from nncf.common.graph.transformations.commands import TargetType
from nncf.experimental.torch.sparsify_activations.sparsify_activations_impl import SparsifyActivationsAlgoBackend
from nncf.openvino.graph.metatypes import openvino_metatypes as om
from nncf.openvino.graph.model_transformer import OVModelTransformer
from nncf.openvino.graph.transformations.commands import OVTargetPoint
from nncf.openvino.statistics.collectors import OVAbsQuantileReducer
from nncf.torch.nncf_network import NNCFNetwork

ACTIVATIONS_SPARSIFIER_PREFIX = "activations_sparsifier"


class OVSparsifyActivationsAlgoBackend(SparsifyActivationsAlgoBackend):
"""
OpenVINO backend for the activation sparsification algorithm.
"""

@property
def supported_metatypes(self) -> List[Type[OperatorMetatype]]:
return [om.OVMatMulMetatype]

def abs_quantile_reducer(self, quantile: Optional[Union[float, List[float]]] = None) -> OVAbsQuantileReducer:
return OVAbsQuantileReducer(quantile=quantile)

def target_point(self, target_type: TargetType, target_node_name: str, port_id: int) -> OVTargetPoint:
return OVTargetPoint(TargetType.PRE_LAYER_OPERATION, target_node_name, port_id=port_id)

def insert_sparsifiers(
self,
model: openvino.Model,
graph: NNCFGraph,
threshold_by_node: Dict[NNCFNode, float],
) -> NNCFNetwork:
name_to_node_mapping = OVModelTransformer._get_name_to_node_mapping(model)
for nncf_node, threshold in threshold_by_node.items():
activation_port_id = self.get_activation_port_id(nncf_node, graph)
matmul_node = name_to_node_mapping[nncf_node.node_name]
dense_activation = matmul_node.input(activation_port_id).get_source_output().get_node()

dtype = dense_activation.get_element_type()
threshold_const = opset.constant(threshold, dtype=dtype, name=f"{matmul_node.name}/sparsity_threshold")
zero_const = opset.constant(0.0, dtype=dtype)

less_mask = opset.less_equal(opset.abs(dense_activation), threshold_const)
sparse_activation = opset.select(
less_mask, zero_const, dense_activation, name=f"{matmul_node.name}/sparse_input"
)
matmul_node.input(activation_port_id).replace_source_output(sparse_activation.output(0))

return model

@staticmethod
def get_activation_port_id(matmul_node: NNCFNode, nncf_graph: NNCFGraph) -> int:
return 0
n_inputs = len(nncf_graph.get_input_edges(matmul_node))
if n_inputs != 2:
raise RuntimeError(f"Expected node to have two inputs, but found {n_inputs} for node {matmul_node}.")

is_const_node_on_port = [
nncf_graph.get_input_edges(matmul_node)[i].from_node.node_type == "Constant" for i in range(2)
]
if is_const_node_on_port[0] != is_const_node_on_port[1]:
assert not is_const_node_on_port[0], matmul_node.node_name
return 1 if is_const_node_on_port[0] else 0

# Try to match compressed constant subgraph
for i in range(2):
node = nncf_graph.get_input_edges(matmul_node)[i].from_node
if node.node_type == "Convert":
node = nncf_graph.get_input_edges(node)[0].from_node
if node.node_type == "Reshape":
node = nncf_graph.get_input_edges(node)[0].from_node
if node.node_type == "Multiply":
node = nncf_graph.get_input_edges(node)[0].from_node
if node.node_type == "Subtract":
node = nncf_graph.get_input_edges(node)[0].from_node
if node.node_type == "Convert":
node = nncf_graph.get_input_edges(node)[0].from_node
else:
continue
if node.node_type == "Constant":
assert i == 1, matmul_node.node_name
return int(i == 0)

raise RuntimeError(f"Could not find activation port id for node {matmul_node}.")
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,25 @@

from abc import ABC
from abc import abstractmethod
from typing import Dict, List, Optional, Type, TypeVar
from typing import Dict, List, Optional, Type, TypeVar, Union

import nncf
from nncf.common import factory
from nncf.common.factory import NNCFGraphFactory
from nncf.common.factory import StatisticsAggregatorFactory
from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.graph import NNCFNode
from nncf.common.graph.operator_metatypes import OperatorMetatype
from nncf.common.logging.track_progress import track
from nncf.common.graph.transformations.commands import TargetPoint
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.scopes import should_consider_scope
from nncf.common.tensor_statistics.statistic_point import StatisticPoint
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
from nncf.common.utils.backend import BackendType
from nncf.common.utils.backend import get_backend
from nncf.data import Dataset
from nncf.experimental.common.tensor_statistics.collectors import AbsQuantileReducer
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.experimental.torch.sparsify_activations.ema_aggregator import EMAAggregator
from nncf.experimental.torch.sparsify_activations.target_scope import TargetScope
from nncf.experimental.torch.sparsify_activations.target_scope import get_target_node_names_from_target_scope
from nncf.scopes import IgnoredScope
Expand All @@ -32,44 +38,36 @@
from nncf.torch.model_creation import wrap_model

TModel = TypeVar("TModel")
STATISTIC_BRANCH_KEY = "abs_quantile"
ALGORITHM_KEY = "AS"


class SparsifyActivationsAlgoBackend(ABC):
"""
Abstract class for activation sparsification algorithm backend.
"""

CALIBRATION_TRACKING_DESC = "Conducting Activations Sparsifier Calibration"

@staticmethod
def do_inference(model: TModel, dataset: Dataset):
"""
Conducts model inference on given dataset to calibrate the activation sparsifiers.

:param model: The model with activation sparsifiers.
:param dataset: The calibration dataset to update the sparsifiers.
"""
engine = factory.EngineFactory.create(model)
for input_data in track(
dataset.get_inference_data(),
total=dataset.get_length(),
description=SparsifyActivationsAlgoBackend.CALIBRATION_TRACKING_DESC,
):
engine.infer(input_data)

@property
@abstractmethod
def supported_metatypes(self) -> List[Type[OperatorMetatype]]:
"""
Property for the backend-specific metatypes for supported layers.
"""

@abstractmethod
def abs_quantile_reducer(self, quantile: Optional[Union[float, List[float]]] = None) -> AbsQuantileReducer:
""" """

@abstractmethod
def target_point(self, target_type: TargetType, target_node_name: str, port_id: int) -> TargetPoint:
""" """

@abstractmethod
def insert_sparsifiers(
self,
model: TModel,
graph: NNCFGraph,
target_sparsity_by_node: Dict[NNCFNode, float],
threshold_by_node: Dict[NNCFNode, float],
) -> TModel:
"""
Inserts the activation sparsifiers to the model.
Expand All @@ -80,15 +78,15 @@ def insert_sparsifiers(
:return: The model with inserted activation sparsifiers.
"""

@staticmethod
@abstractmethod
def calibrate_sparsifiers(self, model: TModel, graph: NNCFGraph, dataset: Dataset) -> TModel:
def get_activation_port_id(node: NNCFNode, graph: NNCFGraph) -> int:
"""
Calibrates the thresholds in the activation sparsifiers.
Finds the input activation port id for the node.

:param model: The model with inserted activation sparsifiers.
:param graph: The model's NNCF graph.
:param dataset: The calibration dataset to update the thresholds in the sparsifiers.
:return: The model with calibrated activation sparsifiers.
:param node: The node to find its activation port id.
:param graph: The NNCF graph containing the node.
:return: The activation port id.
"""


Expand Down Expand Up @@ -116,7 +114,7 @@ def available_backends(self) -> List[BackendType]:
"""
Supported backends for this algorithm.
"""
return [BackendType.TORCH]
return [BackendType.TORCH, BackendType.OPENVINO]

def apply(
self,
Expand All @@ -134,30 +132,10 @@ def apply(
"""
self._set_backend_entity(model)
target_sparsity_by_node = self._get_target_sparsity_by_node(graph)
sparse_model = self.do_sparsification(model, graph, target_sparsity_by_node, dataset)
threshold_by_node = self._get_threshold_by_node(model, graph, target_sparsity_by_node, dataset)
sparse_model = self._backend_entity.insert_sparsifiers(model, graph, threshold_by_node)
return sparse_model

def do_sparsification(
self,
model: TModel,
graph: NNCFGraph,
target_sparsity_by_node: Dict[NNCFNode, float],
dataset: Dataset,
):
"""
Transforms the model into a sparsified one with node-specific target activation sparsity levels.

:param model: The model to be sparsified.
:param graph: The model's NNCF graph.
:param target_sparsity_by_node: A dictionary that defines the target sparsity level
for specified node layers.
:param dataset: The dataset to calibrate the activation sparsifiers.
:return: The sparsified model.
"""
model = self._backend_entity.insert_sparsifiers(model, graph, target_sparsity_by_node)
model = self._backend_entity.calibrate_sparsifiers(model, graph, dataset)
return model

def _set_backend_entity(self, model: TModel) -> None:
"""
Creates a helper class with a backend-specific logic of the algorithm.
Expand All @@ -169,6 +147,10 @@ def _set_backend_entity(self, model: TModel) -> None:
from nncf.experimental.torch.sparsify_activations.torch_backend import PTSparsifyActivationsAlgoBackend

self._backend_entity = PTSparsifyActivationsAlgoBackend()
elif model_backend == BackendType.OPENVINO:
from nncf.experimental.torch.sparsify_activations.openvino_backend import OVSparsifyActivationsAlgoBackend

self._backend_entity = OVSparsifyActivationsAlgoBackend()
else:
raise nncf.UnsupportedBackendError(
f"{model_backend.value} backend is not supported for `sparsify_activations`."
Expand Down Expand Up @@ -203,6 +185,46 @@ def _get_target_sparsity_by_node(self, graph: NNCFGraph) -> Dict[NNCFNode, float
raise nncf.ValidationError("No layers to conduct activation sparsification.")
return target_sparsity_by_node

def _get_threshold_by_node(self, model, graph, target_sparsity_by_node, dataset: Dataset) -> Dict[NNCFNode, float]:
statistic_points_container = StatisticPointsContainer()
for node, sparsity in target_sparsity_by_node.items():
stat_collector = TensorCollector()
stat_collector.register_statistic_branch(
container_key=STATISTIC_BRANCH_KEY,
reducer=self._backend_entity.abs_quantile_reducer(
quantile=[
sparsity,
]
),
aggregator=EMAAggregator(alpha=0.2),
)
activation_port_id = self._backend_entity.get_activation_port_id(node, graph)
statistic_point = StatisticPoint(
target_point=self._backend_entity.target_point(
TargetType.PRE_LAYER_OPERATION, node.node_name, port_id=activation_port_id
),
tensor_collector=stat_collector,
algorithm=ALGORITHM_KEY,
)
statistic_points_container.add_statistic_point(statistic_point)

statistics_aggregator = StatisticsAggregatorFactory.create(model, dataset)
statistics_aggregator.register_statistic_points(statistic_points_container)
statistics_aggregator.collect_statistics(model, graph)

threshold_by_node = {}
for nncf_node in target_sparsity_by_node:
tensor_collector = next(
iter(
statistic_points_container.get_algo_statistics_for_node(
nncf_node.node_name, lambda args: True, ALGORITHM_KEY
)
)
)
threshold_by_node[nncf_node] = tensor_collector.get_statistics()[STATISTIC_BRANCH_KEY].data

return threshold_by_node


def sparsify_activations(
model: TModel,
Expand Down
Loading