Skip to content

Commit

Permalink
Add PruningFrameworkImplementation
Browse files Browse the repository at this point in the history
  • Loading branch information
reuvenp committed Nov 29, 2023
1 parent b50dc1c commit a3b3c92
Show file tree
Hide file tree
Showing 5 changed files with 268 additions and 66 deletions.
16 changes: 0 additions & 16 deletions model_compression_toolkit/core/common/framework_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,22 +67,6 @@ def get_trace_hessian_calculator(self,
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
f'framework\'s get_trace_hessian_calculator method.') # pragma: no cover


@abstractmethod
def is_node_intermediate_pruning_section(self, node):
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
f'framework\'s is_node_intermediate_pruning_section method.') # pragma: no cover

@abstractmethod
def get_pruned_node_num_params(self,
node: BaseNode,
input_mask: np.ndarray,
output_mask: np.ndarray,
fw_info: FrameworkInfo,
include_null_channels: bool):

raise NotImplemented(f'{self.__class__.__name__} have to implement the '
f'framework\'s get_pruned_node_num_params method.') # pragma: no cover
@abstractmethod
def to_numpy(self, tensor: Any) -> np.ndarray:
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
from abc import abstractmethod

from model_compression_toolkit.core.common.framework_info import FrameworkInfo
from model_compression_toolkit.core.common import BaseNode
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
import numpy as np


class PruningFrameworkImplementation(FrameworkImplementation):

@abstractmethod
def prune_entry_node(self,
node: BaseNode,
output_mask: np.ndarray,
fw_info: FrameworkInfo):
"""
Abstract method to prune an entry node in the model.
Args:
node: The node to be pruned.
output_mask: A numpy array representing the mask to be applied to the output channels.
fw_info: Framework-specific information.
Raises:
NotImplemented: If the method is not implemented in the subclass.
"""
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
f'framework\'s prune_entry_node method.') # pragma: no cover

@abstractmethod
def prune_intermediate_node(self,
node: BaseNode,
input_mask: np.ndarray,
output_mask: np.ndarray,
fw_info: FrameworkInfo):
"""
Abstract method to prune an intermediate node in the model.
Args:
node: The node to be pruned.
input_mask: Mask to be applied to the input channels.
output_mask: Mask to be applied to the output channels.
fw_info: Framework-specific information.
Raises:
NotImplemented: If the method is not implemented in the subclass.
"""
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
f'framework\'s prune_intermediate_node method.') # pragma: no cover

@abstractmethod
def prune_exit_node(self,
node: BaseNode,
input_mask: np.ndarray,
fw_info: FrameworkInfo):
"""
Abstract method to prune an exit node in the model.
Args:
node: The node to be pruned.
input_mask: Mask to be applied to the input channels.
fw_info: Framework-specific information.
Raises:
NotImplemented: If the method is not implemented in the subclass.
"""
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
f'framework\'s prune_exit_node method.') # pragma: no cover

@abstractmethod
def is_node_entry_node(self,
node: BaseNode):
"""
Abstract method to determine if a given node is an entry node.
Args:
node: The node to be checked.
Returns:
bool: True if the node is an entry node, False otherwise.
Raises:
NotImplemented: If the method is not implemented in the subclass.
"""
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
f'framework\'s is_node_entry_node method.') # pragma: no cover

@abstractmethod
def is_node_exit_node(self,
node: BaseNode,
dual_entry_node: BaseNode):
"""
Abstract method to determine if a given node is an exit node.
Args:
node: The node to be checked.
dual_entry_node: Another node to be used in the determination process.
Returns:
bool: True if the node is an exit node, False otherwise.
Raises:
NotImplemented: If the method is not implemented in the subclass.
"""
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
f'framework\'s is_node_exit_node method.') # pragma: no cover

@abstractmethod
def is_node_intermediate_pruning_section(self,
node):
"""
Abstract method to determine if a given node is in the intermediate section of pruning.
Args:
node: The node to be checked.
Returns:
bool: True if the node is in the intermediate pruning section, False otherwise.
Raises:
NotImplemented: If the method is not implemented in the subclass.
"""
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
f'framework\'s is_node_intermediate_pruning_section method.') # pragma: no cover

