Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CUDA Expand operator #1292

Merged
merged 11 commits into from
Jun 27, 2019
24 changes: 24 additions & 0 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,18 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, bool, Equal);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, int32_t, Equal);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, int64_t, Equal);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, float, Expand);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, double, Expand);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, int8_t, Expand);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, int16_t, Expand);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, int32_t, Expand);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, int64_t, Expand);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, uint8_t, Expand);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, uint16_t, Expand);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, uint32_t, Expand);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, uint64_t, Expand);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, bool, Expand);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, MLFloat16, Expand);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, int32_t, Greater);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, int64_t, Greater);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, uint32_t, Greater);
Expand Down Expand Up @@ -693,6 +705,18 @@ static void RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, bool, Equal)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, int32_t, Equal)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, int64_t, Equal)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, float, Expand)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, double, Expand)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, int8_t, Expand)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, int16_t, Expand)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, int32_t, Expand)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, int64_t, Expand)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, uint8_t, Expand)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, uint16_t, Expand)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, uint32_t, Expand)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, uint64_t, Expand)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, bool, Expand)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 8, MLFloat16, Expand)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, int32_t, Greater)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, int64_t, Greater)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, uint32_t, Greater)>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Status BinaryElementwise<ShouldNotBroadcast>::Prepare(OpKernelContext* context,
return Status::OK();
}

static Status ComputeOutputShape(const std::string& node_name, const TensorShape& lhs_shape, const TensorShape& rhs_shape, TensorShape& out_shape) {
Status ComputeOutputShape(const std::string& node_name, const TensorShape& lhs_shape, const TensorShape& rhs_shape, TensorShape& out_shape) {
size_t lhs_rank = lhs_shape.NumDimensions();
size_t rhs_rank = rhs_shape.NumDimensions();
size_t out_rank = std::max(lhs_rank, rhs_rank);
Expand Down
89 changes: 89 additions & 0 deletions onnxruntime/core/providers/cuda/tensor/expand.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "expand.h"
#include "expand_impl.h"
#include "core/providers/cpu/tensor/utils.h"

namespace onnxruntime {
namespace cuda {

#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
Expand, \
kOnnxDomain, \
8, \
T, \
kCudaExecutionProvider, \
KernelDefBuilder() \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
.InputMemoryType<OrtMemTypeCPUInput>(1), \
Expand<T>);
jignparm marked this conversation as resolved.
Show resolved Hide resolved

template <typename T>
Status Expand<T>::ComputeInternal(OpKernelContext* ctx) const {
const auto& input0 = *ctx->Input<Tensor>(0);
const auto& input1 = *ctx->Input<Tensor>(1);
int device_id = GetDeviceId();

// new shape to be expanded to
const auto* p_shape = input1.template Data<int64_t>();
std::vector<int64_t> output_dims{p_shape, p_shape + input1.Shape().Size()};
TensorShape output_shape(output_dims);

ORT_RETURN_IF_ERROR(ComputeOutputShape(Node().Name(), input0.Shape(), output_dims, output_shape));

// pad input_dims with 1 to make ranks match
auto rank = output_shape.NumDimensions();
auto& output_tensor = *ctx->Output(0, output_shape);
auto input_shape = input0.Shape().GetDims();
for (int i = 0; i < rank - input_shape.size(); i++) {
input_shape.insert(input_shape.begin(), 1);
}

// create fast_divmod using dimension values
CudaAsyncBuffer<fast_divmod> fdm_input_dims_gpu(this, device_id, rank);
CudaAsyncBuffer<fast_divmod> fdm_output_dims_gpu(this, device_id, rank);
{
auto in_span = fdm_input_dims_gpu.CpuSpan();
auto out_span = fdm_output_dims_gpu.CpuSpan();
for (auto i = 0; i < rank; i++) {
in_span[i] = fast_divmod(static_cast<int>(input_shape[i]));
out_span[i] = fast_divmod(static_cast<int>(output_shape[i]));
}
}

ORT_RETURN_IF_ERROR(fdm_input_dims_gpu.CopyToGpu());
ORT_RETURN_IF_ERROR(fdm_output_dims_gpu.CopyToGpu());

ExpandImpl(
output_tensor.Shape().NumDimensions(),
output_tensor.Shape().Size(),
input0.Shape().Size(),
reinterpret_cast<const typename ToCudaType<T>::MappedType*>(input0.template Data<T>()),
reinterpret_cast<typename ToCudaType<T>::MappedType*>(output_tensor.template MutableData<T>()),
fdm_input_dims_gpu.GpuPtr(),
fdm_output_dims_gpu.GpuPtr());

return Status::OK();
}

#define SPECIALIZED_COMPUTE(T) \
REGISTER_KERNEL_TYPED(T) \
template Status Expand<T>::ComputeInternal(OpKernelContext* ctx) const;

SPECIALIZED_COMPUTE(float)
SPECIALIZED_COMPUTE(double)
SPECIALIZED_COMPUTE(int8_t)
SPECIALIZED_COMPUTE(int16_t)
SPECIALIZED_COMPUTE(int32_t)
SPECIALIZED_COMPUTE(int64_t)
SPECIALIZED_COMPUTE(uint8_t)
SPECIALIZED_COMPUTE(uint16_t)
SPECIALIZED_COMPUTE(uint32_t)
SPECIALIZED_COMPUTE(uint64_t)
SPECIALIZED_COMPUTE(bool)
SPECIALIZED_COMPUTE(MLFloat16)

} // namespace cuda
}; // namespace onnxruntime
26 changes: 26 additions & 0 deletions onnxruntime/core/providers/cuda/tensor/expand.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "core/providers/cuda/cuda_common.h"

namespace onnxruntime {
namespace cuda {

template <typename T>
class Expand final : public CudaKernel {
public:
Expand(const OpKernelInfo& info) : CudaKernel(info) {}

Status ComputeInternal(OpKernelContext* context) const override;
};

Status ComputeOutputShape(
const std::string& node_name,
const TensorShape& lhs_shape,
const TensorShape& rhs_shape,
TensorShape& out_shape);

} // namespace cuda
} // namespace onnxruntime
80 changes: 80 additions & 0 deletions onnxruntime/core/providers/cuda/tensor/expand_impl.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/cuda/cu_inc/common.cuh"
#include "expand_impl.h"

namespace onnxruntime {
namespace cuda {

template <typename T>
__global__ void _ExpandKernel(
jignparm marked this conversation as resolved.
Show resolved Hide resolved
const size_t rank,
const size_t N,
const size_t N_input,
const T* input_data,
T* output_data,
const fast_divmod* fdm_input_dims,
const fast_divmod* fdm_output_dims) {
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
// initialize
int64_t output_index = id;
int64_t input_index = 0;

// use striding when tensor is larger than grid
int stride = blockDim.x * gridDim.x;
auto outputSubDimSize = N;
auto inputSubDimSize = N_input;
auto out_coord = output_index;

// translate indices to coordinates. copy expanded dims from source
while (output_index < N) {
for (int64_t i = 0; i < rank; i++) {
outputSubDimSize = fdm_output_dims[i].div(outputSubDimSize);
inputSubDimSize = fdm_input_dims[i].div(inputSubDimSize);
auto new_out_coord = out_coord / outputSubDimSize;
auto in_coord = (new_out_coord > (fdm_input_dims[i].d_ - 1)) ? fdm_input_dims[i].d_ - 1 : new_out_coord;
input_index += inputSubDimSize * in_coord;
out_coord -= new_out_coord * outputSubDimSize;
}
output_data[output_index] = input_data[input_index];
output_index += stride;
out_coord = output_index;
outputSubDimSize = N;
inputSubDimSize = N_input;
input_index = 0;
}
}

template <typename T>
void ExpandImpl(
const size_t rank,
const size_t N,
const size_t N_input,
const T* input_data,
T* output_data,
const fast_divmod* fdm_input_dims,
const fast_divmod* fdm_output_dims) {
int blocksPerGrid = (int)(ceil(static_cast<float>(N) / GridDim::maxThreadsPerBlock));
_ExpandKernel<T><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
rank, N, N_input, input_data, output_data, fdm_input_dims, fdm_output_dims);
}

#define SPECIALIZED_IMPL(T) \
template void ExpandImpl<T>(const size_t rank, const size_t N, const size_t N_input, const T* input_data, T* output_data, const fast_divmod* fdm_input_dims, const fast_divmod* fdm_output_dims);

SPECIALIZED_IMPL(float)
SPECIALIZED_IMPL(double)
SPECIALIZED_IMPL(int8_t)
SPECIALIZED_IMPL(int16_t)
SPECIALIZED_IMPL(int32_t)
SPECIALIZED_IMPL(int64_t)
SPECIALIZED_IMPL(uint8_t)
SPECIALIZED_IMPL(uint16_t)
SPECIALIZED_IMPL(uint32_t)
SPECIALIZED_IMPL(uint64_t)
SPECIALIZED_IMPL(bool)
SPECIALIZED_IMPL(half)

} // namespace cuda
} // namespace onnxruntime
23 changes: 23 additions & 0 deletions onnxruntime/core/providers/cuda/tensor/expand_impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include <stdint.h>
#include "core/providers/cuda/shared_inc/cuda_utils.h"
#include "core/framework/data_types.h"

