Skip to content
This repository has been archived by the owner on Jan 3, 2023. It is now read-only.

Commit

Permalink
Add CUDA Expand operator (microsoft#1292)
Browse files Browse the repository at this point in the history
* Add CUDA expand operator

* Reset counter variables when striding

* Reset counter variables when striding

* use fast_divmod and other PR comments

* Fix merge variable rename

* Fix indentation per PR comment

* Remove maxpool_argmax

* Reduce number of type templates for Expand operator

* removed all types

* Commit updated cuda_execution_provider.cc
  • Loading branch information
jignparm authored Jun 27, 2019
1 parent a79ab5e commit 59de37a
Show file tree
Hide file tree
Showing 8 changed files with 237 additions and 2 deletions.
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, bool, Equal);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, int32_t, Equal);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, int64_t, Equal);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, Expand);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, int32_t, Greater);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, int64_t, Greater);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, uint32_t, Greater);
Expand Down Expand Up @@ -668,6 +669,7 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, bool, Equal)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, int32_t, Equal)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, int64_t, Equal)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, Expand)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, int32_t, Greater)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, int64_t, Greater)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, uint32_t, Greater)>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Status BinaryElementwise<ShouldNotBroadcast>::Prepare(OpKernelContext* context,
return Status::OK();
}

static Status ComputeOutputShape(const std::string& node_name, const TensorShape& lhs_shape, const TensorShape& rhs_shape, TensorShape& out_shape) {
Status ComputeOutputShape(const std::string& node_name, const TensorShape& lhs_shape, const TensorShape& rhs_shape, TensorShape& out_shape) {
size_t lhs_rank = lhs_shape.NumDimensions();
size_t rhs_rank = rhs_shape.NumDimensions();
size_t out_rank = std::max(lhs_rank, rhs_rank);
Expand Down
76 changes: 76 additions & 0 deletions onnxruntime/core/providers/cuda/tensor/expand.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "expand.h"
#include "expand_impl.h"
#include "core/providers/cpu/tensor/utils.h"

namespace onnxruntime {
namespace cuda {

Status Expand::ComputeInternal(OpKernelContext* ctx) const {
const auto& input0 = *ctx->Input<Tensor>(0);
const auto& input1 = *ctx->Input<Tensor>(1);
int device_id = GetDeviceId();

// new shape to be expanded to
const auto* p_shape = input1.template Data<int64_t>();
std::vector<int64_t> output_dims{p_shape, p_shape + input1.Shape().Size()};
TensorShape output_shape(output_dims);

ORT_RETURN_IF_ERROR(ComputeOutputShape(Node().Name(), input0.Shape(), output_dims, output_shape));
auto rank = output_shape.NumDimensions();
auto& output_tensor = *ctx->Output(0, output_shape);
auto input_shape = input0.Shape().GetDims();

// pad input_dims with 1 to make ranks match
for (int i = 0; i < rank - input_shape.size(); i++) {
input_shape.insert(input_shape.begin(), 1);
}

// create fast_divmod using dimension values
CudaAsyncBuffer<fast_divmod> fdm_input_dims(this, device_id, rank);
CudaAsyncBuffer<fast_divmod> fdm_output_dims(this, device_id, rank);
CudaAsyncBuffer<fast_divmod> fdm_output_subdim_size(this, device_id, rank);
{
auto in_span = fdm_input_dims.CpuSpan();
auto out_span = fdm_output_dims.CpuSpan();
auto sdm_span = fdm_output_subdim_size.CpuSpan();
auto subdim_size = output_shape.Size();
for (auto i = 0; i < rank; i++) {
in_span[i] = fast_divmod(static_cast<int>(input_shape[i]));
out_span[i] = fast_divmod(static_cast<int>(output_shape[i]));
subdim_size /= output_shape[i];
sdm_span[i] = static_cast<int>(subdim_size);
}
}
ORT_RETURN_IF_ERROR(fdm_input_dims.CopyToGpu());
ORT_RETURN_IF_ERROR(fdm_output_dims.CopyToGpu());
ORT_RETURN_IF_ERROR(fdm_output_subdim_size.CopyToGpu());

ExpandImpl(
input0.DataType()->Size(),
output_shape.NumDimensions(),
output_shape.Size(),
input0.Shape().Size(),
input0.DataRaw(),
output_tensor.MutableDataRaw(),
fdm_input_dims.GpuPtr(),
fdm_output_dims.GpuPtr(),
fdm_output_subdim_size.GpuPtr());

return Status::OK();
}

ONNX_OPERATOR_KERNEL_EX(
Expand,
kOnnxDomain,
8,
kCudaExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
.InputMemoryType<OrtMemTypeCPUInput>(1),
Expand);

} // namespace cuda
}; // namespace onnxruntime
25 changes: 25 additions & 0 deletions onnxruntime/core/providers/cuda/tensor/expand.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "core/providers/cuda/cuda_common.h"

namespace onnxruntime {
namespace cuda {

class Expand final : public CudaKernel {
public:
Expand(const OpKernelInfo& info) : CudaKernel(info) {}

Status ComputeInternal(OpKernelContext* context) const override;
};

Status ComputeOutputShape(
const std::string& node_name,
const TensorShape& lhs_shape,
const TensorShape& rhs_shape,
TensorShape& out_shape);

} // namespace cuda
} // namespace onnxruntime
96 changes: 96 additions & 0 deletions onnxruntime/core/providers/cuda/tensor/expand_impl.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/cuda/cu_inc/common.cuh"
#include "expand_impl.h"
#include "core/providers/cuda/cuda_common.h"

namespace onnxruntime {
namespace cuda {

template <typename T>
__global__ void ExpandKernel(
const size_t rank,
const size_t N,
const size_t N_input,
const T* input_data,
T* output_data,
const fast_divmod* fdm_input_dims,
const fast_divmod* fdm_output_dims,
const fast_divmod* fdm_output_subdim_size) {
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);

// initialize
auto output_index = id;
auto input_index = 0;
auto input_subdim_size = N_input;
auto out_coord = output_index;
// use striding when tensor is larger than grid
int stride = blockDim.x * gridDim.x;

// translate indices to coordinates. copy expanded dims from source
while (output_index < N) {
for (int64_t i = 0; i < rank; i++) {
input_subdim_size = fdm_input_dims[i].div(input_subdim_size);
auto new_out_coord = fdm_output_subdim_size[i].div(out_coord);
auto in_coord = (new_out_coord > (fdm_input_dims[i].d_ - 1)) ? fdm_input_dims[i].d_ - 1 : new_out_coord;
input_index += input_subdim_size * in_coord;
out_coord -= new_out_coord * fdm_output_subdim_size[i].d_;
}
output_data[output_index] = input_data[input_index];
output_index += stride;
out_coord = output_index;
input_subdim_size = N_input;
input_index = 0;
}
}

Status ExpandImpl(
const size_t element_size,
const size_t rank,
const size_t N,
const size_t N_input,
const void* input_data,
void* output_data,
const fast_divmod* fdm_input_dims,
const fast_divmod* fdm_output_dims,
const fast_divmod* fdm_output_subdim_size) {
int blocksPerGrid = (int)(ceil(static_cast<float>(N) / GridDim::maxThreadsPerBlock));

switch (element_size) {
case sizeof(uint8_t):
ExpandKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
rank, N, N_input,
reinterpret_cast<const ToCudaType<uint8_t>::MappedType*>(input_data),
reinterpret_cast<ToCudaType<uint8_t>::MappedType*>(output_data),
fdm_input_dims, fdm_output_dims, fdm_output_subdim_size);
break;
case sizeof(uint16_t):
ExpandKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
rank, N, N_input,
reinterpret_cast<const ToCudaType<uint16_t>::MappedType*>(input_data),
reinterpret_cast<ToCudaType<uint16_t>::MappedType*>(output_data),
fdm_input_dims, fdm_output_dims, fdm_output_subdim_size);
break;
case sizeof(uint32_t):
ExpandKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
rank, N, N_input,
reinterpret_cast<const ToCudaType<uint32_t>::MappedType*>(input_data),
reinterpret_cast<ToCudaType<uint32_t>::MappedType*>(output_data),
fdm_input_dims, fdm_output_dims, fdm_output_subdim_size);
break;
case sizeof(uint64_t):
ExpandKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
rank, N, N_input,
reinterpret_cast<const ToCudaType<uint64_t>::MappedType*>(input_data),
reinterpret_cast<ToCudaType<uint64_t>::MappedType*>(output_data),
fdm_input_dims, fdm_output_dims, fdm_output_subdim_size);
break;
default:
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for Expand operator");
}
return Status::OK();
}

} // namespace cuda
} // namespace onnxruntime
25 changes: 25 additions & 0 deletions onnxruntime/core/providers/cuda/tensor/expand_impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include <stdint.h>
#include "core/providers/cuda/shared_inc/cuda_utils.h"
#include "core/framework/data_types.h"
#include "core/common/common.h"

namespace onnxruntime {
namespace cuda {

Status ExpandImpl(
const size_t element_size,
const size_t shape_rank,
const size_t N,
const size_t N_input,
const void* input_data,
void* output_data,
const fast_divmod* fdm_input_dims,
const fast_divmod* fdm_output_dims,
const fast_divmod* fdm_output_subdim_size);

} // namespace cuda
} // namespace onnxruntime
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cuda/tensor/slice_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
namespace onnxruntime {
namespace cuda {

template <typename T>
template <typename T >
__global__ void _SliceKernel(const int32_t dimension_count,
const int64_t* starts,
const int64_t* steps,
Expand Down
11 changes: 11 additions & 0 deletions onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -950,6 +950,17 @@ TEST(MathOpTest, Expand_8_1x3_int64) {
test.Run();
}

TEST(MathOpTest, Expand_8_3x1x3x1_int64) {
OpTester test("Expand", 8);
test.AddInput<int64_t>("data_0", {1, 3, 1, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
test.AddInput<int64_t>("data_1", {4}, {3, 1, 3, 1});
test.AddOutput<int64_t>("result", {3, 3, 3, 3},
{1, 2, 3, 1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6, 4, 5, 6, 7, 8, 9, 7, 8, 9, 7, 8, 9,
1, 2, 3, 1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6, 4, 5, 6, 7, 8, 9, 7, 8, 9, 7, 8, 9,
1, 2, 3, 1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6, 4, 5, 6, 7, 8, 9, 7, 8, 9, 7, 8, 9,});
test.Run();
}

TEST(MathOpTest, Expand_8_3x3_float16) {
OpTester test("Expand", 8);
test.AddInput<MLFloat16>("data_0", {1}, {MLFloat16(math::floatToHalf(1.0f))});
Expand Down

0 comments on commit 59de37a

Please sign in to comment.