diff --git a/model_compression_toolkit/core/pytorch/back2framework/instance_builder.py b/model_compression_toolkit/core/pytorch/back2framework/instance_builder.py index 376b920eb..1684a397d 100644 --- a/model_compression_toolkit/core/pytorch/back2framework/instance_builder.py +++ b/model_compression_toolkit/core/pytorch/back2framework/instance_builder.py @@ -33,7 +33,7 @@ def node_builder(n: BaseNode) -> Module: framework_attr = copy.copy(n.framework_attr) node_instance = n.type(**framework_attr) - node_instance.load_state_dict({k: torch.Tensor(v) for k, v in n.weights.items()}, strict=False) + node_instance.load_state_dict({k: torch.tensor(v) for k, v in n.weights.items()}, strict=False) set_model(node_instance) return node_instance diff --git a/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py b/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py index 7f0a51421..a9f1e313e 100644 --- a/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py +++ b/model_compression_toolkit/core/pytorch/back2framework/pytorch_model_builder.py @@ -233,7 +233,7 @@ def _add_modules(self): if node.type == BufferHolder: self.get_submodule(node.name). \ register_buffer(node.name, - torch.Tensor(node.get_weights_by_keys(BUFFER)).to(get_working_device())) + torch.tensor(node.get_weights_by_keys(BUFFER)).to(get_working_device())) # Add activation quantization modules if an activation holder is configured for this node if node.is_activation_quantization_enabled() and self.get_activation_quantizer_holder is not None: diff --git a/tests/pytorch_tests/model_tests/feature_models/scalar_tensor_test.py b/tests/pytorch_tests/model_tests/feature_models/scalar_tensor_test.py new file mode 100644 index 000000000..32e9cbab6 --- /dev/null +++ b/tests/pytorch_tests/model_tests/feature_models/scalar_tensor_test.py @@ -0,0 +1,44 @@ +# Copyright 2023 Sony Semiconductor Israel, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import torch +from tests.pytorch_tests.model_tests.base_pytorch_test import BasePytorchTest + + +class ScalarTensorNet(torch.nn.Module): + def __init__(self): + super(ScalarTensorNet, self).__init__() + self.conv1 = torch.nn.Conv2d(3, 3, kernel_size=1, stride=1) + self.scalars = [torch.tensor(i) for i in range(-5, 6)] + + def forward(self, x): + x = self.conv1(x) + for scalar in self.scalars: + x = x + scalar + return x + + +class ScalarTensorTest(BasePytorchTest): + """ + This test checks that we build a correct graph when the input graph contains a tensor with a single integer value. + """ + + def __init__(self, unit_test): + super().__init__(unit_test) + + def create_inputs_shape(self): + return [[self.val_batch_size, 3, 32, 32]] + + def create_feature_network(self, input_shape): + return ScalarTensorNet() diff --git a/tests/pytorch_tests/model_tests/test_feature_models_runner.py b/tests/pytorch_tests/model_tests/test_feature_models_runner.py index d01b06be9..b5d100471 100644 --- a/tests/pytorch_tests/model_tests/test_feature_models_runner.py +++ b/tests/pytorch_tests/model_tests/test_feature_models_runner.py @@ -39,6 +39,7 @@ MixedPercisionActivationSearch4BitFunctional, MixedPercisionActivationMultipleInputs from tests.pytorch_tests.model_tests.feature_models.relu_bound_test import ReLUBoundToPOTNetTest, \ HardtanhBoundToPOTNetTest +from tests.pytorch_tests.model_tests.feature_models.scalar_tensor_test import ScalarTensorTest from tests.pytorch_tests.model_tests.feature_models.second_moment_correction_test import ConvSecondMomentNetTest, \ ConvTSecondMomentNetTest, MultipleInputsConvSecondMomentNetTest, ValueSecondMomentTest from tests.pytorch_tests.model_tests.feature_models.symmetric_activation_test import SymmetricActivationTest @@ -253,6 +254,12 @@ def test_scale_equalization(self): # and with zero padding. ScaleEqualizationReluFuncConvTransposeWithZeroPadNetTest(self).run_test() + def test_scalar_tensor(self): + """ + This test checks that we support scalar tensors initialized as torch.tensor(x) where x is int + """ + ScalarTensorTest(self).run_test() + def test_layer_name(self): """ This test checks that we build a correct graph and correctly reconstruct the model