Skip to content

Commit

Permalink
Added Tests for 4D and 5D images
Browse files Browse the repository at this point in the history
4d col2im works, 5d and higher doesn't
  • Loading branch information
Thiago Crepaldi committed Aug 3, 2022
1 parent 72e056c commit 346dea5
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 17 deletions.
50 changes: 36 additions & 14 deletions onnxruntime/core/providers/cpu/tensor/col2im.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,20 +66,42 @@ Status Col2Im<T>::Compute(OpKernelContext* context) const {

std::cout << "\n\tStatus Col2Im<T>::Compute() --> math::Col2imNd<>()" << std::endl;

math::Col2imNd<T, CPUMathUtil, StorageOrder::NCHW>(
col_input_data, // const T* data_col,
image_shape->Data<int64_t>(), // const int64_t* img_shape,
Yshape.Slice(2).GetDims().data(), // const int64_t* output_shape,
col_input_C, // int64_t channels_col, --> output_num_channels * kernel_shape_size
image_shape_size, // int64_t img_size,
kernel_shape->Data<int64_t>(), // const int64_t* kernel_shape,
col2im_attrs_.strides.data(), // const int64_t* stride,
col2im_attrs_.dilations.data(), // const int64_t* dilation,
col2im_attrs_.pads.data(), // const int64_t* pad,
kernel_shape->Shape().Size(), // ptrdiff_t N, --> number of spatial dims for image
Ydata, // T* data_img,
&CPUMathUtil::Instance() // Provider* provider
);
if (image_shape->Shape()[0] == 2) {
std::cout << "image_shape->Shape()[0] == 2 --> Col2Im" << std::endl;
math::Col2im<float, CPUMathUtil, StorageOrder::NCHW>(
col_input_data,
col_input_C,
image_shape->Data<int64_t>()[0],
image_shape->Data<int64_t>()[1],
kernel_shape->Data<int64_t>()[0],
kernel_shape->Data<int64_t>()[1],
col2im_attrs_.dilations[0],
col2im_attrs_.dilations[1],
col2im_attrs_.pads[0],
col2im_attrs_.pads[1],
col2im_attrs_.pads[2],
col2im_attrs_.pads[3],
col2im_attrs_.strides[0],
col2im_attrs_.strides[1],
Ydata,
&CPUMathUtil::Instance());
} else {
std::cout << "image_shape->Shape()[0] != 2 --> Col2ImNd (nd=" << image_shape->Shape()[0] << ") " << std::endl;
math::Col2imNd<T, CPUMathUtil, StorageOrder::NCHW>(
col_input_data, // const T* data_col,
image_shape->Data<int64_t>(), // const int64_t* img_shape,
Yshape.Slice(2).GetDims().data(), // const int64_t* output_shape,
col_input_C, // int64_t channels_col, --> output_num_channels * kernel_shape_size
image_shape_size, // int64_t img_size,
kernel_shape->Data<int64_t>(), // const int64_t* kernel_shape,
col2im_attrs_.strides.data(), // const int64_t* stride,
col2im_attrs_.dilations.data(), // const int64_t* dilation,
col2im_attrs_.pads.data(), // const int64_t* pad,
kernel_shape->Shape().Size(), // ptrdiff_t N, --> number of spatial dims for image
Ydata, // T* data_img,
&CPUMathUtil::Instance() // Provider* provider
);
}
std::cout << "\n\n Return Col2Im<T>::Compute() --> "; for (auto i=0; i < Yshape.Size(); ++i) std::cout << Ydata[i] << ", "; std::cout << ") with shape " << Yshape << std::endl << std::endl;

return Status::OK();
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/util/math_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ void Im2col<T, StorageOrder::NCHW>::operator()(
std::cout << ",\n\tconst int64_t* kernel_shape={"; for (auto i=0; i < rank; ++i) std::cout << kernel_shape[i] << ", "; std::cout << "}";
std::cout << ",\n\tconst int64_t* stride={"; for (auto i=0; i < rank; ++i) std::cout << stride[i] << ", "; std::cout << "}";
std::cout << ",\n\tconst int64_t* dilation={"; for (auto i=0; i < rank; ++i) std::cout << dilation[i] << ", "; std::cout << "}";
std::cout << ",\n\tconst int64_t* pad={"; for (auto i=0; i < rank; ++i) std::cout << pad[i] << ", "; std::cout << "}";
std::cout << ",\n\tconst int64_t* pad={"; for (auto i=0; i < 2*rank; ++i) std::cout << pad[i] << ", "; std::cout << "}";
std::cout << ",\n\tptrdiff_t rank=" << rank;
std::cout << ",\n\tT* data_col= preallocated pointer to write at {"; for (auto i=0; i < output_shape_size; ++i) std::cout << data_col[i] << ", "; std::cout << "}";
std::cout << ",\n\tbool accumulate_output=" << accumulate_output;
Expand Down
19 changes: 17 additions & 2 deletions onnxruntime/test/contrib_ops/col2im_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,36 @@
namespace onnxruntime {
namespace test {

TEST(Col2ImContribOpTest, simple) {
TEST(Col2ImContribOpTest, simple4dNCHW) {
OpTester test("Col2Im", 1, kMSDomain);

test.AddAttribute("strides", std::vector<int64_t>{1, 1});
test.AddAttribute("dilations", std::vector<int64_t>{1, 1});
test.AddAttribute("pads", std::vector<int64_t>{0, 0, 0, 0});

test.AddInput<float>("input", {1, 5, 5}, std::vector<float>{1.f, 6.f, 11.f, 16.f, 21.f, 2.f, 7.f, 12.f, 17.f, 22.f, 3.f, 8.f, 13.f, 18.f, 23.f, 4.f, 9.f, 14.f, 19.f, 24.f, 5.f, 0.f, 15.f, 20.f, 25.f});
test.AddInput<float>("input", {1, 5, 5}, std::vector<float>{1.f, 6.f, 11.f, 16.f, 21.f, 2.f, 7.f, 12.f, 17.f, 22.f, 3.f, 8.f, 13.f, 18.f, 23.f, 4.f, 9.f, 14.f, 19.f, 24.f, 5.f, 10.f, 15.f, 20.f, 25.f});
test.AddInput<int64_t>("image_shape", {2}, std::vector<int64_t>{5, 5});
test.AddInput<int64_t>("block_shape", {2}, std::vector<int64_t>{1, 5});

test.AddOutput<float>("output", {1, 1, 5, 5}, std::vector<float>{1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 25.f});
test.Run();
}

TEST(Col2ImContribOpTest, simple5dNCHWD) {
OpTester test("Col2Im", 1, kMSDomain);

test.AddAttribute("strides", std::vector<int64_t>{1, 1, 1});
test.AddAttribute("dilations", std::vector<int64_t>{1, 1, 1});
test.AddAttribute("pads", std::vector<int64_t>{0, 0, 0, 0, 0, 0});

test.AddInput<float>("input", {1, 5, 5}, std::vector<float>{1.f, 6.f, 11.f, 16.f, 21.f, 2.f, 7.f, 12.f, 17.f, 22.f, 3.f, 8.f, 13.f, 18.f, 23.f, 4.f, 9.f, 14.f, 19.f, 24.f, 5.f, 10.f, 15.f, 20.f, 25.f});
test.AddInput<int64_t>("image_shape", {3}, std::vector<int64_t>{1, 5, 5});
test.AddInput<int64_t>("block_shape", {3}, std::vector<int64_t>{1, 1, 5});

test.AddOutput<float>("output", {1, 1, 1, 5, 5}, std::vector<float>{1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 25.f});
test.Run();
}


} // namespace test
} // namespace onnxruntime

0 comments on commit 346dea5

Please sign in to comment.