Skip to content

Commit

Permalink
Merge pull request #24 from Microsoft/scmckay/Scan_SupportScalarInputs
Browse files Browse the repository at this point in the history
Support scalar inputs to the Scan subgraph
  • Loading branch information
skottmckay authored Nov 26, 2018
2 parents 84fa101 + 03d7d25 commit 194cc15
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 26 deletions.
4 changes: 2 additions & 2 deletions onnxruntime/core/framework/mlvalue_tensor_slicer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ MLValueTensorSlicer<T> MLValueTensorSlicer<T>::Create(T& mlvalue, int64_t slice_
ONNXRUNTIME_ENFORCE(mlvalue.IsAllocated(), "MLValue has not been allocated so can't be sliced.");

auto& tensor_shape{mlvalue.template Get<Tensor>().Shape()};
ONNXRUNTIME_ENFORCE(gsl::narrow_cast<int64_t>(tensor_shape.NumDimensions()) > slice_dimension,
"Insufficient dimensions to slice on ", slice_dimension, ". Shape:", tensor_shape);
ONNXRUNTIME_ENFORCE(gsl::narrow_cast<int64_t>(tensor_shape.NumDimensions()) >= slice_dimension,
"Insufficient dimensions to slice on ", slice_dimension, ". Shape:", tensor_shape);

auto dim0_size = tensor_shape[0];
ONNXRUNTIME_ENFORCE(dim0_offset < dim0_size, "Invalid dim0_offset of ", dim0_offset, ". Dimension 0 is ", dim0_size);
Expand Down
5 changes: 3 additions & 2 deletions onnxruntime/core/providers/cpu/controlflow/scan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,9 @@ static const MLValue& GetSubgraphInputMLValue(const OpKernelContextInternal& con
// Validate that the subgraph input has valid shapes
Status ScanImpl::ValidateSubgraphInput(int start_input, int end_input, bool has_seq_len_dim,
const std::vector<const NodeArg*>& graph_inputs) {
// first dim is batch size. optional sequence dim. dim/s for the data
auto min_dims_required = has_seq_len_dim ? 3 : 2;
// first dim is batch size. optional sequence dim. dim/s for the data.
// if there is no dim for the data treat it as a scalar.
auto min_dims_required = has_seq_len_dim ? 2 : 1;

for (int i = start_input; i < end_input; ++i) {
auto& input_tensor = GetSubgraphInputTensor(context_, i);
Expand Down
68 changes: 46 additions & 22 deletions onnxruntime/test/providers/cpu/controlflow/scan_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ struct RunOptions {
bool include_dim_values_in_subgraph = true;
bool include_types_in_subgraph = true;
bool include_outer_scope_add = false;
bool scalar_loop_state_value = false;
bool add_bad_shape = false;
};

Expand All @@ -37,13 +38,13 @@ class ScanOpTester : public OpTester {
// add outer_scope_0 node. push the value through an extra Identity node as a Constant gets lifted into an
// initializer which results in different treatment by the allocation planner
{
TypeProto float_scalar;
float_scalar.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
auto mutable_dim = float_scalar.mutable_tensor_type()->mutable_shape()->add_dim();
TypeProto float_single_value;
float_single_value.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
auto mutable_dim = float_single_value.mutable_tensor_type()->mutable_shape()->add_dim();
mutable_dim->set_dim_value(1);

{
auto& outer_scope_constant = graph.GetOrCreateNodeArg("outer_scope_constant", &float_scalar);
auto& outer_scope_constant = graph.GetOrCreateNodeArg("outer_scope_constant", &float_single_value);
auto* constant = graph.AddNode("outer_scope_constant", "Constant", "Constant with value kOuterNodeAddValue",
{}, {&outer_scope_constant});

Expand All @@ -54,7 +55,7 @@ class ScanOpTester : public OpTester {

constant->AddAttribute("value", value_tensor);

auto& outer_scope_node_arg = graph.GetOrCreateNodeArg("outer_scope_0", &float_scalar);
auto& outer_scope_node_arg = graph.GetOrCreateNodeArg("outer_scope_0", &float_single_value);
graph.AddNode("outer_scope_id", "Identity", "Identity for outer_scope_0",
{&outer_scope_constant}, {&outer_scope_node_arg});
}
Expand All @@ -66,7 +67,7 @@ class ScanOpTester : public OpTester {
};

static void CreateSubgraph(Graph& graph, RunOptions& options, const std::string& failure_message) {
bool include_shapes = options.include_dim_values_in_subgraph;
bool include_dim_values = options.include_dim_values_in_subgraph;
bool include_types = options.include_types_in_subgraph;

std::vector<NodeArg*> inputs;
Expand Down Expand Up @@ -94,21 +95,27 @@ static void CreateSubgraph(Graph& graph, RunOptions& options, const std::string&
inputs = {};
outputs = {};

TypeProto float_scalar;
TypeProto float_input;
// inputs must have type information and a rank
float_scalar.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
auto mutable_dim = float_scalar.mutable_tensor_type()->mutable_shape()->add_dim();
if (include_shapes)
mutable_dim->set_dim_value(1);
float_input.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
auto mutable_shape = float_input.mutable_tensor_type()->mutable_shape();
if (options.scalar_loop_state_value) {
// no dims
} else {
auto mutable_dim = mutable_shape->add_dim(); // set rank
if (include_dim_values)
mutable_dim->set_dim_value(1);
}

{
auto& output_arg = graph.GetOrCreateNodeArg("constant_1", &float_scalar);
auto& output_arg = graph.GetOrCreateNodeArg("constant_1", &float_input);
outputs.push_back(&output_arg);

auto* constant = graph.AddNode("constant", "Constant", "Constant with value 1", inputs, outputs);

TensorProto value_tensor;
value_tensor.add_dims(1);
if (!options.scalar_loop_state_value)
value_tensor.add_dims(1);
value_tensor.add_float_data(1.f);
value_tensor.set_data_type(onnx::TensorProto_DataType_FLOAT);

Expand All @@ -118,7 +125,7 @@ static void CreateSubgraph(Graph& graph, RunOptions& options, const std::string&
inputs = outputs; // start with output from Constant node
outputs = {};

auto& input_arg = graph.GetOrCreateNodeArg("loop_state_in_1", &float_scalar);
auto& input_arg = graph.GetOrCreateNodeArg("loop_state_in_1", &float_input);
inputs.push_back(&input_arg);

TypeProto loop_state_output_tensor;
Expand All @@ -128,15 +135,17 @@ static void CreateSubgraph(Graph& graph, RunOptions& options, const std::string&
// it has to come from here.
bool type_and_shape_required = options.include_dim_values_in_main_graph == false;

if (include_shapes || type_and_shape_required)
loop_state_output_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
if (include_dim_values || type_and_shape_required) {
mutable_shape = loop_state_output_tensor.mutable_tensor_type()->mutable_shape();
if (!options.scalar_loop_state_value)
mutable_shape->add_dim()->set_dim_value(1);
}

TypeProto* type_proto = include_types || type_and_shape_required ? &loop_state_output_tensor : nullptr;
auto& output_arg = graph.GetOrCreateNodeArg("loop_state_out_1", type_proto);
outputs.push_back(&output_arg);

auto* add = graph.AddNode("add", "Add", "Add 1 to the loop state", inputs, outputs);
(void)add;
graph.AddNode("add", "Add", "Add 1 to the loop state", inputs, outputs);
}

// subgraph with multiple inputs and outputs to test variadic behaviour.
Expand All @@ -152,7 +161,7 @@ static void CreateSubgraph(Graph& graph, RunOptions& options, const std::string&
// inputs must have type information and rank, but dimension can have no value if we're not providing shape info.
concat_input_tensor.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
auto mutable_dim = concat_input_tensor.mutable_tensor_type()->mutable_shape()->add_dim();
if (include_shapes) {
if (include_dim_values) {
mutable_dim->set_dim_value(2);

if (options.add_bad_shape) {
Expand All @@ -168,7 +177,7 @@ static void CreateSubgraph(Graph& graph, RunOptions& options, const std::string&
// one output from concatenate of {4} tensor
TypeProto concat_output_tensor;
concat_output_tensor.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
if (include_shapes)
if (include_dim_values)
concat_output_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(4);

TypeProto* type_proto = include_types ? &concat_output_tensor : nullptr;
Expand Down Expand Up @@ -277,13 +286,18 @@ void RunTest(const std::string test_name, int64_t batch_size, int64_t max_sequen
test.AddInput<int64_t>("sequence_lens", sequence_lens_dims, *sequence_lens);
}

test.AddInput<float>("scan_loop_state_in_0", {batch_size, 1}, loop_state_in_0);
std::vector<int64_t> loop_state_shape{batch_size};
if (!options.scalar_loop_state_value) {
loop_state_shape.push_back(1);
}

test.AddInput<float>("scan_loop_state_in_0", loop_state_shape, loop_state_in_0);

std::vector<int64_t> input_shape{batch_size, max_sequence_len, input_size};
test.AddInput<float>("scan_input_0", input_shape, input_0);
test.AddInput<float>("scan_input_1", input_shape, input_1);

test.AddOutput<float>("scan_loop_state_out_0", {batch_size, 1}, loop_state_out_0);
test.AddOutput<float>("scan_loop_state_out_0", loop_state_shape, loop_state_out_0);

std::vector<int64_t> output_shape{batch_size, max_sequence_len, 1};
test.AddOutput<float>("scan_output_0", output_shape, output_0);
Expand Down Expand Up @@ -353,6 +367,16 @@ TEST(Scan, ShortSequenceOneInBatchOneLoopStateVar_NoShapeInMainGraph_NoTypeAndSh
ShortSequenceOneInBatchOneLoopStateVar(options);
}

TEST(Scan, OnnxScalarLoopState) {
RunOptions options{};
options.include_dim_values_in_main_graph = true;
options.include_types_in_subgraph = false;
options.include_dim_values_in_subgraph = false;
options.scalar_loop_state_value = true;

ShortSequenceOneInBatchOneLoopStateVar(options);
}

// test when there is an operator in the subgraph that uses a value coming from outer scope
TEST(Scan, OuterScopeAccess_NoShapeInMainGraph_TypeAndShapeInSubgraph) {
RunOptions options{};
Expand Down

0 comments on commit 194cc15

Please sign in to comment.