Skip to content

Commit

Permalink
A few performance improvements:
Browse files Browse the repository at this point in the history
 - Make the iteration in NonZero more efficient by using a raw pointer and simplifying the increment logic
   - add another unit test to check the new logic works with 3 dimensional tensor
   - gains about 2% for ssd_mobilenet
 - Avoid floating point operations on each iteration on Concat
  - about 0.5% for ssd_mobilenet and ssd_resnet34
 - Put common case first in ExecutionFrame::AllocateAsPerAllocationPlan to avoid unnecessary call to IsSparseTensor
  - about 0.05% for ssd_mobilenet
 - Minor tweak to put some ctors in the TensorShape header so they can be inlined more easily
  • Loading branch information
skottmckay committed Aug 7, 2019
1 parent 9a34089 commit 20335d1
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 90 deletions.
9 changes: 5 additions & 4 deletions include/onnxruntime/core/framework/tensor_shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,13 @@ class TensorShape : private std::vector<int64_t> {
TensorShape(TensorShape&& /*other*/) = default;
TensorShape& operator=(TensorShape&& /*other*/) = default;

TensorShape(const int64_t* dimension_sizes, size_t dimension_count);
TensorShape(const std::vector<int64_t>& dims) : std::vector<int64_t>(dims) {}

TensorShape(std::vector<int64_t>&& dims) : std::vector<int64_t>(std::move(dims)) {}

TensorShape(const std::vector<int64_t>& dims);
TensorShape(std::vector<int64_t>&& dims);
TensorShape(const std::initializer_list<int64_t>& dims) : std::vector<int64_t>(dims) {}

TensorShape(const std::initializer_list<int64_t>& dims);
TensorShape(const int64_t* dimension_sizes, size_t dimension_count);

TensorShape(const std::vector<int64_t>& dims, size_t start, size_t end);

Expand Down
78 changes: 39 additions & 39 deletions onnxruntime/core/framework/execution_frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -402,54 +402,54 @@ Status ExecutionFrame::AllocateAsPerAllocationPlan(OrtValue& ort_value, int ort_

const auto& alloc_info = per_alloc_plan.location;
const auto* ml_type = per_alloc_plan.value_type;
if (ml_type == nullptr)
if (ml_type == nullptr) {
return Status(
ONNXRUNTIME, INVALID_ARGUMENT,
"Tried to allocate without valid type information, ort_value index=" + std::to_string(ort_value_index));

if (ml_type->IsSparseTensorType()) {
return AllocateSparseTensor(ort_value, *ml_type, GetAllocator(alloc_info),
*shape, nnz, per_alloc_plan.create_fence_if_async, session_state_);
}
if (!ml_type->IsTensorType()) {
return AllocateTraditionalMLValue(ort_value, *static_cast<const NonTensorTypeBase*>(ml_type));
}

ORT_ENFORCE(shape, "Allocation of tensor types requires a shape.");
if (ml_type->IsTensorType()) {
ORT_ENFORCE(shape, "Allocation of tensor types requires a shape.");

// tensors
const auto* ml_data_type = static_cast<const TensorTypeBase*>(ml_type)->GetElementType();
// tensors
const auto* ml_data_type = static_cast<const TensorTypeBase*>(ml_type)->GetElementType();

AllocKind alloc_kind = per_alloc_plan.alloc_kind;
switch (alloc_kind) {
// Right now for kAllocate and kAllocateOutput we are using same approach.
// In the future we may want to have different way to handle it.
case AllocKind::kAllocateOutput:
case AllocKind::kAllocate: {
ORT_RETURN_IF_ERROR(AllocateMLValueTensorSelfOwnBuffer(ort_value, ort_value_index, ml_data_type, alloc_info,
*shape, per_alloc_plan.create_fence_if_async));
break;
}
case AllocKind::kReuse: {
int reuse_mlvalue_index = per_alloc_plan.reused_buffer;
ORT_RETURN_IF_ERROR(AllocateMLValueTensorPreAllocateBuffer(
ort_value, reuse_mlvalue_index, ml_data_type, alloc_info, *shape, per_alloc_plan.create_fence_if_async));
break;
}
case AllocKind::kShare: {
int reuse_mlvalue_index = per_alloc_plan.reused_buffer;
// copy at the OrtValue level so the shared_ptr for the data is shared between the two OrtValue instances
ort_value = GetMutableMLValue(reuse_mlvalue_index);
break;
}
default: {
std::ostringstream ostr;
ostr << "Invalid allocation kind: " << static_cast<std::underlying_type<AllocKind>::type>(alloc_kind);
return Status(ONNXRUNTIME, FAIL, ostr.str());
AllocKind alloc_kind = per_alloc_plan.alloc_kind;
switch (alloc_kind) {
// Right now for kAllocate and kAllocateOutput we are using same approach.
// In the future we may want to have different way to handle it.
case AllocKind::kAllocateOutput:
case AllocKind::kAllocate: {
ORT_RETURN_IF_ERROR(AllocateMLValueTensorSelfOwnBuffer(ort_value, ort_value_index, ml_data_type, alloc_info,
*shape, per_alloc_plan.create_fence_if_async));
break;
}
case AllocKind::kReuse: {
int reuse_mlvalue_index = per_alloc_plan.reused_buffer;
ORT_RETURN_IF_ERROR(AllocateMLValueTensorPreAllocateBuffer(
ort_value, reuse_mlvalue_index, ml_data_type, alloc_info, *shape, per_alloc_plan.create_fence_if_async));
break;
}
case AllocKind::kShare: {
int reuse_mlvalue_index = per_alloc_plan.reused_buffer;
// copy at the OrtValue level so the shared_ptr for the data is shared between the two OrtValue instances
ort_value = GetMutableMLValue(reuse_mlvalue_index);
break;
}
default: {
std::ostringstream ostr;
ostr << "Invalid allocation kind: " << static_cast<std::underlying_type<AllocKind>::type>(alloc_kind);
return Status(ONNXRUNTIME, FAIL, ostr.str());
}
}
}

return Status::OK();
return Status::OK();
} else if (ml_type->IsSparseTensorType()) {
return AllocateSparseTensor(ort_value, *ml_type, GetAllocator(alloc_info),
*shape, nnz, per_alloc_plan.create_fence_if_async, session_state_);
} else {
return AllocateTraditionalMLValue(ort_value, *static_cast<const NonTensorTypeBase*>(ml_type));
}
}

AllocatorPtr ExecutionFrame::GetAllocatorImpl(const OrtAllocatorInfo& info) const {
Expand Down
12 changes: 2 additions & 10 deletions onnxruntime/core/framework/tensor_shape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,8 @@

namespace onnxruntime {

TensorShape::TensorShape(const std::vector<int64_t>& dims) : std::vector<int64_t>(dims) {
}

TensorShape::TensorShape(std::vector<int64_t>&& dims) : std::vector<int64_t>(std::move(dims)) {
}

TensorShape::TensorShape(const std::initializer_list<int64_t>& dims) : std::vector<int64_t>(dims) {
}

TensorShape::TensorShape(const int64_t* dimension_sizes, size_t dimension_count) : std::vector<int64_t>(dimension_count) {
TensorShape::TensorShape(const int64_t* dimension_sizes, size_t dimension_count)
: std::vector<int64_t>(dimension_count) {
for (size_t i = 0; i < dimension_count; ++i) {
(*this)[i] = dimension_sizes[i];
}
Expand Down
43 changes: 27 additions & 16 deletions onnxruntime/core/providers/cpu/tensor/concat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,17 @@ Status ConcatBase::PrepareForCompute(OpKernelContext* ctx, int input_count, Prep
auto& inputs_n = *tensor_pointer;
const auto& inputs_n_dims = inputs_n.Shape().GetDims();
const size_t inputs_n_rank = inputs_n_dims.size();
ORT_ENFORCE(inputs_n_rank == inputs_0_rank, "Ranks of input data are different, cannot concatenate them, "
"expected rank: ", std::to_string(inputs_0_rank), " got: ", std::to_string(inputs_n_rank));
ORT_ENFORCE(inputs_n_rank == inputs_0_rank,
"Ranks of input data are different, cannot concatenate them. expected rank: ",
inputs_0_rank, " got: ", inputs_n_rank);
// Ensure all the other (non-concat) axes match
for (size_t axis_index = 0; axis_index < inputs_0_rank; ++axis_index) {
num_elements *= inputs_n_dims[axis_index];
if (axis_index == p.axis)
continue;
ORT_RETURN_IF_NOT(inputs_n_dims[axis_index] == inputs_0_dims[axis_index],
"Non concat axis dimensions must match: Axis ",
axis_index, " has mismatched dimensions of ", inputs_n_dims[axis_index],
"Non concat axis dimensions must match: Axis ",
axis_index, " has mismatched dimensions of ", inputs_n_dims[axis_index],
" and ", inputs_0_dims[axis_index]);
}
tensor_num_elements[index] = num_elements;
Expand All @@ -58,15 +59,15 @@ Status ConcatBase::PrepareForCompute(OpKernelContext* ctx, int input_count, Prep

// Calculate the shape of the output tensor
std::vector<int64_t> dims(inputs_0_rank);
size_t num_elements = 1; // cache size of the first input along the way
size_t num_elements = 1; // cache size of the first input along the way
for (size_t dimension_index = 0; dimension_index < inputs_0_rank; dimension_index++) {
dims[dimension_index] = inputs_0_dims[dimension_index];
num_elements *= inputs_0_dims[dimension_index];
}
tensor_num_elements[0] = num_elements;
dims[p.axis] = concat_axis_size;
TensorShape output_shape(dims);

auto& concat_result = *ctx->Output(0, output_shape);
p.output_tensor = &concat_result;
p.output_num_elements = output_shape.Size();
Expand All @@ -75,7 +76,7 @@ Status ConcatBase::PrepareForCompute(OpKernelContext* ctx, int input_count, Prep
// there is no need to proceed further
if (p.output_num_elements == 0)
return Status::OK();

// The output_axis_pitch is the number of elements to add to move to the next split axis in the output
p.output_axis_pitch = 1;
for (size_t i = inputs_0_rank; i-- > p.axis;) p.output_axis_pitch *= dims[i];
Expand Down Expand Up @@ -110,7 +111,7 @@ Status Concat::Compute(OpKernelContext* ctx) const {

auto is_string_type = ctx->Input<Tensor>(0)->DataType() == DataTypeImpl::GetType<std::string>();

int64_t output_offset = 0;
int64_t initial_output_offset = 0; // initial offset for each input
auto element_bytes = p.output_tensor->DataType()->Size();
for (int input_index = 0; input_index < input_count; input_index++) {
const auto& prep = p.inputs[input_index];
Expand All @@ -124,19 +125,29 @@ Status Concat::Compute(OpKernelContext* ctx) const {

// Copy the data across. For every 'input_axis_pitch' values copied, we move over by the 'output_axis_pitch'
uint8_t* output = static_cast<uint8_t*>(p.output_tensor->MutableDataRaw());
for (size_t idxCopy = 0; idxCopy < input_size / input_axis_pitch; ++idxCopy) {
int64_t cur_out_offset = 0;
int64_t cur_in_offset = 0;
for (size_t idx_copy = 0, end = input_size / input_axis_pitch; idx_copy < end; ++idx_copy) {
if (is_string_type) {
for (int idxItem = 0; idxItem < input_axis_pitch; ++idxItem)
reinterpret_cast<std::string*>(output)[output_offset + idxCopy * p.output_axis_pitch + idxItem] =
reinterpret_cast<const std::string*>(input)[idxCopy * input_axis_pitch + idxItem];
} else
size_t out = initial_output_offset + cur_out_offset;
for (int idxItem = 0; idxItem < input_axis_pitch; ++idxItem) {
reinterpret_cast<std::string*>(output)[out + idxItem] =
reinterpret_cast<const std::string*>(input)[cur_in_offset + idxItem];
}
} else {
memcpy(
output + (output_offset + idxCopy * p.output_axis_pitch) * element_bytes,
input + idxCopy * input_axis_pitch * element_bytes,
output + (initial_output_offset + cur_out_offset) * element_bytes,
input + cur_in_offset * element_bytes,
input_axis_pitch * element_bytes);
}

cur_out_offset += p.output_axis_pitch;
cur_in_offset += input_axis_pitch;
}
output_offset += input_axis_pitch;

initial_output_offset += input_axis_pitch;
}

return Status::OK();
}

Expand Down
42 changes: 21 additions & 21 deletions onnxruntime/core/providers/cpu/tensor/nonzero_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,24 +40,6 @@ NONZERO_TYPED_KERNEL(float)
#undef NONZERO_TYPED_KERNEL_WITH_TYPE_NAME
#undef NONZERO_TYPED_KERNEL

namespace {
void IncrementCoordinate(const TensorShape& shape, std::vector<int64_t>* coordinate) {
assert(coordinate->size() == shape.NumDimensions());

size_t i = 0;
const size_t i_end = coordinate->size();
for (; i < i_end; ++i) {
const size_t i_from_back = i_end - i - 1;
if ((*coordinate)[i_from_back] != shape[i_from_back] - 1) break;
(*coordinate)[i_from_back] = 0;
}

if (i < i_end) {
++(*coordinate)[i_end - i - 1];
}
}
} // namespace

template <typename T>
Status NonZero<T>::Compute(OpKernelContext* context) const {
const auto X = context->Input<Tensor>(0);
Expand All @@ -71,19 +53,37 @@ Status NonZero<T>::Compute(OpKernelContext* context) const {
// reserve enough space for indices for every element of X
non_zero_indices_buffer.reserve(X_shape.Size() * coordinate_size);

const T* data = X->Data<T>();

if (X_shape.IsScalar()) {
const T& value = *(X->Data<T>());
const T& value = *data;
if (value != T{}) {
non_zero_indices_buffer.push_back(0);
}
} else {
std::vector<int64_t> coordinate(coordinate_size, 0);
for (const T& value : X->DataAsSpan<T>()) {

// as we iterate the entries, increment the coordinate for the current entry
// e.g. if shape is {2,2}, we start with 0,0 increment to 0,1 increment to 1,0 and finally 1,1
auto increment_coordinate = [&coordinate, &coordinate_size, &X_shape]() {
for (int64_t idx = coordinate_size - 1; idx >= 0; --idx) {
int64_t& cur_coord = coordinate[idx];
if (cur_coord != X_shape[idx] - 1) {
++cur_coord;
break;
}
cur_coord = 0;
}
};

for (size_t i = 0, end = X_shape.Size(); i < end; ++i) {
const T& value = *data++;
if (value != T{}) {
non_zero_indices_buffer.insert(non_zero_indices_buffer.end(),
coordinate.begin(), coordinate.end());
}
IncrementCoordinate(X_shape, &coordinate);

increment_coordinate();
}
}

Expand Down
19 changes: 19 additions & 0 deletions onnxruntime/test/providers/cpu/tensor/nonzero_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,25 @@ TEST(NonZeroOpTest, BasicBool) {
test.Run();
}

TEST(NonZeroOpTest, ThreeDims) {
OpTester test{kOpName, kOpVersion};

std::vector<int64_t> X_dims{2, 2, 2};
std::vector<int64_t> X{0, 1,
1, 0,

1, 0,
1, 0};
test.AddInput<int64_t>("X", X_dims, std::vector<int64_t>{X.begin(), X.end()});
test.AddOutput<int64_t>(
"Y", {3, 4},
{0, 0, 1, 1,
0, 1, 0, 1,
1, 0, 0, 0});

test.Run();
}

TEST(NonZeroOpTest, Scalar) {
{
OpTester test{kOpName, kOpVersion};
Expand Down

0 comments on commit 20335d1

Please sign in to comment.