Skip to content

Commit

Permalink
Add pytorch substitution to remove Identity layers (#1043)
Browse files Browse the repository at this point in the history
* Add pytorch substitution to remove Identity layers.
* Adapt tests that used identity layers.

---------

Co-authored-by: reuvenp <[email protected]>
  • Loading branch information
reuvenperetz and reuvenp authored Apr 17, 2024
1 parent 6c61777 commit 9fa60d4
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 11 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from torch import reshape
import torch

from model_compression_toolkit.logger import Logger
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
from model_compression_toolkit.core import common
from model_compression_toolkit.core.common.graph.base_graph import Graph
from model_compression_toolkit.core.common.graph.base_node import BaseNode
from model_compression_toolkit.core.pytorch.constants import BATCH_DIM_VALUE


class RemoveIdentity(common.BaseSubstitution):
"""
Remove `torch.nn.Identity` layers from the graph.
"""

def __init__(self):
nodes = NodeOperationMatcher(torch.nn.Identity)
super().__init__(matcher_instance=nodes)

def substitute(self,
graph: Graph,
node: BaseNode) -> Graph:
"""
The method to perform the substitution of the `torch.nn.Identity` node by
reconnecting its input directly to its output, effectively removing the node
from the graph.
Args:
graph: The current graph of operations where the node resides.
node: The specific `BaseNode` that is matched to be an Identity operation.
Returns:
Graph: The updated graph after removing the identity node.
"""

# Retrieve the predecessor nodes of the identity node.
prev_identity_nodes = graph.get_prev_nodes(node)
# Ensure there is exactly one predecessor; otherwise, do nothing.
if len(prev_identity_nodes) != 1:
return graph

# Reconnect the output edges of the identity node to its predecessor,
# effectively bypassing the identity node.
graph.reconnect_out_edges(current_node=node, new_node=prev_identity_nodes[0])
# Remove the edge from the predecessor to the identity node.
graph.remove_edge(prev_identity_nodes[0], node)
# Remove the identity node from the graph.
graph.remove_node(node_to_remove=node)

return graph

Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
FunctionalConvSubstitution
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.relu_bound_to_power_of_2 import \
ReLUBoundToPowerOfTwo
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.remove_identity import RemoveIdentity
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.reshape_with_static_shapes import \
ReshapeWithStaticShapes
from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.residual_collapsing import \
Expand Down Expand Up @@ -238,7 +239,8 @@ def get_substitutions_prepare_graph(self, fw_info: FrameworkInfo = None) -> List
PermuteCallMethod(),
FunctionalConvSubstitution(fw_info),
FunctionalBatchNorm(),
FunctionalLayerNorm()]
FunctionalLayerNorm(),
RemoveIdentity()]

