Skip to content

Commit

Permalink
support and test scalars in const representation (#941)
Browse files Browse the repository at this point in the history
  • Loading branch information
elad-c authored Feb 5, 2024
1 parent 7a360d4 commit d46abb0
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 1 deletion.
2 changes: 2 additions & 0 deletions model_compression_toolkit/core/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def to_torch_tensor(tensor):
return (to_torch_tensor(t) for t in tensor)
elif isinstance(tensor, np.ndarray):
return torch.from_numpy(tensor.astype(np.float32)).to(working_device)
elif isinstance(tensor, (int, float)):
return torch.from_numpy(np.array(tensor).astype(np.float32)).to(working_device)
else:
raise Exception(f'Conversion of type {type(tensor)} to {type(torch.Tensor)} is not supported')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,8 @@ def test_const_representation(self):
ConstRepresentationTest(self, func, c, input_reverse_order=True).run_test()
ConstRepresentationTest(self, func, c, input_reverse_order=True, use_kwrags=True).run_test()
ConstRepresentationTest(self, func, c, as_layer=True, use_kwrags=True).run_test()
ConstRepresentationTest(self, func, 2.45).run_test()
ConstRepresentationTest(self, func, 5.1, input_reverse_order=True).run_test()

c = (np.ones((16,)) + np.random.random((16,))).astype(np.float32).reshape((1, -1))
for func in [layers.Add(), layers.Multiply(), layers.Subtract()]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def test_conv2d_replacement(self):

def test_add_net(self):
"""
This tests check the addition and subtraction operations.
This test checks the addition and subtraction operations.
Both with different layers and with constants.
"""
AddNetTest(self).run_test()
Expand Down Expand Up @@ -226,6 +226,8 @@ def test_const_representation(self):
for func in [torch.add, torch.sub, torch.mul, torch.div]:
ConstRepresentationTest(self, func, c).run_test()
ConstRepresentationTest(self, func, c, input_reverse_order=True).run_test()
ConstRepresentationTest(self, func, 2.45).run_test()
ConstRepresentationTest(self, func, 5, input_reverse_order=True).run_test()

ConstRepresentationMultiInputTest(self).run_test()

Expand Down

0 comments on commit d46abb0

Please sign in to comment.