Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keras structured SIMD pruning #871

Merged
merged 96 commits into from
Dec 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
96 commits
Select commit Hold shift + click to select a range
b276815
Init pruning support
Nov 28, 2023
96f6df5
Seperate is_entry and is_exit for keras nodes functions
Nov 28, 2023
3a8ffbb
split intermediate section mask to 2 masks
Nov 28, 2023
58eeedd
fixed pruned model to be trainable
Nov 28, 2023
c0fbf11
split keras functions into multiple files
Nov 28, 2023
877cf45
Add l2norm and params count to lfh scores
Nov 29, 2023
957133c
Add check for exit nodes #IC to match the #OC of their entry node
Nov 29, 2023
6a2c099
Consider null channels in graph memory computation
Nov 29, 2023
b50dc1c
Remove debug code
Nov 29, 2023
a3b3c92
Add PruningFrameworkImplementation
Nov 29, 2023
c1a3bc1
add small sections tests
Nov 29, 2023
476fcc8
Fix tf imports when tf was not found
Nov 29, 2023
73b9971
Run pruning tests in keras workflow
Nov 29, 2023
54e6fb3
add memory calculator test
Dec 3, 2023
0e69771
add simd padding to tpc
Dec 3, 2023
0433744
Take score computation out to a new LFH importance score calculator
Dec 3, 2023
8c7f0ee
split memory count from params count in memory calculator
Dec 3, 2023
6fd3fa3
rename pruning section attributes
Dec 3, 2023
92cdaf5
move has_matching_channel_count to common
Dec 3, 2023
74e4d86
rename pruner api
Dec 3, 2023
e64047e
move node count params to common memory calc
Dec 3, 2023
8dcbf85
Refactor pruned node params count
Dec 3, 2023
a5a636e
Add comments to memory calc
Dec 3, 2023
521910a
Add comments to LFH importance matric
Dec 3, 2023
c75791a
revert constant value that was changed during debug
Dec 4, 2023
522f93a
Add comments to importance metric factory
Dec 4, 2023
ac545c9
Use ChannelGrouping object to select the groups indices of each node
Dec 10, 2023
d7dc395
Add licenses to new files
Dec 10, 2023
631bce0
Undo change in qat tutorial
Dec 10, 2023
c452884
merge with main
Dec 10, 2023
5a4d5b4
Add tests to test suite for coverage inclusion
Dec 11, 2023
280803b
revert typehint fix
Dec 11, 2023
cb4331c
fix bad import of ImportanceMetric
Dec 11, 2023
7942276
Remove unneeded comments and spaces
Dec 11, 2023
044f785
add todo for moving KPI out of MP package
Dec 11, 2023
0e1bd39
Create PerChannelMask and PerSIMDGroupMask for holding and updating t…
Dec 12, 2023
a41e77f
Remove summing of scores of channels in an simd group
Dec 12, 2023
c62f458
Add documantation to mask calculator
Dec 12, 2023
4de024a
Move keras implementation functions into PruningKerasImplementation
Dec 12, 2023
5f9367d
remove fw_info from functions in base graph
Dec 12, 2023
c04a5ad
Refactor base graph pruning functions
Dec 12, 2023
94b0d4e
Add constant FP32_BYTES_PER_PARAMETER to use when computing memory fo…
Dec 12, 2023
0f47b8d
Use enum for mask indicator values
Dec 14, 2023
51dfe4e
Tests fixes
Dec 18, 2023
94828e5
add comments and assertions to get_simd in base_node
Dec 18, 2023
0c5bbcb
Add examples in the facade
Dec 18, 2023
199cd8c
Use 2 approximations in feature tests
Dec 18, 2023
2768d7c
Remove new function that was added to hessian service
Dec 19, 2023
e660dff
Add fn to calculate the num of out channels in a node in per-channel …
Dec 19, 2023
861d7ad
change remainder check in ChannelsGrouping
Dec 19, 2023
ac382e9
Add type hints and comments to ChannelsGrouping and LFHImportanceScore
Dec 20, 2023
64a7ca5
Move scores unrolling to PrunerInfo and test it
Dec 20, 2023
613869d
Extend feature tests with more simd tests
Dec 20, 2023
228865a
Remove old unused file
Dec 20, 2023
45a14a3
Rename trace occurances in lfh metric computation
Dec 20, 2023
badd233
add _get_kernel_node_oc_info to LFH metric
Dec 20, 2023
9d4c5c3
Add unit tests that use LFH and not only constant score
Dec 20, 2023
3c8c020
Add type hints to memory calculator
Dec 20, 2023
28319e7
Replace None with 1s mask when input mask can not be found
Dec 21, 2023
a917377
Add assertion for non-negative integer num oc in memory calc
Dec 21, 2023
1295c3c
Add comments to pruner and rename _create_pruned_graph to _prune_graph
Dec 21, 2023
81b8e45
Remove empty lines in hessian service info
Dec 21, 2023
0b531bc
Fix comment in pruning config
Dec 21, 2023
2fefe84
fix typehints
Dec 21, 2023
193671f
reformat pruning info
Dec 21, 2023
c59096f
rename pruning section method to get all section nodes
Dec 21, 2023
f17ab09
Add todo to rename feature tests
Dec 21, 2023
983a013
rename prune_graph fn in pruner
Dec 21, 2023
97ef899
organize imports of pruning facade
Dec 21, 2023
81edb8a
fix tpc comments and use properties where needed
Dec 21, 2023
e892522
use a single check for simd legal values instead of 2
Dec 24, 2023
97738c4
Use logger when asserting
Dec 24, 2023
3518e64
Use unittests assert when needed
Dec 24, 2023
fd5ebb2
use only one cr for testing pretrained models
Dec 24, 2023
04b3652
use keras kernel constant in random importance metric
Dec 24, 2023
0613e95
Add assertion in get_Attributes_info before iterating on kernel attri…
Dec 24, 2023
bf6c00f
rename get_node_attributes_with_oi_axis to attrs_oi_channels_info_for…
Dec 24, 2023
7f124e8
Merged from main
Dec 24, 2023
19100b0
remove todos
Dec 24, 2023
aff8068
use property in channels_grouping
Dec 24, 2023
dfb194e
remove todo
Dec 24, 2023
a30ab5d
add comments to mask files
Dec 24, 2023
4ffc45c
remove commented out code from pruner
Dec 24, 2023
becc499
remove pytorch init file
Dec 24, 2023
78e8d29
remove todos
Dec 24, 2023
5eab6d4
add example usage in notebooks
Dec 24, 2023
b492be6
Add notebook
Dec 25, 2023
6b305f1
fixes to notebook
Dec 25, 2023
38d61e1
add docs for pruning API
Dec 25, 2023
7513006
Rename importance metric factory
Dec 28, 2023
5cf8994
use mask indicator enum when initializing masks
Dec 28, 2023
449bd2e
replace a list of params count in LFH with a np array
Dec 28, 2023
60a24ea
Add fw_impl typehints
Dec 28, 2023
fa709b4
create pruning config separately in pruning tutorial
Dec 28, 2023
294f203
rename function name in memory calculator for specifying the input ma…
Dec 28, 2023
0c2290c
Update readme
Dec 28, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/run_keras_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ jobs:
# CPU environment (https://github.com/tensorflow/tensorflow/issues/41718).
# For this reason, if we run them in such an environment, we need to run them first non-parallel separately.
run: |
python -m unittest discover tests/keras_tests/pruning_tests -v
reuvenperetz marked this conversation as resolved.
Show resolved Hide resolved
python -m unittest discover tests/keras_tests/non_parallel_tests -v
for script in tests/keras_tests/exporter_tests tests/keras_tests/feature_networks_tests tests/keras_tests/graph_tests tests/keras_tests/layer_tests; do python -m unittest discover $script -v & pids+=($!); done; for pid in ${pids[@]}; do wait $pid || exit 1; done

