-
Notifications
You must be signed in to change notification settings - Fork 58
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
reuvenp
committed
Nov 29, 2023
1 parent
b50dc1c
commit a3b3c92
Showing
5 changed files
with
268 additions
and
66 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
150 changes: 150 additions & 0 deletions
150
model_compression_toolkit/core/common/pruning/pruning_framework_implementation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
from abc import abstractmethod | ||
|
||
from model_compression_toolkit.core.common.framework_info import FrameworkInfo | ||
from model_compression_toolkit.core.common import BaseNode | ||
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation | ||
import numpy as np | ||
|
||
|
||
class PruningFrameworkImplementation(FrameworkImplementation): | ||
|
||
@abstractmethod | ||
def prune_entry_node(self, | ||
node: BaseNode, | ||
output_mask: np.ndarray, | ||
fw_info: FrameworkInfo): | ||
""" | ||
Abstract method to prune an entry node in the model. | ||
Args: | ||
node: The node to be pruned. | ||
output_mask: A numpy array representing the mask to be applied to the output channels. | ||
fw_info: Framework-specific information. | ||
Raises: | ||
NotImplemented: If the method is not implemented in the subclass. | ||
""" | ||
raise NotImplemented(f'{self.__class__.__name__} have to implement the ' | ||
f'framework\'s prune_entry_node method.') # pragma: no cover | ||
|
||
@abstractmethod | ||
def prune_intermediate_node(self, | ||
node: BaseNode, | ||
input_mask: np.ndarray, | ||
output_mask: np.ndarray, | ||
fw_info: FrameworkInfo): | ||
""" | ||
Abstract method to prune an intermediate node in the model. | ||
Args: | ||
node: The node to be pruned. | ||
input_mask: Mask to be applied to the input channels. | ||
output_mask: Mask to be applied to the output channels. | ||
fw_info: Framework-specific information. | ||
Raises: | ||
NotImplemented: If the method is not implemented in the subclass. | ||
""" | ||
raise NotImplemented(f'{self.__class__.__name__} have to implement the ' | ||
f'framework\'s prune_intermediate_node method.') # pragma: no cover | ||
|
||
@abstractmethod | ||
def prune_exit_node(self, | ||
node: BaseNode, | ||
input_mask: np.ndarray, | ||
fw_info: FrameworkInfo): | ||
""" | ||
Abstract method to prune an exit node in the model. | ||
Args: | ||
node: The node to be pruned. | ||
input_mask: Mask to be applied to the input channels. | ||
fw_info: Framework-specific information. | ||
Raises: | ||
NotImplemented: If the method is not implemented in the subclass. | ||
""" | ||
raise NotImplemented(f'{self.__class__.__name__} have to implement the ' | ||
f'framework\'s prune_exit_node method.') # pragma: no cover | ||
|
||
@abstractmethod | ||
def is_node_entry_node(self, | ||
node: BaseNode): | ||
""" | ||
Abstract method to determine if a given node is an entry node. | ||
Args: | ||
node: The node to be checked. | ||
Returns: | ||
bool: True if the node is an entry node, False otherwise. | ||
Raises: | ||
NotImplemented: If the method is not implemented in the subclass. | ||
""" | ||
raise NotImplemented(f'{self.__class__.__name__} have to implement the ' | ||
f'framework\'s is_node_entry_node method.') # pragma: no cover | ||
|
||
@abstractmethod | ||
def is_node_exit_node(self, | ||
node: BaseNode, | ||
dual_entry_node: BaseNode): | ||
""" | ||
Abstract method to determine if a given node is an exit node. | ||
Args: | ||
node: The node to be checked. | ||
dual_entry_node: Another node to be used in the determination process. | ||
Returns: | ||
bool: True if the node is an exit node, False otherwise. | ||
Raises: | ||
NotImplemented: If the method is not implemented in the subclass. | ||
""" | ||
raise NotImplemented(f'{self.__class__.__name__} have to implement the ' | ||
f'framework\'s is_node_exit_node method.') # pragma: no cover | ||
|
||
@abstractmethod | ||
def is_node_intermediate_pruning_section(self, | ||
node): | ||
""" | ||
Abstract method to determine if a given node is in the intermediate section of pruning. | ||
Args: | ||
node: The node to be checked. | ||
Returns: | ||
bool: True if the node is in the intermediate pruning section, False otherwise. | ||
Raises: | ||
NotImplemented: If the method is not implemented in the subclass. | ||
""" | ||
raise NotImplemented(f'{self.__class__.__name__} have to implement the ' | ||
f'framework\'s is_node_intermediate_pruning_section method.') # pragma: no cover | ||
|
||
@abstractmethod | ||
def get_pruned_node_num_params(self, | ||
node: BaseNode, | ||
input_mask: np.ndarray, | ||
output_mask: np.ndarray, | ||
fw_info: FrameworkInfo, | ||
include_null_channels: bool): | ||
""" | ||
Abstract method to get the number of parameters of a pruned node. | ||
Args: | ||
node: The node whose parameters are to be counted. | ||
input_mask: Mask to be applied to the input channels. | ||
output_mask: Mask to be applied to the output channels. | ||
fw_info: Framework-specific information. | ||
include_null_channels: Boolean flag to include or exclude null channels in the count. | ||
Returns: | ||
int: Number of parameters after pruning. | ||
Raises: | ||
NotImplemented: If the method is not implemented in the subclass. | ||
""" | ||
raise NotImplemented(f'{self.__class__.__name__} have to implement the ' | ||
f'framework\'s get_pruned_node_num_params method.') # pragma: no cover |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
116 changes: 116 additions & 0 deletions
116
model_compression_toolkit/core/keras/pruning/pruning_keras_implementation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
from model_compression_toolkit.core.common.framework_info import FrameworkInfo | ||
from model_compression_toolkit.core.common import BaseNode | ||
from model_compression_toolkit.core.common.pruning.pruning_framework_implementation import \ | ||
PruningFrameworkImplementation | ||
from model_compression_toolkit.core.keras.keras_implementation import KerasImplementation | ||
from model_compression_toolkit.core.keras.pruning.check_node_role import is_keras_node_intermediate_pruning_section, \ | ||
is_keras_entry_node, is_keras_exit_node | ||
from model_compression_toolkit.core.keras.pruning.count_node_params import get_keras_pruned_node_num_params | ||
from model_compression_toolkit.core.keras.pruning.prune_keras_node import (prune_keras_exit_node, | ||
prune_keras_entry_node, \ | ||
prune_keras_intermediate_node) | ||
import numpy as np | ||
|
||
class PruningKerasImplementation(KerasImplementation, PruningFrameworkImplementation): | ||
""" | ||
Implementation of the PruningFramework for the Keras framework. This class provides | ||
concrete implementations of the abstract methods defined in PruningFrameworkImplementation | ||
for the Keras framework. | ||
""" | ||
|
||
def prune_entry_node(self, node: BaseNode, output_mask: np.ndarray, fw_info: FrameworkInfo): | ||
""" | ||
Prunes the entry node of a model in Keras. | ||
Args: | ||
node: The entry node to be pruned. | ||
output_mask: A numpy array representing the mask to be applied to the output channels. | ||
fw_info: Framework-specific information object. | ||
Returns: | ||
The result from the pruning operation. | ||
""" | ||
return prune_keras_entry_node(node, output_mask, fw_info) | ||
|
||
def prune_intermediate_node(self, node: BaseNode, input_mask: np.ndarray, output_mask: np.ndarray, fw_info: FrameworkInfo): | ||
""" | ||
Prunes an intermediate node in a Keras model. | ||
Args: | ||
node: The intermediate node to be pruned. | ||
input_mask: A numpy array representing the mask to be applied to the input channels. | ||
output_mask: A numpy array representing the mask to be applied to the output channels. | ||
fw_info: Framework-specific information object. | ||
Returns: | ||
The result from the pruning operation. | ||
""" | ||
return prune_keras_intermediate_node(node, input_mask, output_mask, fw_info) | ||
|
||
def prune_exit_node(self, node: BaseNode, input_mask: np.ndarray, fw_info: FrameworkInfo): | ||
""" | ||
Prunes the exit node of a model in Keras. | ||
Args: | ||
node: The exit node to be pruned. | ||
input_mask: A numpy array representing the mask to be applied to the input channels. | ||
fw_info: Framework-specific information object. | ||
Returns: | ||
The result from the pruning operation. | ||
""" | ||
return prune_keras_exit_node(node, input_mask, fw_info) | ||
|
||
def is_node_entry_node(self, node: BaseNode): | ||
""" | ||
Determines whether a node is an entry node in a Keras model. | ||
Args: | ||
node: The node to be checked. | ||
Returns: | ||
Boolean indicating if the node is an entry node. | ||
""" | ||
return is_keras_entry_node(node) | ||
|
||
def is_node_exit_node(self, node: BaseNode, dual_entry_node: BaseNode): | ||
""" | ||
Determines whether a node is an exit node in a Keras model. | ||
Args: | ||
node: The node to be checked. | ||
dual_entry_node: A related entry node to assist in the determination. | ||
Returns: | ||
Boolean indicating if the node is an exit node. | ||
""" | ||
return is_keras_exit_node(node, dual_entry_node) | ||
|
||
def is_node_intermediate_pruning_section(self, node): | ||
""" | ||
Determines whether a node is part of the intermediate section in the pruning process of a Keras model. | ||
Args: | ||
node: The node to be checked. | ||
Returns: | ||
Boolean indicating if the node is part of the intermediate pruning section. | ||
""" | ||
return is_keras_node_intermediate_pruning_section(node) | ||
|
||
def get_pruned_node_num_params(self, node: BaseNode, input_mask: np.ndarray, output_mask: np.ndarray, fw_info: FrameworkInfo, include_null_channels: bool): | ||
""" | ||
Calculates the number of parameters in a pruned node of a Keras model. | ||
Args: | ||
node: The node whose parameters are to be counted. | ||
input_mask: Mask to be applied to the input channels. | ||
output_mask: Mask to be applied to the output channels. | ||
fw_info: Framework-specific information object. | ||
include_null_channels: Boolean flag to include or exclude null channels in the count. | ||
Returns: | ||
Integer representing the number of parameters in the pruned node. | ||
""" | ||
return get_keras_pruned_node_num_params(node, input_mask, output_mask, fw_info, include_null_channels) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters