Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support 'Bilinear' mode for 2D inputs in Resize and Upsample kernels #1679

Merged
merged 7 commits into from
Aug 29, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 45 additions & 24 deletions onnxruntime/core/providers/cpu/tensor/upsample.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "core/providers/cpu/tensor/upsample.h"
#include <cmath>
#include <sstream>

using namespace onnxruntime::common;
using namespace std;
Expand Down Expand Up @@ -61,14 +62,18 @@ Status UpsampleNearest(const T* input,
T* output,
const TensorShape& input_shape,
const TensorShape& output_shape,
const vector<float>& scales) {
const vector<float>& scales,
bool is_resize) {
if (!input || !output)
return Status(ONNXRUNTIME, FAIL, "Upsample: input/output value is nullptr");
return Status(ONNXRUNTIME, FAIL, is_resize ? "Resize: input/output value is nullptr" :
"Upsample: input/output value is nullptr");
if (input_shape.NumDimensions() != output_shape.NumDimensions())
return Status(ONNXRUNTIME, FAIL, "Upsample: input/output value's dimension mismatch");
return Status(ONNXRUNTIME, FAIL, is_resize ? "Resize: input/output value's dimension mismatch" :
"Upsample: input/output value's dimension mismatch");
if (input_shape.NumDimensions() == 0) {
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
"Upsample: input shape needs to be at least a single dimension.");
is_resize ? "Resize: input shape needs to be at least a single dimension" :
"Upsample: input shape needs to be at least a single dimension.");
}

