Skip to content

Commit

Permalink
change method order
Browse files Browse the repository at this point in the history
  • Loading branch information
yujiepan-work committed May 22, 2024
1 parent 066c51c commit f2a2cbc
Showing 1 changed file with 56 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -128,45 +128,31 @@ def available_backends(self) -> List[BackendType]:
"""
return [BackendType.TORCH]

def _set_backend_entity(self, model: TModel) -> None:
"""
Creates a helper class with a backend-specific logic of the algorithm.
:param model: Backend-specific input model.
"""
model_backend = get_backend(model)
if model_backend == BackendType.TORCH:
from nncf.experimental.torch.sparsify_activations.torch_backend import PTSparsifyActivationsAlgoBackend

self._backend_entity = PTSparsifyActivationsAlgoBackend()
else:
raise nncf.UnsupportedBackendError(
f"{model_backend.value} backend is not supported for `sparsify_activations`."
)

def _get_target_sparsity_by_node(self, graph: NNCFGraph) -> Dict[NNCFNode, float]:
def apply(
self,
model: TModel,
graph: NNCFGraph,
dataset: Dataset,
) -> TModel:
"""
Collects nodes in the model's graph corresponding to the layers for sparsification.
Applies the algorithm to the given model.
:param graph: NNCFGraph instance.
:return: A dictionary with nodes and the corresponding target sparsity level.
:param model: The model to be sparsified.
:param graph: The model's NNCF graph.
:param dataset: The dataset to calibrate the activation sparsifiers.
:return: The sparsified model.
"""
supported_metatypes = self._backend_entity.supported_metatypes
ignored_names = get_ignored_node_names_from_ignored_scope(
self._ignored_scope, graph, strict=self._ignored_scope.validate
self._set_backend_entity(model)
target_sparsity_by_node = self._get_target_sparsity_by_node(graph)
if not target_sparsity_by_node:
raise nncf.ValidationError("No layers matched for activation sparsification.")
sparse_model = self.do_sparsification(
model,
graph,
target_sparsity_by_node,
dataset,
)
target_sparsity_by_node = {}
for node in graph.topological_sort():
if node.metatype not in supported_metatypes or not should_consider_scope(node.node_name, ignored_names):
continue
for scope, target_sparsity in self._target_sparsity_by_scope.items():
if matches_any(node.node_name, scope):
if node.node_name in target_sparsity_by_node:
raise nncf.ValidationError(
f'"{node.node_name}" is matched by multiple items in `target_sparsity_by_scope`.'
)
target_sparsity_by_node[node] = target_sparsity
return target_sparsity_by_node
return sparse_model

def do_sparsification(
self,
Expand Down Expand Up @@ -198,31 +184,45 @@ def do_sparsification(
model = self._backend_entity.freeze_sparsifiers(model, graph)
return model

def apply(
self,
model: TModel,
graph: NNCFGraph,
dataset: Dataset,
) -> TModel:
def _set_backend_entity(self, model: TModel) -> None:
"""
Applies the algorithm to the given model.
Creates a helper class with a backend-specific logic of the algorithm.
:param model: The model to be sparsified.
:param graph: The model's NNCF graph.
:param dataset: The dataset to calibrate the activation sparsifiers.
:return: The sparsified model.
:param model: Backend-specific input model.
"""
self._set_backend_entity(model)
target_sparsity_by_node = self._get_target_sparsity_by_node(graph)
if not target_sparsity_by_node:
raise nncf.ValidationError("No layers matched for activation sparsification.")
sparse_model = self.do_sparsification(
model,
graph,
target_sparsity_by_node,
dataset,
model_backend = get_backend(model)
if model_backend == BackendType.TORCH:
from nncf.experimental.torch.sparsify_activations.torch_backend import PTSparsifyActivationsAlgoBackend

self._backend_entity = PTSparsifyActivationsAlgoBackend()
else:
raise nncf.UnsupportedBackendError(
f"{model_backend.value} backend is not supported for `sparsify_activations`."
)

def _get_target_sparsity_by_node(self, graph: NNCFGraph) -> Dict[NNCFNode, float]:
"""
Collects nodes in the model's graph corresponding to the layers for sparsification.
:param graph: NNCFGraph instance.
:return: A dictionary with nodes and the corresponding target sparsity level.
"""
supported_metatypes = self._backend_entity.supported_metatypes
ignored_names = get_ignored_node_names_from_ignored_scope(
self._ignored_scope, graph, strict=self._ignored_scope.validate
)
return sparse_model
target_sparsity_by_node = {}
for node in graph.topological_sort():
if node.metatype not in supported_metatypes or not should_consider_scope(node.node_name, ignored_names):
continue
for scope, target_sparsity in self._target_sparsity_by_scope.items():
if matches_any(node.node_name, scope):
if node.node_name in target_sparsity_by_node:
raise nncf.ValidationError(
f'"{node.node_name}" is matched by multiple items in `target_sparsity_by_scope`.'
)
target_sparsity_by_node[node] = target_sparsity
return target_sparsity_by_node


def sparsify_activations(
Expand Down

0 comments on commit f2a2cbc

Please sign in to comment.