Skip to content

Commit

Permalink
fix data generation and pruning
Browse files Browse the repository at this point in the history
  • Loading branch information
reuvenp committed Mar 17, 2024
1 parent 51334f1 commit 4bff9c5
Show file tree
Hide file tree
Showing 10 changed files with 163 additions and 32 deletions.
60 changes: 59 additions & 1 deletion docsrc/source/api/api_docs/classes/DataGenerationConfig.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,66 @@
.. _ug-DataGenerationConfig:

================================================
Data Generation Config
Data Generation Configuration
================================================

.. autoclass:: model_compression_toolkit.data_generation.DataGenerationConfig
:members:


ImageGranularity
================================================

.. autoclass:: model_compression_toolkit.data_generation.ImageGranularity
:members:


SchedulerType
================================================

.. autoclass:: model_compression_toolkit.data_generation.SchedulerType
:members:



BatchNormAlignemntLossType
================================================

.. autoclass:: model_compression_toolkit.data_generation.BatchNormAlignemntLossType
:members:


OutputLossType
================================================

.. autoclass:: model_compression_toolkit.data_generation.OutputLossType
:members:


DataInitType
================================================

.. autoclass:: model_compression_toolkit.data_generation.DataInitType
:members:


BNLayerWeightingType
================================================

.. autoclass:: model_compression_toolkit.data_generation.BNLayerWeightingType
:members:


ImagePipelineType
================================================

.. autoclass:: model_compression_toolkit.data_generation.ImagePipelineType
:members:


ImageNormalizationType
================================================

.. autoclass:: model_compression_toolkit.data_generation.ImageNormalizationType
:members:

12 changes: 12 additions & 0 deletions docsrc/source/api/api_docs/classes/MpDistanceWeighting.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
:orphan:

.. _ug-MpDistanceWeighting:


=================================
MpDistanceWeighting
=================================

.. autoclass:: model_compression_toolkit.core.MpDistanceWeighting
:members:

20 changes: 20 additions & 0 deletions docsrc/source/api/api_docs/classes/PruningConfig.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,23 @@ Pruning Configuration

.. autoclass:: model_compression_toolkit.pruning.PruningConfig
:members:


ImportanceMetric
================================================

.. autoclass:: model_compression_toolkit.pruning.ImportanceMetric
:members:


ChannelsFilteringStrategy
================================================

.. autoclass:: model_compression_toolkit.pruning.ChannelsFilteringStrategy
:members:






3 changes: 2 additions & 1 deletion docsrc/source/api/api_docs/classes/PruningInfo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@
Pruning Information
================================================

.. autofunction:: model_compression_toolkit.pruning.PruningInfo
.. autoclass:: model_compression_toolkit.pruning.PruningInfo
:members:

1 change: 1 addition & 0 deletions docsrc/source/api/api_docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ core
- :ref:`QuantizationErrorMethod<ug-QuantizationErrorMethod>`: Select a method for quantization parameters' selection.
- :ref:`MixedPrecisionQuantizationConfig<ug-MixedPrecisionQuantizationConfig>`: Module to configure the quantization process when using mixed-precision PTQ.
- :ref:`KPI<ug-KPI>`: Module to configure resources to use when searching for a configuration for the optimized model.
- :ref:`MpDistanceWeighting<ug-MpDistanceWeighting>`: Mixed precision distance metric weighting methods.
- :ref:`network_editor<ug-network_editor>`: Module to modify the optimization process for troubleshooting.
- :ref:`FolderImageLoader<ug-FolderImageLoader>`: Class to use an images directory as a representative dataset.
- :ref:`pytorch_kpi_data<ug-pytorch_kpi_data>`: A function to compute KPI data that can be used to calculate the desired target KPI for PyTorch models.
Expand Down
2 changes: 2 additions & 0 deletions model_compression_toolkit/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,5 @@
from model_compression_toolkit.core.common.mixed_precision.mixed_precision_quantization_config import MixedPrecisionQuantizationConfig
from model_compression_toolkit.core.keras.kpi_data_facade import keras_kpi_data
from model_compression_toolkit.core.pytorch.kpi_data_facade import pytorch_kpi_data
from model_compression_toolkit.core.common.mixed_precision.distance_weighting import MpDistanceWeighting

10 changes: 8 additions & 2 deletions model_compression_toolkit/core/common/pruning/pruning_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,20 @@

class ImportanceMetric(Enum):
"""
Enum for specifying the metric used to determine the importance of channels when pruning.
Enum for specifying the metric used to determine the importance of channels when pruning:
LFH - Label-Free Hessian uses hessian info for measuring each channel's sensitivity.
"""
LFH = 0 # Score based on the Hessian matrix w.r.t. layers weights, to determine channel importance without labels.


class ChannelsFilteringStrategy(Enum):
"""
Enum for specifying the strategy used for filtering (pruning) channels.
Enum for specifying the strategy used for filtering (pruning) channels:
GREEDY - Prune the least important channel groups up to allowed resources in the KPI (for now, only weights_memory is considered).
"""
GREEDY = 0 # Greedy strategy for pruning channels based on importance metrics.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class PruningInfo:
and importance scores for each layer. This class acts as a container for accessing
pruning-related metadata.
Attributes:
Args:
pruning_masks (Dict[BaseNode, np.ndarray]): Stores the pruning masks for each layer.
A pruning mask is an array where each element indicates whether the corresponding
channel or neuron has been pruned (0) or kept (1).
Expand Down
2 changes: 2 additions & 0 deletions model_compression_toolkit/data_generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from model_compression_toolkit.constants import FOUND_TORCH, FOUND_TF
from model_compression_toolkit.data_generation.common.data_generation_config import DataGenerationConfig
from model_compression_toolkit.data_generation.common.enums import ImageGranularity, DataInitType, SchedulerType, BNLayerWeightingType, OutputLossType, BatchNormAlignemntLossType, ImagePipelineType, ImageNormalizationType