int64_t n_dim = static_cast<int64_t>(input_shape.NumDimensions());
Expand Down Expand Up @@ -192,11 +197,14 @@ Status upsampleLiner(const T* input,
T* output,
const TensorShape& input_shape,
const TensorShape& output_shape,
const vector<float>& scales) {
const vector<float>& scales,
bool is_resize) {
if (!input || !output)
return Status(ONNXRUNTIME, FAIL, "Upsample: input/output value is nullptr");
return Status(ONNXRUNTIME, FAIL, is_resize ? "Resize: input / output value is nullptr" :
"Upsample: input / output value is nullptr");
if (input_shape.NumDimensions() != output_shape.NumDimensions())
return Status(ONNXRUNTIME, FAIL, "Upsample: input/output value's dimension mismatch");
return Status(ONNXRUNTIME, FAIL, is_resize ? "Resize: input/output value's dimension mismatch" :
"Upsample: input/output value's dimension mismatch");
auto n_dim = input_shape.NumDimensions();
for (size_t i = 0, size = output_shape.Size(); i < size; i++) {
std::vector<int64_t> val1;
Expand Down Expand Up @@ -242,6 +250,11 @@ Status upsampleLiner(const T* input,
return Status::OK();
}

// The following method supports a 4-D input in 'Linear mode'
// that amounts to 'Bilinear' Upsampling/Resizing in the sense that it assumes
// the scale values for the outermost 2 dimensions are 1.
// This is the common use-case where the 4-D input (batched multi-channel images)
// is usually of shape [N, C, H, W] and the scales are [1.0, 1.0, height_scale, width_scale]
template <typename T>
void upsampleBilinear(
int64_t batch_size,
Expand Down Expand Up @@ -327,9 +340,10 @@ Status Upsample<T>::BaseCompute(OpKernelContext* context, const std::vector<floa
ORT_ENFORCE(X != nullptr);

const std::vector<int64_t>& dims = X->Shape().GetDims();
if (dims.size() != scales.size()) {
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Upsample: input tensor's dimension does not match the scales.");
}
if (dims.size() != scales.size())
return Status(ONNXRUNTIME, INVALID_ARGUMENT,
is_resize ? "Resize: input tensor's dimension does not match the scales." :
"Upsample: input tensor's dimension does not match the scales.");

bool no_scale = true;
std::vector<int64_t> Y_dims;
Expand All @@ -348,26 +362,33 @@ Status Upsample<T>::BaseCompute(OpKernelContext* context, const std::vector<floa

switch (mode_) {
case UpsampleMode::NN:
return UpsampleNearest<T>(X->template Data<T>(), Y->template MutableData<T>(), X->Shape(), Y->Shape(), scales);
return UpsampleNearest<T>(X->template Data<T>(), Y->template MutableData<T>(), X->Shape(), Y->Shape(), scales, is_resize);
case UpsampleMode::LINEAR: {
//What's the correct behavior of linear mode is not clear right now,
//Only support bilinear with 4D tensor to keep consistent with previous behavior
if (dims.size() != 4)
return Status(ONNXRUNTIME, FAIL, "Upsample: linear mode upsample only support 4-D tensor with NCHW layout");
//The correct behavior of 'linear' mode for an N-D input is not clear right now,
//so only support 'bilinear' with 2-D or 4-D input tensor with outermost 2 scales as 1 in the 4-D case
if (dims.size() != 2 && dims.size() != 4) {
std::ostringstream oss;
oss << "'Linear' mode only support 2-D inputs ('Bilinear') or 4-D inputs "
"with the corresponding outermost 2 scale values being 1 in the ";
oss << (is_resize ? "Resize operator" : "Upsample operator");
return Status(ONNXRUNTIME, FAIL, oss.str());
}

const int64_t batch_size = dims[0];
const int64_t num_channels = dims[1];
const int64_t input_height = dims[2];
const int64_t input_width = dims[3];
bool is_2D = dims.size() == 2;
const int64_t batch_size = is_2D ? 1 : dims[0];
const int64_t num_channels = is_2D ? 1 : dims[1];
const int64_t input_height = is_2D ? dims[0] : dims[2];
const int64_t input_width = is_2D ? dims[1] : dims[3];

AllocatorPtr alloc;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc));
upsampleBilinear(batch_size, num_channels, input_height, input_width,
scales[2], scales[3], X->template Data<T>(), Y->template MutableData<T>(), alloc);
is_2D ? scales[0] : scales[2], is_2D ? scales[1] : scales[3],
X->template Data<T>(), Y->template MutableData<T>(), alloc);
return Status::OK();
}
default:
return Status(ONNXRUNTIME, FAIL, "Upsample: unexpected mode");
return Status(ONNXRUNTIME, FAIL, is_resize ? "Resize: unexpected mode" : "Upsample: unexpected mode");
}
}

Expand All @@ -380,9 +401,9 @@ Status Upsample<T>::Compute(OpKernelContext* context) const {
const auto* scales = context->Input<Tensor>(1);
ORT_ENFORCE(scales != nullptr);
int64_t scales_size = scales->Shape().Size();
std::vector<float> scales_arrary(scales_size);
ParseScalesData(scales, scales_arrary);
return BaseCompute(context, scales_arrary);
std::vector<float> scales_array(scales_size);
ParseScalesData(scales, scales_array);
return BaseCompute(context, scales_array);
}

} // namespace onnxruntime
7 changes: 4 additions & 3 deletions onnxruntime/core/providers/cpu/tensor/upsample.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,10 @@ class UpsampleBase {
}

if (UpsampleMode::LINEAR == mode) {
ORT_ENFORCE(scales.size() == 4, "Upsample: linear mode upsample only support bilinear with 4 dimension.");
ORT_ENFORCE(((scales[0] == 1) && (scales[1] == 1)),
"Upsample: linear mode upsample only support bilinear, the first 2 scales should be 1.");
ORT_ENFORCE(scales.size() == 2 || (scales.size() == 4 && scales[0] == 1 && scales[1] == 1),
"'Linear' mode only support 2-D inputs ('Bilinear') or 4-D inputs "
"with the corresponding outermost 2 scale values being 1 in the ",
is_resize ? "Resize operator" : "Upsample operator");
}
}

