Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make GetTensorShapeFromTensorShapeProto return TensorShape and not it's internal representation. #1353

Merged
merged 1 commit into from
Jul 8, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/onnxruntime/core/framework/tensor_shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class TensorShape : private std::vector<int64_t> {
TensorShape(const int64_t* dimension_sizes, size_t dimension_count);

TensorShape(const std::vector<int64_t>& dims);
TensorShape(std::vector<int64_t>&& dims);
pranavsharma marked this conversation as resolved.
Show resolved Hide resolved

TensorShape(const std::initializer_list<int64_t>& dims);

Expand Down
14 changes: 8 additions & 6 deletions onnxruntime/core/framework/tensor_shape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ namespace onnxruntime {
TensorShape::TensorShape(const std::vector<int64_t>& dims) : std::vector<int64_t>(dims) {
}

TensorShape::TensorShape(std::vector<int64_t>&& dims) : std::vector<int64_t>(dims) {
}

TensorShape::TensorShape(const std::initializer_list<int64_t>& dims) : std::vector<int64_t>(dims) {
}

Expand All @@ -20,7 +23,6 @@ TensorShape::TensorShape(const int64_t* dimension_sizes, size_t dimension_count)
}
}


TensorShape::TensorShape(const std::vector<int64_t>& dims, size_t start, size_t end) {
assign(dims.begin() + start, dims.begin() + end);
}
Expand All @@ -38,8 +40,8 @@ int64_t TensorShape::Size() const {
int64_t TensorShape::SizeToDimension(size_t dimension) const {
const size_t num_dims = size();
ORT_ENFORCE(dimension <= num_dims,
"Invalid dimension of ", dimension, " for SizeFromDimension. Tensor has ",
num_dims, " dimensions.");
"Invalid dimension of ", dimension, " for SizeFromDimension. Tensor has ",
num_dims, " dimensions.");

int64_t size = SizeHelper(0, dimension);
return size;
Expand All @@ -48,16 +50,16 @@ int64_t TensorShape::SizeToDimension(size_t dimension) const {
int64_t TensorShape::SizeFromDimension(size_t dimension) const {
const size_t num_dims = size();
ORT_ENFORCE(dimension <= num_dims,
"Invalid dimension of ", dimension, " for SizeFromDimension. Tensor has ",
num_dims, " dimensions.");
"Invalid dimension of ", dimension, " for SizeFromDimension. Tensor has ",
num_dims, " dimensions.");

int64_t size = SizeHelper(dimension, num_dims);
return size;
}

TensorShape TensorShape::Slice(size_t dimstart, size_t dimend) const {
ORT_ENFORCE(dimstart <= dimend && dimend <= size(),
"Invalid tensor shape slice argument.");
"Invalid tensor shape slice argument.");
return TensorShape(*this, dimstart, dimend);
}

Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/framework/tensorprotoutils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -274,14 +274,14 @@ common::Status GetSizeInBytesFromTensorProto(const ONNX_NAMESPACE::TensorProto&
return Status::OK();
}

std::vector<int64_t> GetTensorShapeFromTensorShapeProto(const ONNX_NAMESPACE::TensorShapeProto& tensor_shape_proto) {
TensorShape GetTensorShapeFromTensorShapeProto(const ONNX_NAMESPACE::TensorShapeProto& tensor_shape_proto) {
const auto& dims = tensor_shape_proto.dim();
std::vector<int64_t> tensor_shape_vec(static_cast<size_t>(dims.size()));
for (int i = 0; i < dims.size(); ++i) {
tensor_shape_vec[i] = dims[i].has_dim_param() ? -1 /* symbolic dimensions are represented as -1 in onnxruntime*/
: dims[i].dim_value();
}
return tensor_shape_vec;
return TensorShape(std::move(tensor_shape_vec));
pranavsharma marked this conversation as resolved.
Show resolved Hide resolved
}

struct UnInitializeParam {
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/framework/tensorprotoutils.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class TensorShapeProto;
namespace onnxruntime {
class Tensor;
namespace utils {
std::vector<int64_t> GetTensorShapeFromTensorShapeProto(const ONNX_NAMESPACE::TensorShapeProto& tensor_shape_proto);
TensorShape GetTensorShapeFromTensorShapeProto(const ONNX_NAMESPACE::TensorShapeProto& tensor_shape_proto);
/**
* deserialize a TensorProto into a preallocated memory buffer.
* \param tensor_proto_path A local file path of where the 'input' was loaded from. Can be NULL if the tensor proto doesn't
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cpu/controlflow/if.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ Status IfImpl::AllocateOutputTensors() {
graph_output->Name(), " did not.");
}

TensorShape output_shape{onnxruntime::utils::GetTensorShapeFromTensorShapeProto(*graph_output_shape)};
TensorShape output_shape = onnxruntime::utils::GetTensorShapeFromTensorShapeProto(*graph_output_shape);

// if size < 0 we have a symbolic dimension and need to use a temporary OrtValue in the subgraph execution
if (output_shape.Size() < 0) {
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/providers/cpu/controlflow/loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,8 @@ Status LoopImpl::Execute(FeedsFetchesManager* ffm, const FeedsFetchesManager* ca
if (graph_output_shape) {
output_dims.reserve(graph_output_shape->dim_size() + 1);

auto dims = onnxruntime::utils::GetTensorShapeFromTensorShapeProto(*graph_output_shape);
const auto& tensor_shape = onnxruntime::utils::GetTensorShapeFromTensorShapeProto(*graph_output_shape);
const auto& dims = tensor_shape.GetDims();
pranavsharma marked this conversation as resolved.
Show resolved Hide resolved
std::copy(dims.cbegin(), dims.cend(), std::back_inserter(output_dims));
} else {
// TODO: We could try and call ExecuteGraph to get the output shape from fetches so the rank is correct,
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cpu/controlflow/scan_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ Status AllocateOutput(OpKernelContextInternal& context, const GraphViewer& subgr
graph_output->Name(), " did not.");
}

TensorShape output_shape{onnxruntime::utils::GetTensorShapeFromTensorShapeProto(*graph_output_shape)};
TensorShape output_shape = onnxruntime::utils::GetTensorShapeFromTensorShapeProto(*graph_output_shape);
auto& graph_output_dims{output_shape.GetDims()};

std::vector<int64_t> scan_output_dims;
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/test/framework/inference_session_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -348,8 +348,8 @@ static bool Compare(const InputDefList& f_arg, const InputDefList& s_arg) {
if (!x->Shape()) {
continue;
}
vector<int64_t> x_shape = utils::GetTensorShapeFromTensorShapeProto(*x->Shape());
vector<int64_t> y_shape = utils::GetTensorShapeFromTensorShapeProto(*y->Shape());
auto x_shape = utils::GetTensorShapeFromTensorShapeProto(*x->Shape());
auto y_shape = utils::GetTensorShapeFromTensorShapeProto(*y->Shape());
if (x->Name() == y->Name() && x_shape == y_shape && *x->Type() == *y->Type()) {
continue;
}
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/test/providers/provider_test_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,8 @@ void OpTester::ExecuteModel(Model& model, InferenceSession& session_object, Expe
if (add_shape_to_tensor_data_) {
auto out_shape_proto = expected_data.def_.Shape();
EXPECT_TRUE(out_shape_proto != nullptr);
auto inferred_dims = utils::GetTensorShapeFromTensorShapeProto(*out_shape_proto);
const auto& tensor_shape = utils::GetTensorShapeFromTensorShapeProto(*out_shape_proto);
const auto& inferred_dims = tensor_shape.GetDims();
const auto& expected_shape = expected_data.data_.Get<Tensor>().Shape();
EXPECT_TRUE(inferred_dims.size() == expected_shape.NumDimensions());
for (size_t d = 0; d < inferred_dims.size(); ++d) {
Expand Down