@abstractmethod
def get_pruned_node_num_params(self,
node: BaseNode,
input_mask: np.ndarray,
output_mask: np.ndarray,
fw_info: FrameworkInfo,
include_null_channels: bool):
"""
Abstract method to get the number of parameters of a pruned node.
Args:
node: The node whose parameters are to be counted.
input_mask: Mask to be applied to the input channels.
output_mask: Mask to be applied to the output channels.
fw_info: Framework-specific information.
include_null_channels: Boolean flag to include or exclude null channels in the count.
Returns:
int: Number of parameters after pruning.
Raises:
NotImplemented: If the method is not implemented in the subclass.
"""
raise NotImplemented(f'{self.__class__.__name__} have to implement the '
f'framework\'s get_pruned_node_num_params method.') # pragma: no cover
49 changes: 0 additions & 49 deletions model_compression_toolkit/core/keras/keras_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,6 @@
ActivationTraceHessianCalculatorKeras
from model_compression_toolkit.core.keras.hessian.trace_hessian_calculator_keras import TraceHessianCalculatorKeras
from model_compression_toolkit.core.keras.hessian.weights_trace_hessian_calculator_keras import WeightsTraceHessianCalculatorKeras
from model_compression_toolkit.core.keras.pruning.check_node_role import is_keras_node_intermediate_pruning_section, \
is_keras_entry_node, is_keras_exit_node
from model_compression_toolkit.core.keras.pruning.count_node_params import get_keras_pruned_node_num_params
from model_compression_toolkit.core.keras.pruning.prune_keras_node import prune_keras_exit_node, prune_keras_entry_node, prune_keras_intermediate_node

from model_compression_toolkit.logger import Logger
from model_compression_toolkit.trainable_infrastructure.keras.quantize_wrapper import KerasTrainableQuantizationWrapper
from model_compression_toolkit.core.common.mixed_precision.sensitivity_evaluation import SensitivityEvaluation
Expand Down Expand Up @@ -599,47 +594,3 @@ def sensitivity_eval_inference(self,

return model(inputs)


def prune_entry_node(self, node: BaseNode, output_mask: np.ndarray,
fw_info: FrameworkInfo):
return prune_keras_entry_node(node,
output_mask,
fw_info)

def prune_intermediate_node(self, node: BaseNode, input_mask: np.ndarray, output_mask: np.ndarray,
fw_info: FrameworkInfo):
return prune_keras_intermediate_node(node,
input_mask,
output_mask,
fw_info)

def prune_exit_node(self,
node: BaseNode,
input_mask: np.ndarray,
fw_info: FrameworkInfo):
return prune_keras_exit_node(node,
input_mask,
fw_info)


def is_node_entry_node(self, node:BaseNode): #TODO:Add to base class
return is_keras_entry_node(node)

def is_node_exit_node(self, node:BaseNode, dual_entry_node: BaseNode):
return is_keras_exit_node(node, dual_entry_node)

def is_node_intermediate_pruning_section(self, node):
return is_keras_node_intermediate_pruning_section(node)

def get_pruned_node_num_params(self,
node: BaseNode,
input_mask: np.ndarray,
output_mask: np.ndarray,
fw_info: FrameworkInfo,
include_null_channels: bool):

return get_keras_pruned_node_num_params(node,
input_mask,
output_mask,
fw_info,
include_null_channels)
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from model_compression_toolkit.core.common.framework_info import FrameworkInfo
from model_compression_toolkit.core.common import BaseNode
from model_compression_toolkit.core.common.pruning.pruning_framework_implementation import \
PruningFrameworkImplementation
from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation
from model_compression_toolkit.core.keras.pruning.check_node_role import is_keras_node_intermediate_pruning_section, \
is_keras_entry_node, is_keras_exit_node
from model_compression_toolkit.core.keras.pruning.count_node_params import get_keras_pruned_node_num_params
from model_compression_toolkit.core.keras.pruning.prune_keras_node import (prune_keras_exit_node,
prune_keras_entry_node, \
prune_keras_intermediate_node)
import numpy as np

class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementation):
"""
Implementation of the PruningFramework for the Keras framework. This class provides
concrete implementations of the abstract methods defined in PruningFrameworkImplementation
for the Keras framework.
"""

