Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't create implicit input for outer scope value if there is a subgraph input with the same name. #1186

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/onnxruntime/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -853,7 +853,7 @@ class Graph {

// Build and verify node connection (edges).
// Verify NodeArg name/type/shape matching correctly.
common::Status BuildConnections(std::vector<std::string>& outer_scope_node_args_consumed);
common::Status BuildConnections(std::unordered_set<std::string>& outer_scope_node_args_consumed);

common::Status VerifyNoDuplicateName();

Expand Down
50 changes: 28 additions & 22 deletions onnxruntime/core/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -898,7 +898,7 @@ void Graph::RemoveEdge(NodeIndex src_node_index, NodeIndex dst_node_index, int s
}

GSL_SUPPRESS(es .84) // ignoring return value from unordered_map::insert causes noisy complaint
Status Graph::BuildConnections(std::vector<std::string>& outer_scope_node_args_consumed) {
Status Graph::BuildConnections(std::unordered_set<std::string>& outer_scope_node_args_consumed) {
const std::unordered_set<std::string>& outer_scope_node_args = resolve_context_.outer_scope_node_args;
std::unordered_set<Node*> inner_nodes;

Expand All @@ -908,7 +908,7 @@ Status Graph::BuildConnections(std::vector<std::string>& outer_scope_node_args_c

for (auto* node : resolve_context_.nodes_with_subgraphs) {
for (auto& subgraph : node->MutableSubgraphs()) {
std::vector<std::string> node_args_consumed;
std::unordered_set<std::string> node_args_consumed;
ORT_RETURN_IF_ERROR(subgraph->BuildConnections(node_args_consumed));

for (auto& node_arg_name : node_args_consumed) {
Expand All @@ -918,7 +918,7 @@ Status Graph::BuildConnections(std::vector<std::string>& outer_scope_node_args_c
// it's a node arg from outside this graph's scope, so add that to the list we return
// so that we can add the dependency at the next level up. this happens if you have multiple
// levels of subgraphs between the graph with the original NodeArg and the subgraph with implicit usage.
outer_scope_node_args_consumed.push_back(node_arg_name);
ORT_IGNORE_RETURN_VALUE(outer_scope_node_args_consumed.insert(node_arg_name));

if (!parent_graph_) {
return ORT_MAKE_STATUS(
Expand Down Expand Up @@ -987,25 +987,31 @@ Status Graph::BuildConnections(std::vector<std::string>& outer_scope_node_args_c
continue;
}

auto output_arg_iter = resolve_context_.output_args.find(input_arg->Name());
if (resolve_context_.output_args.end() == output_arg_iter) {
// No such output_arg matching this input_arg.
// This input arg should be fed when running evaluation.
// See if it's present in the outer scope. If so it will be 'fed' by the execution frame
// providing access to the OrtValue from the outer scope. Pass the name back up so nodes can
// be linked correctly at that level.
if (outer_scope_node_args.find(input_arg->Name()) != outer_scope_node_args.cend()) {
outer_scope_node_args_consumed.push_back(input_arg->Name());
}
const auto& input_arg_name = input_arg->Name();
auto output_arg_iter = resolve_context_.output_args.find(input_arg_name);
if (resolve_context_.output_args.end() != output_arg_iter) {
// The input to this node is an output from a previous node in this graph.
// Create relationship between this node (node), and the node providing the output (output_node).
Node& output_node = *output_arg_iter->second.first;
AddEdge(output_node.Index(), node.Index(), output_arg_iter->second.second, input_slot_index);

continue;
inner_nodes.insert(&output_node);
} else {
// the value is either an input, an initializer, or coming from outer scope. we only need to take action
// if coming from outer scope, so first check if this is a subgraph (otherwise there is no outer scope).
if (parent_graph_ != nullptr) {
// make sure it's not an input or initializer first as those override any outer scope values
if (resolve_context_.inputs_and_initializers.find(input_arg_name) ==
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor nit: Conceptually, this belongs together with the search in line 991 (search the locally-defined names, which is the union of output_args and inputs and initializers). While it has the same effect, it might be a bit easier to understand if we move this check up there (basically swapping the two checks in 999 and 1002 around).

resolve_context_.inputs_and_initializers.cend()) {
// If it is present in the outer scope it will be 'fed' by the execution frame
// providing access to the OrtValue from the outer scope. Pass the name back up so nodes can
// be linked correctly at that level.
if (outer_scope_node_args.find(input_arg_name) != outer_scope_node_args.cend()) {
ORT_IGNORE_RETURN_VALUE(outer_scope_node_args_consumed.insert(input_arg_name));
}
}
}
}

// Create relationship between this node (node), and the node providing the output (output_node).
Node& output_node = *output_arg_iter->second.first;
AddEdge(output_node.Index(), node.Index(), output_arg_iter->second.second, input_slot_index);

inner_nodes.insert(&output_node);
}
} else if (node.OutputDefs().empty()) {
// This is a useless node.
Expand All @@ -1015,7 +1021,7 @@ Status Graph::BuildConnections(std::vector<std::string>& outer_scope_node_args_c
}

return Status::OK();
}
} // namespace onnxruntime

void Graph::ReverseDFSFrom(const std::vector<NodeIndex>& from,
const std::function<void(const Node*)>& enter,
Expand Down Expand Up @@ -1869,7 +1875,7 @@ Status Graph::Resolve(bool no_proto_sync_required) {
// recursively set the outer scope node args.
ORT_RETURN_IF_ERROR(SetOuterScopeNodeArgs(resolve_context_.outer_scope_node_args));

std::vector<std::string> outer_scope_node_args_consumed;
std::unordered_set<std::string> outer_scope_node_args_consumed;

// recursively build connections between nodes in this graph and all subgraphs
ORT_RETURN_IF_ERROR(BuildConnections(outer_scope_node_args_consumed));
Expand Down
19 changes: 18 additions & 1 deletion onnxruntime/test/framework/test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
#include "core/providers/cpu/cpu_execution_provider.h"
#include "core/framework/ml_value.h"

#include "gsl/gsl_algorithm"

#ifdef USE_CUDA
#include "core/providers/cuda/cuda_execution_provider.h"
#endif
Expand Down Expand Up @@ -48,6 +50,20 @@ IExecutionProvider* TestOpenVINOExecutionProvider();
IExecutionProvider* TestNnapiExecutionProvider();
#endif

template <typename T>
inline void CopyVectorToTensor(const std::vector<T>& value, Tensor& tensor) {
gsl::copy(gsl::make_span(value), tensor.MutableDataAsSpan<T>());
}

// vector<bool> is specialized so we need to handle it separately
template <>
inline void CopyVectorToTensor<bool>(const std::vector<bool>& value, Tensor& tensor) {
auto output_span = tensor.MutableDataAsSpan<bool>();
for (size_t i = 0, end = value.size(); i < end; ++i) {
output_span[i] = value[i];
}
}

template <typename T>
void CreateMLValue(AllocatorPtr alloc, const std::vector<int64_t>& dims, const std::vector<T>& value,
OrtValue* p_mlvalue) {
Expand All @@ -57,8 +73,9 @@ void CreateMLValue(AllocatorPtr alloc, const std::vector<int64_t>& dims, const s
shape,
alloc);
if (value.size() > 0) {
memcpy(p_tensor->MutableData<T>(), &value[0], element_type->Size() * shape.Size());
CopyVectorToTensor(value, *p_tensor);
}

p_mlvalue->Init(p_tensor.release(),
DataTypeImpl::GetType<Tensor>(),
DataTypeImpl::GetType<Tensor>()->GetDeleteFunc());
Expand Down
57 changes: 57 additions & 0 deletions onnxruntime/test/providers/cpu/controlflow/loop_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "test/providers/provider_test_utils.h"
#include "test/util/include/default_providers.h"
#include "test/framework/test_utils.h"

using namespace ONNX_NAMESPACE;

Expand Down Expand Up @@ -573,6 +574,62 @@ TEST(Loop, InfiniteLoopTermination) {
terminator_thread.join();
}

// Regression test that a subgraph input overrides an outer scope value of the same name.
// Replicate issue from https://github.com/onnx/onnx/issues/2082
TEST(Loop, SubgraphInputShadowsOuterScopeValue) {
SessionOptions so;
so.session_logid = "SubgraphInputShadowsOuterScopeValue";

InferenceSession session_object{so, &DefaultLoggingManager()};
Status st;
ASSERT_TRUE((st = session_object.Load("testdata/subgraph_input_shadows_outer_scope_value.onnx")).IsOK()) << st;
ASSERT_TRUE((st = session_object.Initialize()).IsOK()) << st;

// prepare inputs
std::vector<int64_t> scalar = {1};
std::vector<float> a = {3.f}, b = {6.f};
std::vector<int64_t> trip_count = {10};
std::vector<bool> keep_going = {true};

NameMLValMap feeds;
OrtValue ml_value;

CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), scalar, a, &ml_value);
feeds.insert(std::make_pair("a", ml_value));
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), scalar, b, &ml_value);
feeds.insert(std::make_pair("b", ml_value));
CreateMLValue<int64_t>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), scalar, trip_count, &ml_value);
feeds.insert(std::make_pair("max_trip_count", ml_value));
CreateMLValue<bool>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), scalar, keep_going, &ml_value);
feeds.insert(std::make_pair("keep_going_inp", ml_value));

// prepare outputs
std::vector<std::string> output_names{"b", "user_defined_vals"};
std::vector<OrtValue> fetches;

// Now run
onnxruntime::RunOptions run_options;
st = session_object.Run(run_options, feeds, output_names, &fetches);
ASSERT_TRUE(st.IsOK()) << st;
ASSERT_EQ(2, fetches.size());

// prepare expected outputs
float expected_value_b = 6.f;
std::vector<int64_t> expected_dims_user_defined_vals = {2, 1};
std::vector<float> expected_user_defined_vals = {-6.f, 12.f};

auto& b_out = fetches[0].Get<Tensor>();
TensorShape expected_shape(scalar);
ASSERT_EQ(expected_shape, b_out.Shape());
ASSERT_EQ(b_out.DataAsSpan<float>()[0], expected_value_b);

auto user_defined_vals_out = fetches[1].Get<Tensor>().DataAsSpan<float>();
ASSERT_EQ(expected_user_defined_vals.size(), static_cast<size_t>(user_defined_vals_out.size()));
for (size_t i = 0, end = expected_user_defined_vals.size(); i < end; ++i) {
ASSERT_THAT(user_defined_vals_out[i], testing::FloatEq(expected_user_defined_vals[i]));
}
}

#ifdef USE_CUDA
// test that when part of the subgraph run on CUDA it executes successfully
TEST(Loop, MixedExecutionProviders) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
:�
�
max_trip_count
keep_going_inp
bb_loopmy_localuser_defined_vals"Loop*�
body2�

a
bmy_local"Add

a
bb_loop"Sub
'
my_local
b_loop
keep_going"Greater
(
b_loop
b_loopuser_defined_vals"AddbodyZ
iteration_num


Z
keep_going_inp

 
Z
b


b

keep_going

 
b
b_loop


b
my_local


b
user_defined_vals


�outerZ
a


Z
b


Z
keep_going_inp

 
Z
max_trip_count


b
b


b#
user_defined_vals



B
Expand Down