Skip to content

Commit

Permalink
Don't create implicit input for outer scope value if there is a subgr…
Browse files Browse the repository at this point in the history
…aph input with the same name. (#1186)

* If there is an outer scope value that matches a subgraph input, don't create an implicit input from the outer scope value.

Minor unrelated change for issue noticed while debugging: Use unordered_set for implicit inputs so we don't add them multiple times.

* Add unit test based on onnx issue.
  • Loading branch information
skottmckay authored Aug 1, 2019
1 parent 1cf5ebc commit 9fb8867
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 24 deletions.
2 changes: 1 addition & 1 deletion include/onnxruntime/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -853,7 +853,7 @@ class Graph {

// Build and verify node connection (edges).
// Verify NodeArg name/type/shape matching correctly.
common::Status BuildConnections(std::vector<std::string>& outer_scope_node_args_consumed);
common::Status BuildConnections(std::unordered_set<std::string>& outer_scope_node_args_consumed);

common::Status VerifyNoDuplicateName();

Expand Down
50 changes: 28 additions & 22 deletions onnxruntime/core/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -898,7 +898,7 @@ void Graph::RemoveEdge(NodeIndex src_node_index, NodeIndex dst_node_index, int s
}

GSL_SUPPRESS(es .84) // ignoring return value from unordered_map::insert causes noisy complaint
Status Graph::BuildConnections(std::vector<std::string>& outer_scope_node_args_consumed) {
Status Graph::BuildConnections(std::unordered_set<std::string>& outer_scope_node_args_consumed) {
const std::unordered_set<std::string>& outer_scope_node_args = resolve_context_.outer_scope_node_args;
std::unordered_set<Node*> inner_nodes;

Expand All @@ -908,7 +908,7 @@ Status Graph::BuildConnections(std::vector<std::string>& outer_scope_node_args_c

for (auto* node : resolve_context_.nodes_with_subgraphs) {
for (auto& subgraph : node->MutableSubgraphs()) {
std::vector<std::string> node_args_consumed;
std::unordered_set<std::string> node_args_consumed;
ORT_RETURN_IF_ERROR(subgraph->BuildConnections(node_args_consumed));

for (auto& node_arg_name : node_args_consumed) {
Expand All @@ -918,7 +918,7 @@ Status Graph::BuildConnections(std::vector<std::string>& outer_scope_node_args_c
// it's a node arg from outside this graph's scope, so add that to the list we return
// so that we can add the dependency at the next level up. this happens if you have multiple
// levels of subgraphs between the graph with the original NodeArg and the subgraph with implicit usage.
outer_scope_node_args_consumed.push_back(node_arg_name);
ORT_IGNORE_RETURN_VALUE(outer_scope_node_args_consumed.insert(node_arg_name));

if (!parent_graph_) {
return ORT_MAKE_STATUS(
Expand Down Expand Up @@ -987,25 +987,31 @@ Status Graph::BuildConnections(std::vector<std::string>& outer_scope_node_args_c
continue;
}

auto output_arg_iter = resolve_context_.output_args.find(input_arg->Name());
if (resolve_context_.output_args.end() == output_arg_iter) {
// No such output_arg matching this input_arg.
// This input arg should be fed when running evaluation.
// See if it's present in the outer scope. If so it will be 'fed' by the execution frame
// providing access to the OrtValue from the outer scope. Pass the name back up so nodes can
// be linked correctly at that level.
if (outer_scope_node_args.find(input_arg->Name()) != outer_scope_node_args.cend()) {
outer_scope_node_args_consumed.push_back(input_arg->Name());
}
const auto& input_arg_name = input_arg->Name();
auto output_arg_iter = resolve_context_.output_args.find(input_arg_name);
if (resolve_context_.output_args.end() != output_arg_iter) {
// The input to this node is an output from a previous node in this graph.
// Create relationship between this node (node), and the node providing the output (output_node).
Node& output_node = *output_arg_iter->second.first;
AddEdge(output_node.Index(), node.Index(), output_arg_iter->second.second, input_slot_index);

continue;
inner_nodes.insert(&output_node);
} else {
// the value is either an input, an initializer, or coming from outer scope. we only need to take action
// if coming from outer scope, so first check if this is a subgraph (otherwise there is no outer scope).
if (parent_graph_ != nullptr) {
// make sure it's not an input or initializer first as those override any outer scope values
if (resolve_context_.inputs_and_initializers.find(input_arg_name) ==
resolve_context_.inputs_and_initializers.cend()) {
// If it is present in the outer scope it will be 'fed' by the execution frame
// providing access to the OrtValue from the outer scope. Pass the name back up so nodes can
// be linked correctly at that level.
if (outer_scope_node_args.find(input_arg_name) != outer_scope_node_args.cend()) {
ORT_IGNORE_RETURN_VALUE(outer_scope_node_args_consumed.insert(input_arg_name));
}
}
}
}

// Create relationship between this node (node), and the node providing the output (output_node).
Node& output_node = *output_arg_iter->second.first;
AddEdge(output_node.Index(), node.Index(), output_arg_iter->second.second, input_slot_index);

inner_nodes.insert(&output_node);
}
} else if (node.OutputDefs().empty()) {
// This is a useless node.
Expand All @@ -1015,7 +1021,7 @@ Status Graph::BuildConnections(std::vector<std::string>& outer_scope_node_args_c
}

return Status::OK();
}
} // namespace onnxruntime

void Graph::ReverseDFSFrom(const std::vector<NodeIndex>& from,
const std::function<void(const Node*)>& enter,
Expand Down Expand Up @@ -1869,7 +1875,7 @@ Status Graph::Resolve(bool no_proto_sync_required) {
// recursively set the outer scope node args.
ORT_RETURN_IF_ERROR(SetOuterScopeNodeArgs(resolve_context_.outer_scope_node_args));

std::vector<std::string> outer_scope_node_args_consumed;
std::unordered_set<std::string> outer_scope_node_args_consumed;

// recursively build connections between nodes in this graph and all subgraphs
ORT_RETURN_IF_ERROR(BuildConnections(outer_scope_node_args_consumed));
Expand Down
19 changes: 18 additions & 1 deletion onnxruntime/test/framework/test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
#include "core/providers/cpu/cpu_execution_provider.h"
#include "core/framework/ml_value.h"

#include "gsl/gsl_algorithm"

#ifdef USE_CUDA
#include "core/providers/cuda/cuda_execution_provider.h"
#endif
Expand Down Expand Up @@ -48,6 +50,20 @@ IExecutionProvider* TestOpenVINOExecutionProvider();
IExecutionProvider* TestNnapiExecutionProvider();
#endif

template <typename T>
inline void CopyVectorToTensor(const std::vector<T>& value, Tensor& tensor) {
gsl::copy(gsl::make_span(value), tensor.MutableDataAsSpan<T>());
}

// vector<bool> is specialized so we need to handle it separately
template <>
inline void CopyVectorToTensor<bool>(const std::vector<bool>& value, Tensor& tensor) {
auto output_span = tensor.MutableDataAsSpan<bool>();
for (size_t i = 0, end = value.size(); i < end; ++i) {
output_span[i] = value[i];
}
}

template <typename T>
void CreateMLValue(AllocatorPtr alloc, const std::vector<int64_t>& dims, const std::vector<T>& value,
OrtValue* p_mlvalue) {
Expand All @@ -57,8 +73,9 @@ void CreateMLValue(AllocatorPtr alloc, const std::vector<int64_t>& dims, const s
shape,
alloc);
if (value.size() > 0) {
memcpy(p_tensor->MutableData<T>(), &value[0], element_type->Size() * shape.Size());
CopyVectorToTensor(value, *p_tensor);
}

p_mlvalue->Init(p_tensor.release(),
DataTypeImpl::GetType<Tensor>(),
DataTypeImpl::GetType<Tensor>()->GetDeleteFunc());
Expand Down
57 changes: 57 additions & 0 deletions onnxruntime/test/providers/cpu/controlflow/loop_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "test/providers/provider_test_utils.h"
#include "test/util/include/default_providers.h"
#include "test/framework/test_utils.h"

using namespace ONNX_NAMESPACE;

Expand Down Expand Up @@ -573,6 +574,62 @@ TEST(Loop, InfiniteLoopTermination) {
terminator_thread.join();
}

// Regression test that a subgraph input overrides an outer scope value of the same name.
// Replicate issue from https://github.com/onnx/onnx/issues/2082
TEST(Loop, SubgraphInputShadowsOuterScopeValue) {
SessionOptions so;
so.session_logid = "SubgraphInputShadowsOuterScopeValue";

InferenceSession session_object{so, &DefaultLoggingManager()};
Status st;
ASSERT_TRUE((st = session_object.Load("testdata/subgraph_input_shadows_outer_scope_value.onnx")).IsOK()) << st;
ASSERT_TRUE((st = session_object.Initialize()).IsOK()) << st;

// prepare inputs
std::vector<int64_t> scalar = {1};
std::vector<float> a = {3.f}, b = {6.f};
std::vector<int64_t> trip_count = {10};
std::vector<bool> keep_going = {true};

NameMLValMap feeds;
OrtValue ml_value;

CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), scalar, a, &ml_value);
feeds.insert(std::make_pair("a", ml_value));
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), scalar, b, &ml_value);
feeds.insert(std::make_pair("b", ml_value));
CreateMLValue<int64_t>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), scalar, trip_count, &ml_value);
feeds.insert(std::make_pair("max_trip_count", ml_value));
CreateMLValue<bool>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), scalar, keep_going, &ml_value);
feeds.insert(std::make_pair("keep_going_inp", ml_value));

// prepare outputs
std::vector<std::string> output_names{"b", "user_defined_vals"};
std::vector<OrtValue> fetches;

// Now run
onnxruntime::RunOptions run_options;
st = session_object.Run(run_options, feeds, output_names, &fetches);
ASSERT_TRUE(st.IsOK()) << st;
ASSERT_EQ(2, fetches.size());

// prepare expected outputs
float expected_value_b = 6.f;
std::vector<int64_t> expected_dims_user_defined_vals = {2, 1};
std::vector<float> expected_user_defined_vals = {-6.f, 12.f};

auto& b_out = fetches[0].Get<Tensor>();
TensorShape expected_shape(scalar);
ASSERT_EQ(expected_shape, b_out.Shape());
ASSERT_EQ(b_out.DataAsSpan<float>()[0], expected_value_b);

auto user_defined_vals_out = fetches[1].Get<Tensor>().DataAsSpan<float>();
ASSERT_EQ(expected_user_defined_vals.size(), static_cast<size_t>(user_defined_vals_out.size()));
for (size_t i = 0, end = expected_user_defined_vals.size(); i < end; ++i) {
ASSERT_THAT(user_defined_vals_out[i], testing::FloatEq(expected_user_defined_vals[i]));
}
}

#ifdef USE_CUDA
// test that when part of the subgraph run on CUDA it executes successfully
TEST(Loop, MixedExecutionProviders) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
:�
�
max_trip_count
keep_going_inp
bb_loopmy_localuser_defined_vals"Loop*�
body2�

a
bmy_local"Add

a
bb_loop"Sub
'
my_local
b_loop
keep_going"Greater
(
b_loop
b_loopuser_defined_vals"AddbodyZ
iteration_num


Z
keep_going_inp

 
Z
b


b

keep_going

 
b
b_loop


b
my_local


b
user_defined_vals


�outerZ
a


Z
b


Z
keep_going_inp

 
Z
max_trip_count


b
b


b#
user_defined_vals



B
Expand Down

0 comments on commit 9fb8867

Please sign in to comment.