Expand Down
71 changes: 68 additions & 3 deletions onnxruntime/core/providers/cuda/tensor/resize_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,13 @@ __global__ void _ResizeNearestKernel(const size_t rank,
output_data[id] = input_data[input_index];
}

// The following method supports a 4-D input in 'Linear mode'
// that amounts to 'Bilinear' Upsampling/Resizing in the sense that it assumes
// the scale values for the outermost 2 dimensions are 1.
// This is the common use-case where the 4-D input (batched multi-channel images)
// is usually of shape [N, C, H, W] and the scales are [1.0, 1.0, height_scale, width_scale]
template <typename T>
__global__ void _ResizeBilinearKernel(const int64_t input_dim2,
__global__ void _ResizeBilinear4DInputKernel(const int64_t input_dim2,
const int64_t* input_pitches,
const fast_divmod* output_div_pitches,
const float* scales,
Expand Down Expand Up @@ -90,6 +95,62 @@ __global__ void _ResizeBilinearKernel(const int64_t input_dim2,
x11 * static_cast<T>(y_offset_0 * x_offset_0);
}

// The following method supports a 2-D input in 'Linear mode'
template <typename T>
__global__ void _ResizeBilinear2DInputKernel(const int64_t input_dim0,
const int64_t* input_pitches,
const fast_divmod* output_div_pitches,
const float* scales,
const T* input_data,
T* output_data,
const size_t N) {
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
CUDA_LONG input_index = 0;

int mod;
int index_of_dim0, index_of_dim1;
output_div_pitches[0].divmod(id, index_of_dim0, mod);
index_of_dim1 = mod;
int index_of_input_dim0, index_of_input_dim1;
float x_offset_0, y_offset_0, x_offset_1, y_offset_1;
index_of_input_dim0 = static_cast<int64_t>(index_of_dim0 / scales[0]);
index_of_input_dim1 = static_cast<int64_t>(index_of_dim1 / scales[1]);
input_index = index_of_input_dim0 * input_pitches[0] + index_of_input_dim1;

T x00 = input_data[input_index];
T x10, x01, x11;

bool end_of_dim0 = false, end_of_dim1 = false;
if (index_of_input_dim0 == (input_dim0 - 1)) {
// It's the end in dimension 0
x01 = x00;
end_of_dim0 = true;
} else {
x01 = input_data[input_index + input_pitches[0]];
}

if (index_of_input_dim1 == (input_pitches[0] - 1)) {
// It's the end in dimension 1
x10 = x00;
x11 = x01;
end_of_dim1 = true;
} else {
x10 = input_data[input_index + 1];
x11 = end_of_dim0 ? x10 : input_data[input_index + input_pitches[0] + 1];
}

y_offset_0 = end_of_dim0 ? 0.5f : index_of_dim0 / scales[0] - index_of_input_dim0;
y_offset_1 = 1.0f - y_offset_0;
x_offset_0 = end_of_dim1 ? 0.5f : index_of_dim1 / scales[1] - index_of_input_dim1;
x_offset_1 = 1.0f - x_offset_0;

output_data[id] =
x00 * static_cast<T>(y_offset_1 * x_offset_1) +
x01 * static_cast<T>(y_offset_0 * x_offset_1) +
x10 * static_cast<T>(y_offset_1 * x_offset_0) +
x11 * static_cast<T>(y_offset_0 * x_offset_0);
}

template <typename T>
void ResizeImpl(const onnxruntime::UpsampleMode upsample_mode,
const size_t rank,
Expand All @@ -105,8 +166,12 @@ void ResizeImpl(const onnxruntime::UpsampleMode upsample_mode,
_ResizeNearestKernel<T><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
rank, input_pitches, output_div_pitches, scales_vals,
input_data, output_data, N);
} else if (onnxruntime::UpsampleMode::LINEAR == upsample_mode) {
_ResizeBilinearKernel<T><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
} else if (onnxruntime::UpsampleMode::LINEAR == upsample_mode && rank == 4) {
_ResizeBilinear4DInputKernel<T><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
input_dim2, input_pitches, output_div_pitches, scales_vals,
input_data, output_data, N);
} else if (onnxruntime::UpsampleMode::LINEAR == upsample_mode && rank == 2) {
_ResizeBilinear2DInputKernel<T><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
input_dim2, input_pitches, output_div_pitches, scales_vals,
input_data, output_data, N);
}
Expand Down
28 changes: 15 additions & 13 deletions onnxruntime/core/providers/cuda/tensor/upsample.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,21 @@ Status Upsample<T>::BaseCompute(OpKernelContext* context, const std::vector<floa
const std::vector<int64_t>& X_dims = X->Shape().GetDims();
auto rank = X_dims.size();
if (rank == 0)
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Upsample: input tensor cannot be scalar.");
return Status(ONNXRUNTIME, INVALID_ARGUMENT,
is_resize ? "Resize: input tensor cannot be scalar." : "Upsample: input tensor cannot be scalar.");

if (rank != scales.size())
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Upsample: input tensor's dimension does not match the scales.");
return Status(ONNXRUNTIME, INVALID_ARGUMENT,
is_resize ? "Resize: input tensor's dimension does not match the scales." :
"Upsample: input tensor's dimension does not match the scales.");

if (UpsampleMode::LINEAR == mode_ && rank != 4 && rank != 2) {
std::ostringstream oss;
oss << "'Linear' mode only support 2-D inputs ('Bilinear') or 4-D inputs "
"with the corresponding outermost 2 scale values being 1 in the ";
oss << (is_resize ? "Resize operator" : "Upsample operator");
return Status(ONNXRUNTIME, FAIL, oss.str());
}

std::vector<int64_t> Y_dims;
for (std::size_t i = 0; i < rank; i++) {
Expand Down Expand Up @@ -69,21 +80,12 @@ Status Upsample<T>::BaseCompute(OpKernelContext* context, const std::vector<floa

size_t output_count = Y->Shape().Size();

if (UpsampleMode::LINEAR == mode_) {
if (rank != 4)
if (is_resize) {
return Status(ONNXRUNTIME, FAIL, "Resize: linear mode only supports 4-D tensor with NCHW layout");
} else {
return Status(ONNXRUNTIME, FAIL, "Upsample: linear mode only supports 4-D tensor with NCHW layout");
}
}

if (is_resize) {
CudaAsyncBuffer<float> scales_vals(this, device_id, scales);
scales_vals.CopyToGpu();
ResizeImpl(mode_,
rank,
(UpsampleMode::LINEAR == mode_) ? X_dims[2] : 0,
(UpsampleMode::LINEAR == mode_) ? (rank == 2 ? X_dims[0] : X_dims[2]) : 0,
input_strides.GpuPtr(),
output_div_pitches.GpuPtr(),
scales_vals.GpuPtr(),
Expand All @@ -101,7 +103,7 @@ Status Upsample<T>::BaseCompute(OpKernelContext* context, const std::vector<floa

UpampleImpl(mode_,
rank,
(UpsampleMode::LINEAR == mode_) ? X_dims[2] : 0,
(UpsampleMode::LINEAR == mode_) ? (rank == 2 ? X_dims[0] : X_dims[2]) : 0,
input_strides.GpuPtr(),
output_div_pitches.GpuPtr(),
scales_div.GpuPtr(),
Expand Down
68 changes: 65 additions & 3 deletions onnxruntime/core/providers/cuda/tensor/upsample_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,13 @@ __global__ void _UpampleNearestKernel(const size_t rank,
output_data[id] = input_data[input_index];
}

// The following method supports a 4-D input in 'Linear mode'
// that amounts to 'Bilinear' Upsampling/Resizing in the sense that it assumes
// the scale values for the outermost 2 dimensions are 1.
// This is the common use-case where the 4-D input (batched multi-channel images)
// is usually of shape [N, C, H, W] and the scales are [1.0, 1.0, height_scale, width_scale]
template <typename T>
__global__ void _UpampleBilinearKernel(const int64_t input_dim2,
__global__ void _UpampleBilinear4DInputKernel(const int64_t input_dim2,
const int64_t* input_pitches,
const fast_divmod* output_div_pitches,
const fast_divmod* scales_div,
Expand Down Expand Up @@ -90,6 +95,59 @@ __global__ void _UpampleBilinearKernel(const int64_t input_dim2,
output_data[id] = y0 + static_cast<T>(x_offset_T * (y1 - y0) / scales_div3_T);
}

// The following method supports a 2-D input in 'Linear mode'
template <typename T>
__global__ void _UpampleBilinear2DInputKernel(const int64_t input_dim0,
const int64_t* input_pitches,
const fast_divmod* output_div_pitches,
const fast_divmod* scales_div,
const T* input_data,
T* output_data,
const size_t N) {
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
CUDA_LONG input_index = 0;

int mod;
int index_of_dim0, index_of_dim1;
output_div_pitches[0].divmod(id, index_of_dim0, mod);
index_of_dim1 = mod;
int index_of_input_dim0, index_of_input_dim1, x_offset, y_offset;
scales_div[0].divmod(index_of_dim0, index_of_input_dim0, y_offset);
scales_div[1].divmod(index_of_dim1, index_of_input_dim1, x_offset);

input_index = index_of_input_dim0 * input_pitches[0] + index_of_input_dim1;

T x00 = input_data[input_index];
T x10, x01, x11;

bool end_of_dim0 = false;
if (index_of_input_dim0 == (input_dim0 - 1)) {
// It's the end in dimension 0
x01 = x00;
end_of_dim0 = true;
} else {
x01 = input_data[input_index + input_pitches[0]];
}

if (index_of_input_dim1 == (input_pitches[0] - 1)) {
// It's the end in dimension 1
x10 = x00;
x11 = x01;
} else {
x10 = input_data[input_index + 1];
x11 = end_of_dim0 ? x10 : input_data[input_index + input_pitches[0] + 1];
}

T y_offset_T = static_cast<T>(y_offset);
T x_offset_T = static_cast<T>(x_offset);
T scales_div0_T = static_cast<T>(scales_div[0].d_);
T scales_div1_T = static_cast<T>(scales_div[1].d_);
T y0 = x00 + static_cast<T>(y_offset_T * (x01 - x00) / scales_div0_T);
T y1 = x10 + static_cast<T>(y_offset_T * (x11 - x10) / scales_div0_T);

output_data[id] = y0 + static_cast<T>(x_offset_T * (y1 - y0) / scales_div1_T);
}

template <typename T>
void UpampleImpl(const onnxruntime::UpsampleMode upsample_mode,
const size_t rank,
Expand All @@ -105,8 +163,12 @@ void UpampleImpl(const onnxruntime::UpsampleMode upsample_mode,
_UpampleNearestKernel<T><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
rank, input_pitches, output_div_pitches, scales_div,
input_data, output_data, N);
} else if (onnxruntime::UpsampleMode::LINEAR == upsample_mode) {
_UpampleBilinearKernel<T><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
} else if (onnxruntime::UpsampleMode::LINEAR == upsample_mode && rank == 4) {
_UpampleBilinear4DInputKernel<T><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
input_dim2, input_pitches, output_div_pitches, scales_div,
input_data, output_data, N);
} else if (onnxruntime::UpsampleMode::LINEAR == upsample_mode && rank == 2) {
_UpampleBilinear2DInputKernel<T><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
input_dim2, input_pitches, output_div_pitches, scales_div,
input_data, output_data, N);
}
Expand Down
Loading