Skip to content

Commit

Permalink
Remove special casing of "None" as a dim_param (#1482)
Browse files Browse the repository at this point in the history
* Remove special casing of "None" as a dim_param
  • Loading branch information
skottmckay authored Jul 25, 2019
1 parent a8e3ff4 commit f052966
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
3 changes: 1 addition & 2 deletions onnxruntime/core/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ static void RemoveInvalidValues(ONNX_NAMESPACE::TypeProto& type) {
for (int i = 0, end = shape->dim_size(); i < end; ++i) {
auto& dim = *shape->mutable_dim(i);
if (dim.has_dim_param()) {
auto dim_param = dim.dim_param();
if (dim_param.empty() || dim_param == "None") {
if (dim.dim_param().empty()) {
dim.clear_dim_param();
}
} else if (dim.has_dim_value()) {
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/test/python/onnxruntime_test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,12 @@ def testRunModelSymbolicInput(self):
self.assertEqual(input_name, "X")
input_shape = sess.get_inputs()[0].shape
# Input X has an unknown dimension.
self.assertEqual(input_shape, [None, 2])
self.assertEqual(input_shape, ['None', 2])
output_name = sess.get_outputs()[0].name
self.assertEqual(output_name, "Y")
output_shape = sess.get_outputs()[0].shape
# Output X has an unknown dimension.
self.assertEqual(output_shape, [None, 1])
self.assertEqual(output_shape, ['None', 1])
res = sess.run([output_name], {input_name: x})
output_expected = np.array([[5.0], [11.0], [17.0]], dtype=np.float32)
np.testing.assert_allclose(
Expand Down

0 comments on commit f052966

Please sign in to comment.