Skip to content

Commit

Permalink
Add a error message when creating an empty Constant (#19674)
Browse files Browse the repository at this point in the history
Co-authored-by: Anastasia Kuporosova <[email protected]>
  • Loading branch information
siddhant-0707 and akuporos authored Sep 13, 2023
1 parent 2bf8d91 commit 4ca3d51
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 9 deletions.
2 changes: 2 additions & 0 deletions src/bindings/python/src/openvino/runtime/opset1/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,8 @@ def constant(
:param name: Optional name for output node.
:return: The Constant node initialized with provided data.
"""
if value is None or (isinstance(value, np.ndarray) and value.size == 0):
raise ValueError("Cannot create an empty Constant. Please provide valid data.")
return make_constant_node(value, dtype)


Expand Down
22 changes: 13 additions & 9 deletions src/bindings/python/tests/test_graph/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import pytest
from contextlib import nullcontext as does_not_raise

import openvino.runtime.opset8 as ov
from openvino.runtime import AxisSet, Shape, Type
Expand Down Expand Up @@ -114,16 +115,19 @@ def test_broadcast():
assert node.get_output_element_type(0) == element_type


@pytest.mark.parametrize("node", [
Constant(Type.f32, Shape([3, 3]), list(range(9))),
ov.constant(np.arange(9).reshape(3, 3), Type.f32),
ov.constant(np.arange(9).reshape(3, 3), np.float32)
@pytest.mark.parametrize(("const", "args", "expectation"), [
(Constant, (Type.f32, Shape([3, 3]), list(range(9))), does_not_raise()),
(ov.constant, (np.arange(9).reshape(3, 3), Type.f32), does_not_raise()),
(ov.constant, (np.arange(9).reshape(3, 3), np.float32), does_not_raise()),
(ov.constant, [None], pytest.raises(ValueError)),
])
def test_constant(node):
assert node.get_type_name() == "Constant"
assert node.get_output_size() == 1
assert list(node.get_output_shape(0)) == [3, 3]
assert node.get_output_element_type(0) == Type.f32
def test_constant(const, args, expectation):
with expectation:
node = const(*args)
assert node.get_type_name() == "Constant"
assert node.get_output_size() == 1
assert list(node.get_output_shape(0)) == [3, 3]
assert node.get_output_element_type(0) == Type.f32


def test_concat():
Expand Down

0 comments on commit 4ca3d51

Please sign in to comment.