25 changes: 24 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,29 @@ In the following table we present the ImageNet validation results for these mode

For more results, please refer to [quick start](https://github.com/sony/model_optimization/tree/main/tutorials/quick_start).

### Structured Pruning
MCT introduces a structured and hardware-aware model pruning.
This pruning technique is designed to compress models for specific hardware architectures,
taking into account the target platform's Single Instruction, Multiple Data (SIMD) capabilities.
By pruning groups of channels (SIMD groups), our approach not only reduces model size
and complexity, but ensures that better utilization of channels is in line with the SIMD architecture
for a target KPI of weights memory footprint.


<u>_Note: Currently, only Keras models pruning is supported._</u>

#### Results

Results for applying pruning to reduce the parameters of the following models by 50%:

| Model | Dense Model Accuracy | Pruned Model Accuracy |
|-----------------|----------------------|-----------------------|
| ResNet50 [2] | 75.1 | 72.4 |
| DenseNet121 [2] | 75.0 | 71.15 |




## Contributions
MCT aims at keeping a more up-to-date fork and welcomes contributions from anyone.

Expand All @@ -153,7 +176,7 @@ MCT aims at keeping a more up-to-date fork and welcomes contributions from anyon

[1] Habi, H.V., Peretz, R., Cohen, E., Dikstein, L., Dror, O., Diamant, I., Jennings, R.H. and Netzer, A., 2021. [HPTQ: Hardware-Friendly Post Training Quantization. arXiv preprint](https://arxiv.org/abs/2109.09113).

[2] [MobilNet](https://keras.io/api/applications/mobilenet/#mobilenet-function) from Keras applications.
[2] [Keras Applications](https://keras.io/api/applications/)

[3] [TORCHVISION.MODELS](https://pytorch.org/vision/stable/models.html)

Expand Down
1 change: 1 addition & 0 deletions docsrc/source/api/experimental_api_docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ Functions
- :ref:`get_tensorflow_data_generation_config<ug-get_tensorflow_data_generation_config>`: A function to generate a DataGenerationConfig for Tensorflow data generation(experimental).
- :ref:`pytorch_data_generation_experimental<ug-pytorch_data_generation_experimental>`: A function to generate data for a Pytorch model (experimental).
- :ref:`get_pytorch_data_generation_config<ug-get_pytorch_data_generation_config>`: A function to load a DataGenerationConfig for Pytorch data generation (experimental).
- :ref:`keras_pruning_experimental<ug-keras_pruning_experimental>`: A function to apply structured pruning for Keras models (experimental).


Modules
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
:orphan:

.. _ug-keras_pruning_experimental:


================================================
Keras Structured Pruning
================================================

.. autofunction:: model_compression_toolkit.pruning.keras_pruning_experimental

================================================
Pruning Configuration
================================================

.. autofunction:: model_compression_toolkit.pruning.PruningConfig



================================================
Pruning Information
================================================

.. autofunction:: model_compression_toolkit.pruning.PruningInfo

1 change: 1 addition & 0 deletions docsrc/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ Keras:
* :ref:`Mixed-precision post training quantization<ug-keras_post_training_quantization_mixed_precision>`
* :ref:`Init model for Quantization Aware Training<ug-keras_quantization_aware_training_init>` (Experimental)
* :ref:`Finalize model after Quantization Aware Training<ug-keras_quantization_aware_training_finalize>` (Experimental)
* :ref:`Structured Pruning<ug-keras_pruning_experimental>` (Experimental)

Pytorch:

Expand Down
1 change: 1 addition & 0 deletions model_compression_toolkit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from model_compression_toolkit import exporter
from model_compression_toolkit import gptq
from model_compression_toolkit import data_generation
from model_compression_toolkit import pruning
from model_compression_toolkit.trainable_infrastructure.keras.load_model import keras_load_quantized_model


Expand Down
4 changes: 4 additions & 0 deletions model_compression_toolkit/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
MIN_THRESHOLD = (2 ** -16)
EPS = 1e-8
LUT_VALUES_BITWIDTH = 8
FP32_BYTES_PER_PARAMETER = 4.

# Quantization attributes:
OUTPUT_SCALE = 'output_scale'
Expand Down Expand Up @@ -127,3 +128,6 @@
HESSIAN_OUTPUT_ALPHA = 0.3
HESSIAN_NUM_ITERATIONS = 50
HESSIAN_EPS = 1e-6

# Pruning constants
PRUNING_NUM_SCORE_APPROXIMATIONS = 32
114 changes: 114 additions & 0 deletions model_compression_toolkit/core/common/graph/base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from model_compression_toolkit.core.common.graph.base_node import BaseNode
from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector
from model_compression_toolkit.core.common.collectors.statistics_collector import scale_statistics, shift_statistics
from model_compression_toolkit.core.common.pruning.pruning_section import PruningSection
from model_compression_toolkit.core.common.user_info import UserInformation
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import \
Expand Down Expand Up @@ -726,3 +727,116 @@ def replace_node(self, node_to_replace: BaseNode, new_node: BaseNode):
self.replace_output_node(node_to_replace, new_node)
self.replace_input_node(node_to_replace, new_node)
self.remove_node(node_to_replace)

def get_pruning_sections(self,
fw_impl: Any) -> List[PruningSection]:
"""
Constructs pruning sections for a given computational graph.
Each section is created starting from an entry node and includes intermediate and exit nodes.

Args:
fw_impl (PruningFrameworkImplementation): Implementation of specific framework methods required for pruning.

Returns: List of PruningSection in the graph.
"""
entry_nodes = self.get_pruning_sections_entry_nodes(fw_impl)
return [self._create_pruning_section(entry_node, fw_impl) for entry_node in entry_nodes]

def get_pruning_sections_entry_nodes(self, fw_impl: Any) -> List[BaseNode]:
"""
Identifies entry nodes for pruning sections within the graph.
Traverses the graph in a topological order, checking each node for prunability criteria.
Returns a list of nodes that mark the beginning of a prunable section in the graph.

Args:
fw_impl (PruningFrameworkImplementation): Implementation of specific framework methods required for pruning.

Returns: List of nodes that are entry nodes in the pruning sections of the graph.

"""
prunable_nodes = []
for n in list(topological_sort(self)):
if fw_impl.is_node_entry_node(n) and self._is_node_topology_prunable(n, fw_impl):
prunable_nodes.append(n)
return prunable_nodes

def _is_node_topology_prunable(self, entry_node: BaseNode, fw_impl: Any) -> bool:
"""
Determines if the topology starting from a given entry node is suitable for pruning.
Iteratively examines the graph structure, focusing on node connectivity and pruning criteria.
Returns True if the topology is prunable, False otherwise.

Args:
entry_node (BaseNode): The node to start the topology check from.
fw_impl (PruningFrameworkImplementation): Implementation of specific framework methods required for pruning.

Returns: Whether this node is a start of a pruning section according to the graph topology or not.

"""
next_node = entry_node

# Continue iterating until the conditions for prunability are no longer met
while len(self.out_edges(next_node)) == 1:
next_node = self.out_edges(next_node)[0].sink_node

# If next_node is an exit node and has only one incoming edge, the topology is prunable.
if fw_impl.is_node_exit_node(next_node, entry_node, self.fw_info) and len(self.in_edges(next_node)) == 1:
return True

# If the next node is not an intermediate node or has more than one incoming/outgoing edge,
# stop the check.
if not fw_impl.is_node_intermediate_pruning_section(next_node) or len(self.in_edges(next_node)) != 1 or len(self.out_edges(next_node)) != 1:
return False

# If the loop exits normally, it implies that the topology is not prunable
return False


def _create_pruning_section(self, entry_node: BaseNode, fw_impl: Any) -> PruningSection:
"""
Creates a PruningSection object starting from a given entry node.
Includes logic to find intermediate and exit nodes to complete the section.
Ensures the provided entry node is a valid starting point for pruning.

Args:
entry_node (BaseNode): The entry node to create the section it starts.
fw_impl (PruningFrameworkImplementation): Implementation of specific framework methods required for pruning.

Returns: The pruning section that starts with node entry_node.

"""
if not fw_impl.is_node_entry_node(entry_node):
Logger.error(f"Expected to find an entry node to create its pruning section,"
f"but node {entry_node} is not an entry node.")

intermediate_nodes, exit_node = self._find_intermediate_and_exit_nodes(entry_node, fw_impl)

if not fw_impl.is_node_exit_node(exit_node, entry_node, self.fw_info):
Logger.error(f"Expected to find exit node when creating a pruning section,"
f"but node {exit_node} is not an exit node.")

return PruningSection(entry_node=entry_node,
intermediate_nodes=intermediate_nodes,
exit_node=exit_node)

def _find_intermediate_and_exit_nodes(self, entry_node: BaseNode, fw_impl: Any) -> Tuple[List[BaseNode], BaseNode]:
"""
Identifies intermediate and exit nodes for a pruning section starting from an entry node.
Iterates through connected nodes to build the complete structure of the pruning section.

Args:
entry_node (BaseNode): An entry node to find the intermediate and exit nodes of its section.
fw_impl (PruningFrameworkImplementation): Implementation of specific framework methods required for pruning.

Returns: A tuple containing a list of intermediate nodes and the exit node.

"""
intermediate_nodes = []
next_node = self.out_edges(entry_node)[0].sink_node
while not fw_impl.is_node_exit_node(next_node, entry_node, self.fw_info):
intermediate_nodes.append(next_node)
next_node = self.out_edges(next_node)[0].sink_node

return intermediate_nodes, next_node


30 changes: 25 additions & 5 deletions model_compression_toolkit/core/common/graph/base_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import numpy as np

from model_compression_toolkit.constants import WEIGHTS_NBITS_ATTRIBUTE, CORRECTED_BIAS_ATTRIBUTE, \
ACTIVATION_NBITS_ATTRIBUTE
ACTIVATION_NBITS_ATTRIBUTE, FP32_BYTES_PER_PARAMETER
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationConfigOptions, \
TargetPlatformCapabilities, LayerFilterParams
Expand Down Expand Up @@ -222,9 +222,9 @@ def get_memory_bytes(self, fw_info) -> float:
"""
q_params, f_params = self.get_num_parameters(fw_info)
if self.final_weights_quantization_cfg is None: # float coefficients
memory = (f_params+q_params) * 4
memory = (f_params+q_params) * FP32_BYTES_PER_PARAMETER
else:
memory = (f_params*4)+ (q_params * self.final_weights_quantization_cfg.weights_n_bits / 8) # in bytes
memory = (f_params * FP32_BYTES_PER_PARAMETER) + (q_params * self.final_weights_quantization_cfg.weights_n_bits / 8) # in bytes

return memory

Expand All @@ -239,7 +239,7 @@ def get_float_memory_bytes(self, fw_info) -> float:

"""
q_params, f_params = self.get_num_parameters(fw_info)
return (f_params + q_params) * 32 / 8 # in bytes
return (f_params + q_params) * FP32_BYTES_PER_PARAMETER

def get_unified_weights_candidates_dict(self):
"""
Expand Down Expand Up @@ -499,4 +499,24 @@ def is_match_filter_params(self, layer_filter_params: LayerFilterParams) -> bool
if not c.match(layer_config):
return False

return True
return True

def get_simd(self) -> int:
"""
Retrieves the SIMD size used for this node. It collects the SIMD sizes from all candidate
configurations and returns the minimum SIMD size.

Returns:
int: The node's SIMD size.

"""
simd_list = [qc.weights_quantization_cfg.simd_size for qc in self.candidates_quantization_cfg]
if len(simd_list) > 1:
Logger.warning(f"More than one pruning SIMD option is available."
f" Min SIMD is used: {min(simd_list)}")
if len(simd_list) == 0:
Logger.error(f"No SIMD option is available for {self}")
_simd = min(simd_list)
if _simd <= 0 or int(_simd) != _simd:
Logger.error(f"SIMD is expected to be a non-positive integer but found: {_simd}")
return _simd
16 changes: 16 additions & 0 deletions model_compression_toolkit/core/common/pruning/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================


reuvenperetz marked this conversation as resolved.
Show resolved Hide resolved
Loading
Loading