This repository has been archived by the owner on Jan 3, 2023. It is now read-only.
forked from microsoft/onnxruntime
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add CUDA Expand operator (microsoft#1292)
* Add CUDA expand operator * Reset counter variables when striding * Reset counter variables when striding * use fast_divmod and other PR comments * Fix merge variable rename * Fix indentation per PR comment * Remove maxpool_argmax * Reduce number of type templates for Expand operator * removed all types * Commit updated cuda_execution_provider.cc
- Loading branch information
Showing
8 changed files
with
237 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#include "expand.h" | ||
#include "expand_impl.h" | ||
#include "core/providers/cpu/tensor/utils.h" | ||
|
||
namespace onnxruntime { | ||
namespace cuda { | ||
|
||
Status Expand::ComputeInternal(OpKernelContext* ctx) const { | ||
const auto& input0 = *ctx->Input<Tensor>(0); | ||
const auto& input1 = *ctx->Input<Tensor>(1); | ||
int device_id = GetDeviceId(); | ||
|
||
// new shape to be expanded to | ||
const auto* p_shape = input1.template Data<int64_t>(); | ||
std::vector<int64_t> output_dims{p_shape, p_shape + input1.Shape().Size()}; | ||
TensorShape output_shape(output_dims); | ||
|
||
ORT_RETURN_IF_ERROR(ComputeOutputShape(Node().Name(), input0.Shape(), output_dims, output_shape)); | ||
auto rank = output_shape.NumDimensions(); | ||
auto& output_tensor = *ctx->Output(0, output_shape); | ||
auto input_shape = input0.Shape().GetDims(); | ||
|
||
// pad input_dims with 1 to make ranks match | ||
for (int i = 0; i < rank - input_shape.size(); i++) { | ||
input_shape.insert(input_shape.begin(), 1); | ||
} | ||
|
||
// create fast_divmod using dimension values | ||
CudaAsyncBuffer<fast_divmod> fdm_input_dims(this, device_id, rank); | ||
CudaAsyncBuffer<fast_divmod> fdm_output_dims(this, device_id, rank); | ||
CudaAsyncBuffer<fast_divmod> fdm_output_subdim_size(this, device_id, rank); | ||
{ | ||
auto in_span = fdm_input_dims.CpuSpan(); | ||
auto out_span = fdm_output_dims.CpuSpan(); | ||
auto sdm_span = fdm_output_subdim_size.CpuSpan(); | ||
auto subdim_size = output_shape.Size(); | ||
for (auto i = 0; i < rank; i++) { | ||
in_span[i] = fast_divmod(static_cast<int>(input_shape[i])); | ||
out_span[i] = fast_divmod(static_cast<int>(output_shape[i])); | ||
subdim_size /= output_shape[i]; | ||
sdm_span[i] = static_cast<int>(subdim_size); | ||
} | ||
} | ||
ORT_RETURN_IF_ERROR(fdm_input_dims.CopyToGpu()); | ||
ORT_RETURN_IF_ERROR(fdm_output_dims.CopyToGpu()); | ||
ORT_RETURN_IF_ERROR(fdm_output_subdim_size.CopyToGpu()); | ||
|
||
ExpandImpl( | ||
input0.DataType()->Size(), | ||
output_shape.NumDimensions(), | ||
output_shape.Size(), | ||
input0.Shape().Size(), | ||
input0.DataRaw(), | ||
output_tensor.MutableDataRaw(), | ||
fdm_input_dims.GpuPtr(), | ||
fdm_output_dims.GpuPtr(), | ||
fdm_output_subdim_size.GpuPtr()); | ||
|
||
return Status::OK(); | ||
} | ||
|
||
ONNX_OPERATOR_KERNEL_EX( | ||
Expand, | ||
kOnnxDomain, | ||
8, | ||
kCudaExecutionProvider, | ||
KernelDefBuilder() | ||
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) | ||
.InputMemoryType<OrtMemTypeCPUInput>(1), | ||
Expand); | ||
|
||
} // namespace cuda | ||
}; // namespace onnxruntime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#include "core/common/common.h" | ||
#include "core/framework/op_kernel.h" | ||
#include "core/providers/cuda/cuda_common.h" | ||
|
||
namespace onnxruntime { | ||
namespace cuda { | ||
|
||
class Expand final : public CudaKernel { | ||
public: | ||
Expand(const OpKernelInfo& info) : CudaKernel(info) {} | ||
|
||
Status ComputeInternal(OpKernelContext* context) const override; | ||
}; | ||
|
||
Status ComputeOutputShape( | ||
const std::string& node_name, | ||
const TensorShape& lhs_shape, | ||
const TensorShape& rhs_shape, | ||
TensorShape& out_shape); | ||
|
||
} // namespace cuda | ||
} // namespace onnxruntime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#include "core/providers/cuda/cu_inc/common.cuh" | ||
#include "expand_impl.h" | ||
#include "core/providers/cuda/cuda_common.h" | ||
|
||
namespace onnxruntime { | ||
namespace cuda { | ||
|
||
template <typename T> | ||
__global__ void ExpandKernel( | ||
const size_t rank, | ||
const size_t N, | ||
const size_t N_input, | ||
const T* input_data, | ||
T* output_data, | ||
const fast_divmod* fdm_input_dims, | ||
const fast_divmod* fdm_output_dims, | ||
const fast_divmod* fdm_output_subdim_size) { | ||
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); | ||
|
||
// initialize | ||
auto output_index = id; | ||
auto input_index = 0; | ||
auto input_subdim_size = N_input; | ||
auto out_coord = output_index; | ||
// use striding when tensor is larger than grid | ||
int stride = blockDim.x * gridDim.x; | ||
|
||
// translate indices to coordinates. copy expanded dims from source | ||
while (output_index < N) { | ||
for (int64_t i = 0; i < rank; i++) { | ||
input_subdim_size = fdm_input_dims[i].div(input_subdim_size); | ||
auto new_out_coord = fdm_output_subdim_size[i].div(out_coord); | ||
auto in_coord = (new_out_coord > (fdm_input_dims[i].d_ - 1)) ? fdm_input_dims[i].d_ - 1 : new_out_coord; | ||
input_index += input_subdim_size * in_coord; | ||
out_coord -= new_out_coord * fdm_output_subdim_size[i].d_; | ||
} | ||
output_data[output_index] = input_data[input_index]; | ||
output_index += stride; | ||
out_coord = output_index; | ||
input_subdim_size = N_input; | ||
input_index = 0; | ||
} | ||
} | ||
|
||
Status ExpandImpl( | ||
const size_t element_size, | ||
const size_t rank, | ||
const size_t N, | ||
const size_t N_input, | ||
const void* input_data, | ||
void* output_data, | ||
const fast_divmod* fdm_input_dims, | ||
const fast_divmod* fdm_output_dims, | ||
const fast_divmod* fdm_output_subdim_size) { | ||
int blocksPerGrid = (int)(ceil(static_cast<float>(N) / GridDim::maxThreadsPerBlock)); | ||
|
||
switch (element_size) { | ||
case sizeof(uint8_t): | ||
ExpandKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>( | ||
rank, N, N_input, | ||
reinterpret_cast<const ToCudaType<uint8_t>::MappedType*>(input_data), | ||
reinterpret_cast<ToCudaType<uint8_t>::MappedType*>(output_data), | ||
fdm_input_dims, fdm_output_dims, fdm_output_subdim_size); | ||
break; | ||
case sizeof(uint16_t): | ||
ExpandKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>( | ||
rank, N, N_input, | ||
reinterpret_cast<const ToCudaType<uint16_t>::MappedType*>(input_data), | ||
reinterpret_cast<ToCudaType<uint16_t>::MappedType*>(output_data), | ||
fdm_input_dims, fdm_output_dims, fdm_output_subdim_size); | ||
break; | ||
case sizeof(uint32_t): | ||
ExpandKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>( | ||
rank, N, N_input, | ||
reinterpret_cast<const ToCudaType<uint32_t>::MappedType*>(input_data), | ||
reinterpret_cast<ToCudaType<uint32_t>::MappedType*>(output_data), | ||
fdm_input_dims, fdm_output_dims, fdm_output_subdim_size); | ||
break; | ||
case sizeof(uint64_t): | ||
ExpandKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>( | ||
rank, N, N_input, | ||
reinterpret_cast<const ToCudaType<uint64_t>::MappedType*>(input_data), | ||
reinterpret_cast<ToCudaType<uint64_t>::MappedType*>(output_data), | ||
fdm_input_dims, fdm_output_dims, fdm_output_subdim_size); | ||
break; | ||
default: | ||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for Expand operator"); | ||
} | ||
return Status::OK(); | ||
} | ||
|
||
} // namespace cuda | ||
} // namespace onnxruntime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
#include <stdint.h> | ||
#include "core/providers/cuda/shared_inc/cuda_utils.h" | ||
#include "core/framework/data_types.h" | ||
#include "core/common/common.h" | ||
|
||
namespace onnxruntime { | ||
namespace cuda { | ||
|
||
Status ExpandImpl( | ||
const size_t element_size, | ||
const size_t shape_rank, | ||
const size_t N, | ||
const size_t N_input, | ||
const void* input_data, | ||
void* output_data, | ||
const fast_divmod* fdm_input_dims, | ||
const fast_divmod* fdm_output_dims, | ||
const fast_divmod* fdm_output_subdim_size); | ||
|
||
} // namespace cuda | ||
} // namespace onnxruntime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters