-
Notifications
You must be signed in to change notification settings - Fork 60
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add structured pruning support for PyTorch models. #938
Conversation
@@ -0,0 +1,14 @@ | |||
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix package name
else: | ||
Logger.error("Number of out channels are not the same for all outputs of the node") | ||
else: | ||
num_oc = node.output_shape[channel_axis] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In case output_mask is None, and fw_info.out_channel_axis_mapping returns None (in BN, for example), how will it work? I'm missing something, let's go over this part offline
attributes_with_axis[attr] = (None, None) | ||
else: | ||
attributes_with_axis[attr] = (-1, None) | ||
# attributes_with_axis[attr] = (-1, None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove commented out code
fw_info (FrameworkInfo): Framework-specific information object. | ||
|
||
""" | ||
pruning_en = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't like the pruning_en flag. If this node should not be pruned, we shouldn't get here, so I think it's better to apply this logic in the is_node_intermediate_pruning_section method.
# 3. The input channel axis is irrelevant since these attributes are pruned only by | ||
# their output channels. | ||
for attr in list(node.weights.keys()): | ||
# if the number of float parameters is 1 or less |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Expand this comment and please detail about the PRELU example, so it'll be clearer
if node.type in [torch.nn.Conv2d, torch.nn.ConvTranspose2d]: | ||
node.framework_attr[OUT_CHANNELS] = int(np.sum(mask)) | ||
elif node.type == torch.nn.Linear: | ||
node.framework_attr[OUT_FEATURES] = int(np.sum(mask)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If there are no more valid cases, raise an exception...
if node.type in [torch.nn.Conv2d, torch.nn.ConvTranspose2d]: | ||
node.framework_attr[IN_CHANNELS] = int(np.sum(mask)) | ||
elif node.type == torch.nn.Linear: | ||
node.framework_attr[IN_FEATURES] = int(np.sum(mask)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above
|
||
|
||
# Function to generate an infinite stream of dummy images and labels | ||
def dummy_data_generator(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you're not using it, then remove it. But I think it's better to add a flag that indicates whether to test it or not, which is disabled by default (and at least test it once when you run it locally to see if the expected post-pruning retraining is possible).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added retraining
Please don't forget to add a tutorial and adapt docs. |
This commit introduces a new function 'pytorch_pruning_experimental' to perform structured pruning on PyTorch models
7ef112d
to
1cab111
Compare
"source": [ | ||
"# Structured Pruning of a Fully-Connected PyTorch Model\n", | ||
"\n", | ||
"Welcome to this tutorial, where we will guide you through the process of training, pruning, and retraining a fully connected neural network model using the PyTorch framework. The tutorial is organized in the following sections: \n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a google colab link for this tutorial.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I need the file to be pushed first or else the test fails on missing link
"execution_count": null, | ||
"outputs": [], | ||
"source": [ | ||
"pruned_model_retrained = train_model(pruned_model, train_loader, test_loader, device, epochs=1)" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't we need more epochs here? Does it get good results when you run this notebook?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes I will fix it
@@ -0,0 +1,412 @@ | |||
{ | |||
"cells": [ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a reference to this tutorial in the notebooks README.
Please do not forget to update MCT README. |
Please update the notebooks readme https://github.com/sony/model_optimization/blob/main/tutorials/notebooks/README.md |
This commit introduces a new function 'pytorch_pruning_experimental' to perform structured pruning on PyTorch models
Pull Request Description:
Checklist before requesting a review: