Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pytorch substitution to remove Identity layers #1043

Merged
merged 4 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from torch import reshape
import torch

from model_compression_toolkit.logger import Logger
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
from model_compression_toolkit.core import common
from model_compression_toolkit.core.common.graph.base_graph import Graph
from model_compression_toolkit.core.common.graph.base_node import BaseNode
from model_compression_toolkit.core.pytorch.constants import BATCH_DIM_VALUE


class RemoveIdentity(common.BaseSubstitution):
"""
Remove `torch.nn.Identity` layers from the graph.
"""

def __init__(self):
nodes = NodeOperationMatcher(torch.nn.Identity)
super().__init__(matcher_instance=nodes)

def substitute(self,
graph: Graph,
node: BaseNode) -> Graph:
"""
The method to perform the substitution of the `torch.nn.Identity` node by
reconnecting its input directly to its output, effectively removing the node
from the graph.

Args:
graph: The current graph of operations where the node resides.
node: The specific `BaseNode` that is matched to be an Identity operation.

Returns:
Graph: The updated graph after removing the identity node.
"""

# Retrieve the predecessor nodes of the identity node.
prev_identity_nodes = graph.get_prev_nodes(node)
# Ensure there is exactly one predecessor; otherwise, do nothing.
if len(prev_identity_nodes) != 1:
return graph

# Reconnect the output edges of the identity node to its predecessor,
# effectively bypassing the identity node.
graph.reconnect_out_edges(current_node=node, new_node=prev_identity_nodes[0])
# Remove the edge from the predecessor to the identity node.
graph.remove_edge(prev_identity_nodes[0], node)
# Remove the identity node from the graph.
graph.remove_node(node_to_remove=node)

return graph

Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
FunctionalConvSubstitution
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.relu_bound_to_power_of_2 import \
ReLUBoundToPowerOfTwo
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.remove_identity import RemoveIdentity
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.reshape_with_static_shapes import \
ReshapeWithStaticShapes
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.residual_collapsing import \
Expand Down Expand Up @@ -238,7 +239,8 @@ def get_substitutions_prepare_graph(self, fw_info: FrameworkInfo = None) -> List
PermuteCallMethod(),
FunctionalConvSubstitution(fw_info),
FunctionalBatchNorm(),
FunctionalLayerNorm()]
FunctionalLayerNorm(),
RemoveIdentity()]

def get_substitutions_pre_statistics_collection(self,
quant_config: QuantizationConfig
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import copy

import unittest
import torch
from mct_quantizers import PytorchActivationQuantizationHolder, PytorchQuantizationWrapper
Expand All @@ -23,11 +25,9 @@ def __init__(self, num_channels=3, kernel_size=1):
super(BasicModel, self).__init__()
self.conv1 = Conv2d(num_channels, num_channels, kernel_size=kernel_size, bias=False)
self.conv2 = Conv2d(num_channels, num_channels, kernel_size=kernel_size, bias=False)
self.identity = torch.nn.Identity()

def forward(self, inp):
x = self.conv1(inp)
x = self.identity(x)
x = self.conv2(x)
return x

Expand All @@ -51,11 +51,9 @@ class ReuseModel(torch.nn.Module):
def __init__(self, num_channels=3, kernel_size=1):
super(ReuseModel, self).__init__()
self.conv = Conv2d(num_channels, num_channels, kernel_size=kernel_size, bias=False)
self.identity = torch.nn.Identity()

def forward(self, inp):
x = self.conv(inp)
x = self.identity(x)
x = self.conv(x)
return x

Expand All @@ -73,7 +71,7 @@ def test_adding_holder_instead_quantize_wrapper(self):
# the last module should be an activation quantization holder
self.assertTrue(isinstance(last_module, PytorchActivationQuantizationHolder))
# check that 4 activation quantization holders where generated
self.assertTrue(len(activation_quantization_holders_in_model) == 4)
self.assertTrue(len(activation_quantization_holders_in_model) == 3)
for a in activation_quantization_holders_in_model:
self.assertTrue(isinstance(a.activation_holder_quantizer, ActivationPOTInferableQuantizer))
for name, module in gptq_model.named_modules():
Expand Down Expand Up @@ -102,7 +100,7 @@ def test_adding_holders_after_reuse(self):
# the last module should be an activation quantization holder
self.assertTrue(isinstance(last_module, PytorchActivationQuantizationHolder))
# check that 4 activation quantization holders where generated
self.assertTrue(len(activation_quantization_holders_in_model) == 4)
self.assertTrue(len(activation_quantization_holders_in_model) == 3)
for a in activation_quantization_holders_in_model:
self.assertTrue(isinstance(a.activation_holder_quantizer, ActivationPOTInferableQuantizer))
for name, module in gptq_model.named_modules():
Expand All @@ -115,13 +113,16 @@ def test_adding_holders_after_reuse(self):

def _get_gptq_model(self, input_shape, in_model):
pytorch_impl = GPTQPytorchImplemantation()
qc = copy.deepcopy(mct.core.DEFAULTCONFIG)
qc.linear_collapsing = False
graph = prepare_graph_with_quantization_parameters(in_model,
pytorch_impl,
DEFAULT_PYTORCH_INFO,
representative_dataset,
generate_pytorch_tpc,
[1] + input_shape,
mixed_precision_enabled=False)
mixed_precision_enabled=False,
qc=qc)
graph = set_bit_widths(mixed_precision_enable=False,
graph=graph)
trainer = PytorchGPTQTrainer(graph,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self):

def forward(self, x):
x1 = self.conv1(x)
x2 = self.identity(x1)
x2 = torch.relu(x1)
x3 = self.conv2(x2)
x4 = torch.relu(x3)
return x, x1, x2, x3, x4
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,12 @@ def __init__(self):
self.conv3 = Conv2d(3, 3, kernel_size=1, stride=1)
self.conv4 = Conv2d(3, 3, kernel_size=1, stride=1)
self.relu2 = ReLU()
self.identity = torch.nn.Identity()

def forward(self, inp):
x = self.conv1(inp)
x = self.relu1(x)
x = self.conv2(x)
x = self.identity(x)
x = relu(x)
x = self.conv3(x)
x = relu6(x)
x = self.conv4(x)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import torch

from tests.pytorch_tests.model_tests.base_pytorch_feature_test import BasePytorchFeatureNetworkTest


class RemoveIdentityNet(torch.nn.Module):
def __init__(self):
super(RemoveIdentityNet, self).__init__()
self.conv1 = torch.nn.Conv2d(3, 3, kernel_size=1, stride=1)
self.identity = torch.nn.Identity()
self.bn1 = torch.nn.BatchNorm2d(3)

def forward(self, x):
x = self.conv1(x)
x = self.identity(x)
x = self.bn1(x)
return x


class RemoveIdentityTest(BasePytorchFeatureNetworkTest):

def __init__(self, unit_test):
super().__init__(unit_test)

def create_networks(self):
return RemoveIdentityNet()

def compare(self,
quantized_model,
float_model,
input_x=None,
quantization_info=None):
for n,m in quantized_model.named_modules():
# make sure identity was removed and bn was folded into the conv
self.unit_test.assertFalse(isinstance(m, torch.nn.Identity) or isinstance(m, torch.nn.BatchNorm2d))

6 changes: 6 additions & 0 deletions tests/pytorch_tests/model_tests/test_feature_models_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,15 @@
from tests.pytorch_tests.model_tests.feature_models.const_representation_test import ConstRepresentationTest, \
ConstRepresentationMultiInputTest
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
from tests.pytorch_tests.model_tests.feature_models.remove_identity_test import RemoveIdentityTest


class FeatureModelsTestRunner(unittest.TestCase):
def test_remove_identity(self):
"""
This test checks that identity layers are removed from the model.
"""
RemoveIdentityTest(self).run_test()

def test_single_layer_replacement(self):
"""
Expand Down
Loading