diff --git a/nncf/experimental/torch/sparsify_activations/sparsify_activations_impl.py b/nncf/experimental/torch/sparsify_activations/sparsify_activations_impl.py index a486dc4f34c..fa0fa9da534 100644 --- a/nncf/experimental/torch/sparsify_activations/sparsify_activations_impl.py +++ b/nncf/experimental/torch/sparsify_activations/sparsify_activations_impl.py @@ -128,45 +128,31 @@ def available_backends(self) -> List[BackendType]: """ return [BackendType.TORCH] - def _set_backend_entity(self, model: TModel) -> None: - """ - Creates a helper class with a backend-specific logic of the algorithm. - - :param model: Backend-specific input model. - """ - model_backend = get_backend(model) - if model_backend == BackendType.TORCH: - from nncf.experimental.torch.sparsify_activations.torch_backend import PTSparsifyActivationsAlgoBackend - - self._backend_entity = PTSparsifyActivationsAlgoBackend() - else: - raise nncf.UnsupportedBackendError( - f"{model_backend.value} backend is not supported for `sparsify_activations`." - ) - - def _get_target_sparsity_by_node(self, graph: NNCFGraph) -> Dict[NNCFNode, float]: + def apply( + self, + model: TModel, + graph: NNCFGraph, + dataset: Dataset, + ) -> TModel: """ - Collects nodes in the model's graph corresponding to the layers for sparsification. + Applies the algorithm to the given model. - :param graph: NNCFGraph instance. - :return: A dictionary with nodes and the corresponding target sparsity level. + :param model: The model to be sparsified. + :param graph: The model's NNCF graph. + :param dataset: The dataset to calibrate the activation sparsifiers. + :return: The sparsified model. """ - supported_metatypes = self._backend_entity.supported_metatypes - ignored_names = get_ignored_node_names_from_ignored_scope( - self._ignored_scope, graph, strict=self._ignored_scope.validate + self._set_backend_entity(model) + target_sparsity_by_node = self._get_target_sparsity_by_node(graph) + if not target_sparsity_by_node: + raise nncf.ValidationError("No layers matched for activation sparsification.") + sparse_model = self.do_sparsification( + model, + graph, + target_sparsity_by_node, + dataset, ) - target_sparsity_by_node = {} - for node in graph.topological_sort(): - if node.metatype not in supported_metatypes or not should_consider_scope(node.node_name, ignored_names): - continue - for scope, target_sparsity in self._target_sparsity_by_scope.items(): - if matches_any(node.node_name, scope): - if node.node_name in target_sparsity_by_node: - raise nncf.ValidationError( - f'"{node.node_name}" is matched by multiple items in `target_sparsity_by_scope`.' - ) - target_sparsity_by_node[node] = target_sparsity - return target_sparsity_by_node + return sparse_model def do_sparsification( self, @@ -198,31 +184,45 @@ def do_sparsification( model = self._backend_entity.freeze_sparsifiers(model, graph) return model - def apply( - self, - model: TModel, - graph: NNCFGraph, - dataset: Dataset, - ) -> TModel: + def _set_backend_entity(self, model: TModel) -> None: """ - Applies the algorithm to the given model. + Creates a helper class with a backend-specific logic of the algorithm. - :param model: The model to be sparsified. - :param graph: The model's NNCF graph. - :param dataset: The dataset to calibrate the activation sparsifiers. - :return: The sparsified model. + :param model: Backend-specific input model. """ - self._set_backend_entity(model) - target_sparsity_by_node = self._get_target_sparsity_by_node(graph) - if not target_sparsity_by_node: - raise nncf.ValidationError("No layers matched for activation sparsification.") - sparse_model = self.do_sparsification( - model, - graph, - target_sparsity_by_node, - dataset, + model_backend = get_backend(model) + if model_backend == BackendType.TORCH: + from nncf.experimental.torch.sparsify_activations.torch_backend import PTSparsifyActivationsAlgoBackend + + self._backend_entity = PTSparsifyActivationsAlgoBackend() + else: + raise nncf.UnsupportedBackendError( + f"{model_backend.value} backend is not supported for `sparsify_activations`." + ) + + def _get_target_sparsity_by_node(self, graph: NNCFGraph) -> Dict[NNCFNode, float]: + """ + Collects nodes in the model's graph corresponding to the layers for sparsification. + + :param graph: NNCFGraph instance. + :return: A dictionary with nodes and the corresponding target sparsity level. + """ + supported_metatypes = self._backend_entity.supported_metatypes + ignored_names = get_ignored_node_names_from_ignored_scope( + self._ignored_scope, graph, strict=self._ignored_scope.validate ) - return sparse_model + target_sparsity_by_node = {} + for node in graph.topological_sort(): + if node.metatype not in supported_metatypes or not should_consider_scope(node.node_name, ignored_names): + continue + for scope, target_sparsity in self._target_sparsity_by_scope.items(): + if matches_any(node.node_name, scope): + if node.node_name in target_sparsity_by_node: + raise nncf.ValidationError( + f'"{node.node_name}" is matched by multiple items in `target_sparsity_by_scope`.' + ) + target_sparsity_by_node[node] = target_sparsity + return target_sparsity_by_node def sparsify_activations(