From e51189837bdfb1e1bc644796e4d731445820e520 Mon Sep 17 00:00:00 2001 From: edenlum Date: Thu, 30 Nov 2023 11:45:29 +0200 Subject: [PATCH 1/3] Fixing loading bug where torch.Tensor(array(-1)) would try to create a tensor of shape (-1) instead of a tensor with one entry at value -1. torch.tensor() has the expected behavior --- .../core/pytorch/back2framework/instance_builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model_compression_toolkit/core/pytorch/back2framework/instance_builder.py b/model_compression_toolkit/core/pytorch/back2framework/instance_builder.py index 376b920eb..1684a397d 100644 --- a/model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +++ b/model_compression_toolkit/core/pytorch/back2framework/instance_builder.py @@ -33,7 +33,7 @@ def node_builder(n: BaseNode) -> Module: framework_attr = copy.copy(n.framework_attr) node_instance = n.type(**framework_attr) - node_instance.load_state_dict({k: torch.Tensor(v) for k, v in n.weights.items()}, strict=False) + node_instance.load_state_dict({k: torch.tensor(v) for k, v in n.weights.items()}, strict=False) set_model(node_instance) return node_instance From 1e597be4f55f2d21c16a1387dc154a274eb4b1e3 Mon Sep 17 00:00:00 2001 From: edenlum Date: Tue, 5 Dec 2023 10:59:35 +0200 Subject: [PATCH 2/3] Adding test for the torch.Tensor bug --- .../feature_models/scalar_tensor_test.py | 44 +++++++++++++++++++ .../model_tests/test_feature_models_runner.py | 7 +++ 2 files changed, 51 insertions(+) create mode 100644 tests/pytorch_tests/model_tests/feature_models/scalar_tensor_test.py diff --git a/tests/pytorch_tests/model_tests/feature_models/scalar_tensor_test.py b/tests/pytorch_tests/model_tests/feature_models/scalar_tensor_test.py new file mode 100644 index 000000000..6a502626d --- /dev/null +++ b/tests/pytorch_tests/model_tests/feature_models/scalar_tensor_test.py @@ -0,0 +1,44 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import torch +from tests.pytorch_tests.model_tests.base_pytorch_test import BasePytorchTest + + +class ScalarTensorNet(torch.nn.Module): + def __init__(self): + super(ScalarTensorNet, self).__init__() + self.conv1 = torch.nn.Conv2d(3, 3, kernel_size=1, stride=1) + self.scalars = [torch.tensor(i) for i in range(-5, 6)] + + def forward(self, x, y): + x = self.conv1(x) + for scalar in self.scalars: + x = x + scalar + return x + + +class ScalarTensorTest(BasePytorchTest): + """ + This test checks that we build a correct graph when the input graph contains a tensor with a single integer value. + """ + + def __init__(self, unit_test): + super().__init__(unit_test) + + def create_inputs_shape(self): + return [[self.val_batch_size, 3, 32, 32], [self.val_batch_size, 3, 32, 32]] + + def create_feature_network(self, input_shape): + return ScalarTensorNet() diff --git a/tests/pytorch_tests/model_tests/test_feature_models_runner.py b/tests/pytorch_tests/model_tests/test_feature_models_runner.py index d01b06be9..b5d100471 100644 --- a/tests/pytorch_tests/model_tests/test_feature_models_runner.py +++ b/tests/pytorch_tests/model_tests/test_feature_models_runner.py @@ -39,6 +39,7 @@ MixedPercisionActivationSearch4BitFunctional, MixedPercisionActivationMultipleInputs from tests.pytorch_tests.model_tests.feature_models.relu_bound_test import ReLUBoundToPOTNetTest, \ HardtanhBoundToPOTNetTest +from tests.pytorch_tests.model_tests.feature_models.scalar_tensor_test import ScalarTensorTest from tests.pytorch_tests.model_tests.feature_models.second_moment_correction_test import ConvSecondMomentNetTest, \ ConvTSecondMomentNetTest, MultipleInputsConvSecondMomentNetTest, ValueSecondMomentTest from tests.pytorch_tests.model_tests.feature_models.symmetric_activation_test import SymmetricActivationTest @@ -253,6 +254,12 @@ def test_scale_equalization(self): # and with zero padding. ScaleEqualizationReluFuncConvTransposeWithZeroPadNetTest(self).run_test() + def test_scalar_tensor(self): + """ + This test checks that we support scalar tensors initialized as torch.tensor(x) where x is int + """ + ScalarTensorTest(self).run_test() + def test_layer_name(self): """ This test checks that we build a correct graph and correctly reconstruct the model From 69ec6ac2197b53e8c9c878270b402785dacb6868 Mon Sep 17 00:00:00 2001 From: edenlum Date: Tue, 5 Dec 2023 11:55:43 +0200 Subject: [PATCH 3/3] Adding a test and fixing another torch.Tensor instance to torch.tensor. In general, torch.Tensor is for uninitialized tensors, because all of our tensors are with values and the datatype is inferred (not explicit) we should use torch.tensor. --- .../core/pytorch/back2framework/pytorch_model_builder.py | 2 +- .../model_tests/feature_models/scalar_tensor_test.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py b/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py index 7f0a51421..a9f1e313e 100644 --- a/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +++ b/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py @@ -233,7 +233,7 @@ def _add_modules(self): if node.type == BufferHolder: self.get_submodule(node.name). \ register_buffer(node.name, - torch.Tensor(node.get_weights_by_keys(BUFFER)).to(get_working_device())) + torch.tensor(node.get_weights_by_keys(BUFFER)).to(get_working_device())) # Add activation quantization modules if an activation holder is configured for this node if node.is_activation_quantization_enabled() and self.get_activation_quantizer_holder is not None: diff --git a/tests/pytorch_tests/model_tests/feature_models/scalar_tensor_test.py b/tests/pytorch_tests/model_tests/feature_models/scalar_tensor_test.py index 6a502626d..32e9cbab6 100644 --- a/tests/pytorch_tests/model_tests/feature_models/scalar_tensor_test.py +++ b/tests/pytorch_tests/model_tests/feature_models/scalar_tensor_test.py @@ -22,7 +22,7 @@ def __init__(self): self.conv1 = torch.nn.Conv2d(3, 3, kernel_size=1, stride=1) self.scalars = [torch.tensor(i) for i in range(-5, 6)] - def forward(self, x, y): + def forward(self, x): x = self.conv1(x) for scalar in self.scalars: x = x + scalar @@ -38,7 +38,7 @@ def __init__(self, unit_test): super().__init__(unit_test) def create_inputs_shape(self): - return [[self.val_batch_size, 3, 32, 32], [self.val_batch_size, 3, 32, 32]] + return [[self.val_batch_size, 3, 32, 32]] def create_feature_network(self, input_shape): return ScalarTensorNet()