diff --git a/model_compression_toolkit/core/pytorch/utils.py b/model_compression_toolkit/core/pytorch/utils.py index 33b8e8d2c..1d40d0793 100644 --- a/model_compression_toolkit/core/pytorch/utils.py +++ b/model_compression_toolkit/core/pytorch/utils.py @@ -55,6 +55,8 @@ def to_torch_tensor(tensor): return (to_torch_tensor(t) for t in tensor) elif isinstance(tensor, np.ndarray): return torch.from_numpy(tensor.astype(np.float32)).to(working_device) + elif isinstance(tensor, (int, float)): + return torch.from_numpy(np.array(tensor).astype(np.float32)).to(working_device) else: raise Exception(f'Conversion of type {type(tensor)} to {type(torch.Tensor)} is not supported') diff --git a/tests/keras_tests/feature_networks_tests/test_features_runner.py b/tests/keras_tests/feature_networks_tests/test_features_runner.py index 92b948c77..4d047063d 100644 --- a/tests/keras_tests/feature_networks_tests/test_features_runner.py +++ b/tests/keras_tests/feature_networks_tests/test_features_runner.py @@ -547,6 +547,8 @@ def test_const_representation(self): ConstRepresentationTest(self, func, c, input_reverse_order=True).run_test() ConstRepresentationTest(self, func, c, input_reverse_order=True, use_kwrags=True).run_test() ConstRepresentationTest(self, func, c, as_layer=True, use_kwrags=True).run_test() + ConstRepresentationTest(self, func, 2.45).run_test() + ConstRepresentationTest(self, func, 5.1, input_reverse_order=True).run_test() c = (np.ones((16,)) + np.random.random((16,))).astype(np.float32).reshape((1, -1)) for func in [layers.Add(), layers.Multiply(), layers.Subtract()]: diff --git a/tests/pytorch_tests/model_tests/test_feature_models_runner.py b/tests/pytorch_tests/model_tests/test_feature_models_runner.py index 7c69af66f..c4e5ac7ea 100644 --- a/tests/pytorch_tests/model_tests/test_feature_models_runner.py +++ b/tests/pytorch_tests/model_tests/test_feature_models_runner.py @@ -120,7 +120,7 @@ def test_conv2d_replacement(self): def test_add_net(self): """ - This tests check the addition and subtraction operations. + This test checks the addition and subtraction operations. Both with different layers and with constants. """ AddNetTest(self).run_test() @@ -226,6 +226,8 @@ def test_const_representation(self): for func in [torch.add, torch.sub, torch.mul, torch.div]: ConstRepresentationTest(self, func, c).run_test() ConstRepresentationTest(self, func, c, input_reverse_order=True).run_test() + ConstRepresentationTest(self, func, 2.45).run_test() + ConstRepresentationTest(self, func, 5, input_reverse_order=True).run_test() ConstRepresentationMultiInputTest(self).run_test()