if FOUND_TF:
from model_compression_toolkit.data_generation.keras.keras_data_generation import (
Expand Down
83 changes: 56 additions & 27 deletions model_compression_toolkit/data_generation/common/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,14 @@ def get_values(cls):

class ImageGranularity(EnumBaseClass):
"""
An enum for choosing the image dependence granularity when generating images.
0. ImageWise
1. BatchWise
2. AllImages
An enum for choosing the image dependence granularity when generating images:
ImageWise
BatchWise
AllImages
"""

ImageWise = 0
Expand All @@ -42,19 +46,27 @@ class ImageGranularity(EnumBaseClass):

class DataInitType(EnumBaseClass):
"""
An enum for choosing the image dependence granularity when generating images.
0. Gaussian
1. Diverse
An enum for choosing the image dependence granularity when generating images:
Gaussian
Diverse
"""
Gaussian = 0
Diverse = 1


class ImagePipelineType(EnumBaseClass):
"""
An enum for choosing the image pipeline type for image manipulation.
RANDOM_CROP_FLIP: Crop and flip the images.
IDENTITY: Do not apply any manipulation (identity transformation).
An enum for choosing the image pipeline type for image manipulation:
RANDOM_CROP - Crop the images.
RANDOM_CROP_FLIP - Crop and flip the images.
IDENTITY - Do not apply any manipulation (identity transformation).
"""
RANDOM_CROP = 'random_crop'
RANDOM_CROP_FLIP = 'random_crop_flip'
Expand All @@ -63,10 +75,14 @@ class ImagePipelineType(EnumBaseClass):

class ImageNormalizationType(EnumBaseClass):
"""
An enum for choosing the image normalization type.
TORCHVISION: Normalize the images using torchvision normalization.
KERAS_APPLICATIONS: Normalize the images using keras_applications imagenet normalization.
NO_NORMALIZATION: Do not apply any normalization.
An enum for choosing the image normalization type:
TORCHVISION - Normalize the images using torchvision normalization.
KERAS_APPLICATIONS - Normalize the images using keras_applications imagenet normalization.
NO_NORMALIZATION - Do not apply any normalization.
"""
TORCHVISION = 'torchvision'
KERAS_APPLICATIONS = 'keras_applications'
Expand All @@ -75,10 +91,14 @@ class ImageNormalizationType(EnumBaseClass):

class BNLayerWeightingType(EnumBaseClass):
"""
An enum for choosing the layer weighting type.
AVERAGE: Use the same weight per layer.
FIRST_LAYER_MULTIPLIER: Use a multiplier for the first layer, all other layers with the same weight.
GRAD: Use gradient-based layer weighting.
An enum for choosing the layer weighting type:
AVERAGE - Use the same weight per layer.
FIRST_LAYER_MULTIPLIER - Use a multiplier for the first layer, all other layers with the same weight.
GRAD - Use gradient-based layer weighting.
"""
AVERAGE = 'average'
FIRST_LAYER_MULTIPLIER = 'first_layer_multiplier'
Expand All @@ -87,18 +107,24 @@ class BNLayerWeightingType(EnumBaseClass):

class BatchNormAlignemntLossType(EnumBaseClass):
"""
An enum for choosing the BatchNorm alignment loss type.
L2_SQUARE: Use L2 square loss for BatchNorm alignment.
An enum for choosing the BatchNorm alignment loss type:
L2_SQUARE - Use L2 square loss for BatchNorm alignment.
"""
L2_SQUARE = 'l2_square'


class OutputLossType(EnumBaseClass):
"""
An enum for choosing the output loss type.
NONE: No output loss is applied.
MIN_MAX_DIFF: Use min-max difference as the output loss.
REGULARIZED_MIN_MAX_DIFF: Use regularized min-max difference as the output loss.
An enum for choosing the output loss type:
NONE - No output loss is applied.
MIN_MAX_DIFF - Use min-max difference as the output loss.
REGULARIZED_MIN_MAX_DIFF - Use regularized min-max difference as the output loss.
"""
NONE = 'none'
MIN_MAX_DIFF = 'min_max_diff'
Expand All @@ -107,9 +133,12 @@ class OutputLossType(EnumBaseClass):

class SchedulerType(EnumBaseClass):
"""
An enum for choosing the scheduler type for the optimizer.
REDUCE_ON_PLATEAU: Use the ReduceOnPlateau scheduler.
STEP: Use the Step scheduler.
An enum for choosing the scheduler type for the optimizer:
REDUCE_ON_PLATEAU - Use the ReduceOnPlateau scheduler.
STEP - Use the Step scheduler.
"""
REDUCE_ON_PLATEAU = 'reduce_on_plateau'
STEP = 'step'

0 comments on commit 4bff9c5

Please sign in to comment.