Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add structured pruning support for PyTorch models. #938

Merged
merged 8 commits into from
Mar 4, 2024

Conversation

lior-dikstein
Copy link
Collaborator

@lior-dikstein lior-dikstein commented Jan 31, 2024

This commit introduces a new function 'pytorch_pruning_experimental' to perform structured pruning on PyTorch models

Pull Request Description:

Checklist before requesting a review:

  • I set the appropriate labels on the pull request.
  • I have added/updated the release note draft (if necessary).
  • I have updated the documentation to reflect my changes (if necessary).
  • All function and files are well documented.
  • All function and classes have type hints.
  • There is a licenses in all file.
  • The function and variable names are informative.
  • I have checked for code duplications.
  • I have added new unittest (if necessary).

@@ -0,0 +1,14 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix package name

else:
Logger.error("Number of out channels are not the same for all outputs of the node")
else:
num_oc = node.output_shape[channel_axis]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In case output_mask is None, and fw_info.out_channel_axis_mapping returns None (in BN, for example), how will it work? I'm missing something, let's go over this part offline

attributes_with_axis[attr] = (None, None)
else:
attributes_with_axis[attr] = (-1, None)
# attributes_with_axis[attr] = (-1, None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove commented out code

fw_info (FrameworkInfo): Framework-specific information object.

"""
pruning_en = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like the pruning_en flag. If this node should not be pruned, we shouldn't get here, so I think it's better to apply this logic in the is_node_intermediate_pruning_section method.

# 3. The input channel axis is irrelevant since these attributes are pruned only by
# their output channels.
for attr in list(node.weights.keys()):
# if the number of float parameters is 1 or less
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Expand this comment and please detail about the PRELU example, so it'll be clearer

if node.type in [torch.nn.Conv2d, torch.nn.ConvTranspose2d]:
node.framework_attr[OUT_CHANNELS] = int(np.sum(mask))
elif node.type == torch.nn.Linear:
node.framework_attr[OUT_FEATURES] = int(np.sum(mask))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there are no more valid cases, raise an exception...

if node.type in [torch.nn.Conv2d, torch.nn.ConvTranspose2d]:
node.framework_attr[IN_CHANNELS] = int(np.sum(mask))
elif node.type == torch.nn.Linear:
node.framework_attr[IN_FEATURES] = int(np.sum(mask))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above



# Function to generate an infinite stream of dummy images and labels
def dummy_data_generator():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're not using it, then remove it. But I think it's better to add a flag that indicates whether to test it or not, which is disabled by default (and at least test it once when you run it locally to see if the expected post-pruning retraining is possible).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added retraining

@reuvenperetz
Copy link
Collaborator

reuvenperetz commented Feb 8, 2024

Please don't forget to add a tutorial and adapt docs.

liord added 2 commits February 26, 2024 11:32
This commit introduces a new function 'pytorch_pruning_experimental' to perform structured pruning on PyTorch models
"source": [
"# Structured Pruning of a Fully-Connected PyTorch Model\n",
"\n",
"Welcome to this tutorial, where we will guide you through the process of training, pruning, and retraining a fully connected neural network model using the PyTorch framework. The tutorial is organized in the following sections: \n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a google colab link for this tutorial.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need the file to be pushed first or else the test fails on missing link

"execution_count": null,
"outputs": [],
"source": [
"pruned_model_retrained = train_model(pruned_model, train_loader, test_loader, device, epochs=1)"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we need more epochs here? Does it get good results when you run this notebook?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes I will fix it

@@ -0,0 +1,412 @@
{
"cells": [
{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a reference to this tutorial in the notebooks README.

@reuvenperetz
Copy link
Collaborator

Please do not forget to update MCT README.

@reuvenperetz
Copy link
Collaborator

@lior-dikstein lior-dikstein merged commit 40ee30d into main Mar 4, 2024
23 of 24 checks passed
@lior-dikstein lior-dikstein deleted the pruning_pytorch branch March 4, 2024 08:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants