From 78928c88d8e370393b82e3b6d7e99024cc5ff325 Mon Sep 17 00:00:00 2001 From: reuvenp Date: Tue, 16 Apr 2024 12:34:38 +0300 Subject: [PATCH 1/3] Add pytorch substitution to remove Identity layers --- .../substitutions/remove_identity.py | 66 +++++++++++++++++++ .../core/pytorch/pytorch_implementation.py | 4 +- 2 files changed, 69 insertions(+), 1 deletion(-) create mode 100644 model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/remove_identity.py diff --git a/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/remove_identity.py b/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/remove_identity.py new file mode 100644 index 000000000..c942467f2 --- /dev/null +++ b/model_compression_toolkit/core/pytorch/graph_substitutions/substitutions/remove_identity.py @@ -0,0 +1,66 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from torch import reshape +import torch + +from model_compression_toolkit.logger import Logger +from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher +from model_compression_toolkit.core import common +from model_compression_toolkit.core.common.graph.base_graph import Graph +from model_compression_toolkit.core.common.graph.base_node import BaseNode +from model_compression_toolkit.core.pytorch.constants import BATCH_DIM_VALUE + + +class RemoveIdentity(common.BaseSubstitution): + """ + Remove `torch.nn.Identity` layers from the graph. + """ + + def __init__(self): + nodes = NodeOperationMatcher(torch.nn.Identity) + super().__init__(matcher_instance=nodes) + + def substitute(self, + graph: Graph, + node: BaseNode) -> Graph: + """ + The method to perform the substitution of the `torch.nn.Identity` node by + reconnecting its input directly to its output, effectively removing the node + from the graph. + + Args: + graph: The current graph of operations where the node resides. + node: The specific `BaseNode` that is matched to be an Identity operation. + + Returns: + Graph: The updated graph after removing the identity node. + """ + + # Retrieve the predecessor nodes of the identity node. + prev_identity_nodes = graph.get_prev_nodes(node) + # Ensure there is exactly one predecessor; otherwise, do nothing. + if len(prev_identity_nodes) != 1: + return graph + + # Reconnect the output edges of the identity node to its predecessor, + # effectively bypassing the identity node. + graph.reconnect_out_edges(current_node=node, new_node=prev_identity_nodes[0]) + # Remove the edge from the predecessor to the identity node. + graph.remove_edge(prev_identity_nodes[0], node) + # Remove the identity node from the graph. + graph.remove_node(node_to_remove=node) + + return graph + diff --git a/model_compression_toolkit/core/pytorch/pytorch_implementation.py b/model_compression_toolkit/core/pytorch/pytorch_implementation.py index 3c95a3fa4..bef158851 100644 --- a/model_compression_toolkit/core/pytorch/pytorch_implementation.py +++ b/model_compression_toolkit/core/pytorch/pytorch_implementation.py @@ -58,6 +58,7 @@ FunctionalConvSubstitution from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.relu_bound_to_power_of_2 import \ ReLUBoundToPowerOfTwo +from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.remove_identity import RemoveIdentity from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.reshape_with_static_shapes import \ ReshapeWithStaticShapes from model_compression_toolkit.core.pytorch.graph_substitutions.substitutions.residual_collapsing import \ @@ -238,7 +239,8 @@ def get_substitutions_prepare_graph(self, fw_info: FrameworkInfo = None) -> List PermuteCallMethod(), FunctionalConvSubstitution(fw_info), FunctionalBatchNorm(), - FunctionalLayerNorm()] + FunctionalLayerNorm(), + RemoveIdentity()] def get_substitutions_pre_statistics_collection(self, quant_config: QuantizationConfig From af4f8e0ab6cf00cb3b320a5bfd41a560828e527d Mon Sep 17 00:00:00 2001 From: reuvenp Date: Tue, 16 Apr 2024 12:55:57 +0300 Subject: [PATCH 2/3] add unittest --- .../feature_models/remove_identity_test.py | 50 +++++++++++++++++++ .../model_tests/test_feature_models_runner.py | 6 +++ 2 files changed, 56 insertions(+) create mode 100644 tests/pytorch_tests/model_tests/feature_models/remove_identity_test.py diff --git a/tests/pytorch_tests/model_tests/feature_models/remove_identity_test.py b/tests/pytorch_tests/model_tests/feature_models/remove_identity_test.py new file mode 100644 index 000000000..74b1b8d65 --- /dev/null +++ b/tests/pytorch_tests/model_tests/feature_models/remove_identity_test.py @@ -0,0 +1,50 @@ +# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import torch + +from tests.pytorch_tests.model_tests.base_pytorch_feature_test import BasePytorchFeatureNetworkTest + + +class RemoveIdentityNet(torch.nn.Module): + def __init__(self): + super(RemoveIdentityNet, self).__init__() + self.conv1 = torch.nn.Conv2d(3, 3, kernel_size=1, stride=1) + self.identity = torch.nn.Identity() + self.bn1 = torch.nn.BatchNorm2d(3) + + def forward(self, x): + x = self.conv1(x) + x = self.identity(x) + x = self.bn1(x) + return x + + +class RemoveIdentityTest(BasePytorchFeatureNetworkTest): + + def __init__(self, unit_test): + super().__init__(unit_test) + + def create_networks(self): + return RemoveIdentityNet() + + def compare(self, + quantized_model, + float_model, + input_x=None, + quantization_info=None): + for n,m in quantized_model.named_modules(): + # make sure identity was removed and bn was folded into the conv + self.unit_test.assertFalse(isinstance(m, torch.nn.Identity) or isinstance(m, torch.nn.BatchNorm2d)) + diff --git a/tests/pytorch_tests/model_tests/test_feature_models_runner.py b/tests/pytorch_tests/model_tests/test_feature_models_runner.py index 75cb6bb56..7e339a8f4 100644 --- a/tests/pytorch_tests/model_tests/test_feature_models_runner.py +++ b/tests/pytorch_tests/model_tests/test_feature_models_runner.py @@ -88,9 +88,15 @@ from tests.pytorch_tests.model_tests.feature_models.const_representation_test import ConstRepresentationTest, \ ConstRepresentationMultiInputTest from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod +from tests.pytorch_tests.model_tests.feature_models.remove_identity_test import RemoveIdentityTest class FeatureModelsTestRunner(unittest.TestCase): + def test_remove_identity(self): + """ + This test checks that identity layers are removed from the model. + """ + RemoveIdentityTest(self).run_test() def test_single_layer_replacement(self): """ From b0175943ae9858a5fbc9fe7542f3cc031182bae5 Mon Sep 17 00:00:00 2001 From: reuvenp Date: Wed, 17 Apr 2024 14:39:06 +0300 Subject: [PATCH 3/3] fix tests that used identity --- .../test_activation_quantization_holder_gptq.py | 15 ++++++++------- .../feature_models/output_in_the_middle_test.py | 2 +- .../model_tests/feature_models/relu_bound_test.py | 3 +-- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/pytorch_tests/function_tests/test_activation_quantization_holder_gptq.py b/tests/pytorch_tests/function_tests/test_activation_quantization_holder_gptq.py index c1c074213..874caa5a7 100644 --- a/tests/pytorch_tests/function_tests/test_activation_quantization_holder_gptq.py +++ b/tests/pytorch_tests/function_tests/test_activation_quantization_holder_gptq.py @@ -1,3 +1,5 @@ +import copy + import unittest import torch from mct_quantizers import PytorchActivationQuantizationHolder, PytorchQuantizationWrapper @@ -23,11 +25,9 @@ def __init__(self, num_channels=3, kernel_size=1): super(BasicModel, self).__init__() self.conv1 = Conv2d(num_channels, num_channels, kernel_size=kernel_size, bias=False) self.conv2 = Conv2d(num_channels, num_channels, kernel_size=kernel_size, bias=False) - self.identity = torch.nn.Identity() def forward(self, inp): x = self.conv1(inp) - x = self.identity(x) x = self.conv2(x) return x @@ -51,11 +51,9 @@ class ReuseModel(torch.nn.Module): def __init__(self, num_channels=3, kernel_size=1): super(ReuseModel, self).__init__() self.conv = Conv2d(num_channels, num_channels, kernel_size=kernel_size, bias=False) - self.identity = torch.nn.Identity() def forward(self, inp): x = self.conv(inp) - x = self.identity(x) x = self.conv(x) return x @@ -73,7 +71,7 @@ def test_adding_holder_instead_quantize_wrapper(self): # the last module should be an activation quantization holder self.assertTrue(isinstance(last_module, PytorchActivationQuantizationHolder)) # check that 4 activation quantization holders where generated - self.assertTrue(len(activation_quantization_holders_in_model) == 4) + self.assertTrue(len(activation_quantization_holders_in_model) == 3) for a in activation_quantization_holders_in_model: self.assertTrue(isinstance(a.activation_holder_quantizer, ActivationPOTInferableQuantizer)) for name, module in gptq_model.named_modules(): @@ -102,7 +100,7 @@ def test_adding_holders_after_reuse(self): # the last module should be an activation quantization holder self.assertTrue(isinstance(last_module, PytorchActivationQuantizationHolder)) # check that 4 activation quantization holders where generated - self.assertTrue(len(activation_quantization_holders_in_model) == 4) + self.assertTrue(len(activation_quantization_holders_in_model) == 3) for a in activation_quantization_holders_in_model: self.assertTrue(isinstance(a.activation_holder_quantizer, ActivationPOTInferableQuantizer)) for name, module in gptq_model.named_modules(): @@ -115,13 +113,16 @@ def test_adding_holders_after_reuse(self): def _get_gptq_model(self, input_shape, in_model): pytorch_impl = GPTQPytorchImplemantation() + qc = copy.deepcopy(mct.core.DEFAULTCONFIG) + qc.linear_collapsing = False graph = prepare_graph_with_quantization_parameters(in_model, pytorch_impl, DEFAULT_PYTORCH_INFO, representative_dataset, generate_pytorch_tpc, [1] + input_shape, - mixed_precision_enabled=False) + mixed_precision_enabled=False, + qc=qc) graph = set_bit_widths(mixed_precision_enable=False, graph=graph) trainer = PytorchGPTQTrainer(graph, diff --git a/tests/pytorch_tests/model_tests/feature_models/output_in_the_middle_test.py b/tests/pytorch_tests/model_tests/feature_models/output_in_the_middle_test.py index 8bd6ecee9..0162a513e 100644 --- a/tests/pytorch_tests/model_tests/feature_models/output_in_the_middle_test.py +++ b/tests/pytorch_tests/model_tests/feature_models/output_in_the_middle_test.py @@ -28,7 +28,7 @@ def __init__(self): def forward(self, x): x1 = self.conv1(x) - x2 = self.identity(x1) + x2 = torch.relu(x1) x3 = self.conv2(x2) x4 = torch.relu(x3) return x, x1, x2, x3, x4 diff --git a/tests/pytorch_tests/model_tests/feature_models/relu_bound_test.py b/tests/pytorch_tests/model_tests/feature_models/relu_bound_test.py index 730307dba..205a5df83 100644 --- a/tests/pytorch_tests/model_tests/feature_models/relu_bound_test.py +++ b/tests/pytorch_tests/model_tests/feature_models/relu_bound_test.py @@ -38,13 +38,12 @@ def __init__(self): self.conv3 = Conv2d(3, 3, kernel_size=1, stride=1) self.conv4 = Conv2d(3, 3, kernel_size=1, stride=1) self.relu2 = ReLU() - self.identity = torch.nn.Identity() def forward(self, inp): x = self.conv1(inp) x = self.relu1(x) x = self.conv2(x) - x = self.identity(x) + x = relu(x) x = self.conv3(x) x = relu6(x) x = self.conv4(x)