def get_substitutions_pre_statistics_collection(self,
quant_config: QuantizationConfig
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import copy

import unittest
import torch
from mct_quantizers import PytorchActivationQuantizationHolder, PytorchQuantizationWrapper
Expand All @@ -23,11 +25,9 @@ def __init__(self, num_channels=3, kernel_size=1):
super(BasicModel, self).__init__()
self.conv1 = Conv2d(num_channels, num_channels, kernel_size=kernel_size, bias=False)
self.conv2 = Conv2d(num_channels, num_channels, kernel_size=kernel_size, bias=False)
self.identity = torch.nn.Identity()

def forward(self, inp):
x = self.conv1(inp)
x = self.identity(x)
x = self.conv2(x)
return x

Expand All @@ -51,11 +51,9 @@ class ReuseModel(torch.nn.Module):
def __init__(self, num_channels=3, kernel_size=1):
super(ReuseModel, self).__init__()
self.conv = Conv2d(num_channels, num_channels, kernel_size=kernel_size, bias=False)
self.identity = torch.nn.Identity()

def forward(self, inp):
x = self.conv(inp)
x = self.identity(x)
x = self.conv(x)
return x

Expand All @@ -73,7 +71,7 @@ def test_adding_holder_instead_quantize_wrapper(self):
# the last module should be an activation quantization holder
self.assertTrue(isinstance(last_module, PytorchActivationQuantizationHolder))
# check that 4 activation quantization holders where generated
self.assertTrue(len(activation_quantization_holders_in_model) == 4)
self.assertTrue(len(activation_quantization_holders_in_model) == 3)
for a in activation_quantization_holders_in_model:
self.assertTrue(isinstance(a.activation_holder_quantizer, ActivationPOTInferableQuantizer))
for name, module in gptq_model.named_modules():
Expand Down Expand Up @@ -102,7 +100,7 @@ def test_adding_holders_after_reuse(self):
# the last module should be an activation quantization holder
self.assertTrue(isinstance(last_module, PytorchActivationQuantizationHolder))
# check that 4 activation quantization holders where generated
self.assertTrue(len(activation_quantization_holders_in_model) == 4)
self.assertTrue(len(activation_quantization_holders_in_model) == 3)
for a in activation_quantization_holders_in_model:
self.assertTrue(isinstance(a.activation_holder_quantizer, ActivationPOTInferableQuantizer))
for name, module in gptq_model.named_modules():
Expand All @@ -115,13 +113,16 @@ def test_adding_holders_after_reuse(self):

def _get_gptq_model(self, input_shape, in_model):
pytorch_impl = GPTQPytorchImplemantation()
qc = copy.deepcopy(mct.core.DEFAULTCONFIG)
qc.linear_collapsing = False
graph = prepare_graph_with_quantization_parameters(in_model,
pytorch_impl,
DEFAULT_PYTORCH_INFO,
representative_dataset,
generate_pytorch_tpc,
[1] + input_shape,
mixed_precision_enabled=False)
mixed_precision_enabled=False,
qc=qc)
graph = set_bit_widths(mixed_precision_enable=False,
graph=graph)
trainer = PytorchGPTQTrainer(graph,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self):

def forward(self, x):
x1 = self.conv1(x)
x2 = self.identity(x1)
x2 = torch.relu(x1)
x3 = self.conv2(x2)
x4 = torch.relu(x3)
return x, x1, x2, x3, x4
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,12 @@ def __init__(self):
self.conv3 = Conv2d(3, 3, kernel_size=1, stride=1)
self.conv4 = Conv2d(3, 3, kernel_size=1, stride=1)
self.relu2 = ReLU()
self.identity = torch.nn.Identity()

def forward(self, inp):
x = self.conv1(inp)
x = self.relu1(x)
x = self.conv2(x)
x = self.identity(x)
x = relu(x)
x = self.conv3(x)
x = relu6(x)
x = self.conv4(x)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import torch

from tests.pytorch_tests.model_tests.base_pytorch_feature_test import BasePytorchFeatureNetworkTest


class RemoveIdentityNet(torch.nn.Module):
def __init__(self):
super(RemoveIdentityNet, self).__init__()
self.conv1 = torch.nn.Conv2d(3, 3, kernel_size=1, stride=1)
self.identity = torch.nn.Identity()
self.bn1 = torch.nn.BatchNorm2d(3)

def forward(self, x):
x = self.conv1(x)
x = self.identity(x)
x = self.bn1(x)
return x


class RemoveIdentityTest(BasePytorchFeatureNetworkTest):

def __init__(self, unit_test):
super().__init__(unit_test)

def create_networks(self):
return RemoveIdentityNet()

def compare(self,
quantized_model,
float_model,
input_x=None,
quantization_info=None):
for n,m in quantized_model.named_modules():
# make sure identity was removed and bn was folded into the conv
self.unit_test.assertFalse(isinstance(m, torch.nn.Identity) or isinstance(m, torch.nn.BatchNorm2d))

6 changes: 6 additions & 0 deletions tests/pytorch_tests/model_tests/test_feature_models_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,15 @@
from tests.pytorch_tests.model_tests.feature_models.const_representation_test import ConstRepresentationTest, \
ConstRepresentationMultiInputTest
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
from tests.pytorch_tests.model_tests.feature_models.remove_identity_test import RemoveIdentityTest


class FeatureModelsTestRunner(unittest.TestCase):
def test_remove_identity(self):
"""
This test checks that identity layers are removed from the model.
"""
RemoveIdentityTest(self).run_test()

def test_single_layer_replacement(self):
"""
Expand Down

0 comments on commit 9fa60d4

Please sign in to comment.