def prune_entry_node(self, node: BaseNode, output_mask: np.ndarray, fw_info: FrameworkInfo):
"""
Prunes the entry node of a model in Keras.
Args:
node: The entry node to be pruned.
output_mask: A numpy array representing the mask to be applied to the output channels.
fw_info: Framework-specific information object.
Returns:
The result from the pruning operation.
"""
return prune_keras_entry_node(node, output_mask, fw_info)

def prune_intermediate_node(self, node: BaseNode, input_mask: np.ndarray, output_mask: np.ndarray, fw_info: FrameworkInfo):
"""
Prunes an intermediate node in a Keras model.
Args:
node: The intermediate node to be pruned.
input_mask: A numpy array representing the mask to be applied to the input channels.
output_mask: A numpy array representing the mask to be applied to the output channels.
fw_info: Framework-specific information object.
Returns:
The result from the pruning operation.
"""
return prune_keras_intermediate_node(node, input_mask, output_mask, fw_info)

def prune_exit_node(self, node: BaseNode, input_mask: np.ndarray, fw_info: FrameworkInfo):
"""
Prunes the exit node of a model in Keras.
Args:
node: The exit node to be pruned.
input_mask: A numpy array representing the mask to be applied to the input channels.
fw_info: Framework-specific information object.
Returns:
The result from the pruning operation.
"""
return prune_keras_exit_node(node, input_mask, fw_info)

def is_node_entry_node(self, node: BaseNode):
"""
Determines whether a node is an entry node in a Keras model.
Args:
node: The node to be checked.
Returns:
Boolean indicating if the node is an entry node.
"""
return is_keras_entry_node(node)

def is_node_exit_node(self, node: BaseNode, dual_entry_node: BaseNode):
"""
Determines whether a node is an exit node in a Keras model.
Args:
node: The node to be checked.
dual_entry_node: A related entry node to assist in the determination.
Returns:
Boolean indicating if the node is an exit node.
"""
return is_keras_exit_node(node, dual_entry_node)

def is_node_intermediate_pruning_section(self, node):
"""
Determines whether a node is part of the intermediate section in the pruning process of a Keras model.
Args:
node: The node to be checked.
Returns:
Boolean indicating if the node is part of the intermediate pruning section.
"""
return is_keras_node_intermediate_pruning_section(node)

def get_pruned_node_num_params(self, node: BaseNode, input_mask: np.ndarray, output_mask: np.ndarray, fw_info: FrameworkInfo, include_null_channels: bool):
"""
Calculates the number of parameters in a pruned node of a Keras model.
Args:
node: The node whose parameters are to be counted.
input_mask: Mask to be applied to the input channels.
output_mask: Mask to be applied to the output channels.
fw_info: Framework-specific information object.
include_null_channels: Boolean flag to include or exclude null channels in the count.
Returns:
Integer representing the number of parameters in the pruned node.
"""
return get_keras_pruned_node_num_params(node, input_mask, output_mask, fw_info, include_null_channels)

3 changes: 2 additions & 1 deletion model_compression_toolkit/pruning/keras/pruning_facade.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
set_quantization_configuration_to_graph
from model_compression_toolkit.core.graph_prep_runner import read_model_to_graph
from model_compression_toolkit.core.keras.back2framework.float_model_builder import FloatKerasModelBuilder
from model_compression_toolkit.core.keras.pruning.pruning_keras_implementation import PruningKerasImplementation
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import \
TargetPlatformCapabilities
Expand Down Expand Up @@ -62,7 +63,7 @@ def keras_pruning_experimental(model: Model,
"""

# Instantiate the Keras framework implementation.
fw_impl = KerasImplementation()
fw_impl = PruningKerasImplementation()

# Convert the original Keras model to an internal graph representation.
float_graph = read_model_to_graph(model,
Expand Down

0 comments on commit a3b3c92

Please sign in to comment.