namespace onnxruntime {
namespace cuda {

template <typename T>
void ExpandImpl(
const size_t shape_rank,
const size_t N,
const size_t N_input,
const T* input_data,
T* output_data,
const fast_divmod* fdm_input_dims,
const fast_divmod* fdm_output_dims);

} // namespace cuda
} // namespace onnxruntime
11 changes: 11 additions & 0 deletions onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,17 @@ TEST(MathOpTest, Expand_8_1x3_int64) {
test.Run();
}

TEST(MathOpTest, Expand_8_3x1x3x1_int64) {
OpTester test("Expand", 8);
test.AddInput<int64_t>("data_0", {1, 3, 1, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
test.AddInput<int64_t>("data_1", {4}, {3, 1, 3, 1});
test.AddOutput<int64_t>("result", {3, 3, 3, 3},
{1, 2, 3, 1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6, 4, 5, 6, 7, 8, 9, 7, 8, 9, 7, 8, 9,
1, 2, 3, 1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6, 4, 5, 6, 7, 8, 9, 7, 8, 9, 7, 8, 9,
1, 2, 3, 1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6, 4, 5, 6, 7, 8, 9, 7, 8, 9, 7, 8, 9,});
test.Run();
}

TEST(MathOpTest, Expand_8_3x3_float16) {
OpTester test("Expand", 8);
test.AddInput<MLFloat16>("data_0", {1}, {MLFloat16(math::floatToHalf(1.0f))});
Expand Down