Skip to content

Commit

Permalink
Add onnx support for com.microsoft.NhwcConv (#3796)
Browse files Browse the repository at this point in the history
  • Loading branch information
mirza-halilcevic authored Feb 10, 2025
1 parent a75a891 commit ea9d54a
Show file tree
Hide file tree
Showing 7 changed files with 211 additions and 3 deletions.
25 changes: 24 additions & 1 deletion src/onnx/conv.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand All @@ -22,6 +22,9 @@
* THE SOFTWARE.
*/
#include <migraphx/onnx/conv.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/permutation.hpp>
#include <algorithm>

namespace migraphx {
Expand All @@ -47,6 +50,26 @@ void recalc_conv_attributes(value& v, size_t kdims)
}
}

static instruction_ref
apply_nhwc_perm(const onnx_parser::node_info& info, instruction_ref ins, bool invert)
{
std::vector<int64_t> perm(ins->get_shape().ndim());
std::iota(begin(perm) + 1, end(perm) - 1, 2);
perm.back() = 1;
return info.add_instruction(
make_op("transpose", {{"permutation", invert ? invert_permutation(perm) : perm}}), ins);
}

instruction_ref from_nhwc(const onnx_parser::node_info& info, instruction_ref ins)
{
return apply_nhwc_perm(info, ins, true);
}

instruction_ref to_nhwc(const onnx_parser::node_info& info, instruction_ref ins)
{
return apply_nhwc_perm(info, ins, false);
}

} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
7 changes: 6 additions & 1 deletion src/onnx/include/migraphx/onnx/conv.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand All @@ -26,13 +26,18 @@

#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/onnx/onnx_parser.hpp>
#include <migraphx/instruction_ref.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {

void recalc_conv_attributes(value& v, size_t kdims);

instruction_ref from_nhwc(const onnx_parser::node_info& info, instruction_ref ins);
instruction_ref to_nhwc(const onnx_parser::node_info& info, instruction_ref ins);

} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Expand Down
16 changes: 15 additions & 1 deletion src/onnx/parse_convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ struct parse_convolution : op_parser<parse_convolution>
{
std::vector<op_desc> operators() const
{
return {{"Conv", "convolution"}, {"ConvInteger", "quant_convolution"}};
return {{"Conv", "convolution"},
{"ConvInteger", "quant_convolution"},
{"NhwcConv", "convolution"}};
}

// Convert to half prior to a shift to ensure we preserve accuracy here then
Expand Down Expand Up @@ -240,6 +242,13 @@ struct parse_convolution : op_parser<parse_convolution>
auto values = op.to_value();
auto x = args[0];
auto weights = args[1];

if(opd.onnx_name == "NhwcConv")
{
x = from_nhwc(info, x);
weights = from_nhwc(info, weights);
}

auto x_shape = x->get_shape();
auto w_shape = weights->get_shape();
auto in_lens = x_shape.max_lens();
Expand Down Expand Up @@ -362,6 +371,11 @@ struct parse_convolution : op_parser<parse_convolution>
ret = info.add_bias(args, ret, 1);
}

if(opd.onnx_name == "NhwcConv")
{
ret = to_nhwc(info, ret);
}

return ret;
}
};
Expand Down
11 changes: 11 additions & 0 deletions test/onnx/gen_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -8802,6 +8802,17 @@ def neg_dynamic_test():
return ([node], [x], [y])


@onnx_test()
def nhwcconv_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 7, 7, 1])
w = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 1, 1, 1])
out = helper.make_tensor_value_info('2', TensorProto.FLOAT, [1, 7, 7, 1])

node = onnx.helper.make_node('NhwcConv', inputs=['0', '1'], outputs=['2'])

return ([node], [x, w], [out])


@onnx_test()
def nms_test():
b = helper.make_tensor_value_info('boxes', TensorProto.FLOAT, [1, 6, 4])
Expand Down
23 changes: 23 additions & 0 deletions test/onnx/nhwcconv_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@

nhwcconv_test:{

0
12"NhwcConvnhwcconv_testZ
0




Z
1




b
2




B
Expand Down
45 changes: 45 additions & 0 deletions test/onnx/parse/nhwcconv_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include <onnx_test.hpp>

TEST_CASE(nhwcconv_test)
{
migraphx::program p;
auto* mm = p.get_main_module();

migraphx::shape x_shape{migraphx::shape::float_type, {1, 7, 7, 1}};
auto x = mm->add_parameter("0", x_shape);
x = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), x);

migraphx::shape w_shape{migraphx::shape::float_type, {1, 1, 1, 1}};
auto w = mm->add_parameter("1", w_shape);
w = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), w);

auto y = mm->add_instruction(migraphx::make_op("convolution"), x, w);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), y);

migraphx::program prog = optimize_onnx("nhwcconv_test.onnx");
EXPECT(p == prog);
}
87 changes: 87 additions & 0 deletions test/onnx/verify/nhwcconv_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <onnx_test.hpp>
#include <onnx_verify_utils.hpp>

TEST_CASE(nhwcconv_test)
{
migraphx::program p = read_onnx("nhwcconv_test.onnx");
p.compile(migraphx::make_target("ref"));

migraphx::shape x_shape{migraphx::shape::float_type, {1, 7, 7, 1}};
std::vector<float> x_data = {
0.45246148109436035f, 0.15498268604278564f, 0.11199361085891724f,
-0.39421093463897705f, 0.2626858949661255f, 0.13414543867111206f,
-0.27184486389160156f, -0.43028733134269714f, -0.26825493574142456f,
0.3893144130706787f, -0.13631996512413025f, -0.009590476751327515f,
-0.48771554231643677f, -0.25256502628326416f, -0.2812897562980652f,
0.4043201804161072f, 0.07795023918151855f, 0.326981782913208f,
0.13114392757415771f, -0.4416425824165344f, 0.12446999549865723f,
0.36739975214004517f, 0.1698915958404541f, 0.2008744478225708f,
0.23339951038360596f, 0.38613730669021606f, 0.11117297410964966f,
0.3877097964286804f, 0.20812749862670898f, -0.34297940135002136f,
-0.029246658086776733f, -0.20483523607254028f, -0.19244328141212463f,
-0.11104947328567505f, -0.32830488681793213f, -0.01800677180290222f,
0.3618946671485901f, -0.40949052572250366f, -0.18248388171195984f,
-0.3349453806877136f, -0.34091079235076904f, 0.006497859954833984f,
0.4537564516067505f, 0.08006560802459717f, -0.14788749814033508f,
0.034442365169525146f, -0.33322954177856445f, 0.06049239635467529f,
0.42619407176971436f};

migraphx::shape w_shape{migraphx::shape::float_type, {1, 1, 1, 1}};
std::vector<float> w_data = {-0.4406261742115021f};

migraphx::parameter_map pm;
pm["0"] = migraphx::argument{x_shape, x_data.data()};
pm["1"] = migraphx::argument{w_shape, w_data.data()};

auto result = p.eval(pm).back();
EXPECT(result.get_shape().lens() == std::vector<std::size_t>{1, 7, 7, 1});

std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });

std::vector<float> gold = {
-0.19936637580394745f, -0.06828942894935608f, -0.04934731498360634f,
0.17369966208934784f, -0.11574628204107285f, -0.05910799279808998f,
0.1197819635272026f, 0.18959586322307587f, 0.1182001456618309f,
-0.17154212296009064f, 0.06006614491343498f, 0.0042258151806890965f,
0.21490024030208588f, 0.11128675937652588f, 0.12394362688064575f,
-0.17815405130386353f, -0.034346915781497955f, -0.14407673478126526f,
-0.05778544768691063f, 0.19459928572177887f, -0.05484473705291748f,
-0.16188594698905945f, -0.07485868036746979f, -0.08851054310798645f,
-0.10284193605184555f, -0.17014220356941223f, -0.04898572340607643f,
-0.17083507776260376f, -0.09170642495155334f, 0.1511256992816925f,
0.012886842712759972f, 0.09025576710700989f, 0.08479554951190948f,
0.0489313043653965f, 0.14465972781181335f, 0.007934254594147205f,
-0.15946026146411896f, 0.1804322451353073f, 0.08040717244148254f,
0.1475857049226761f, 0.15021422505378723f, -0.0028631272725760937f,
-0.19993697106838226f, -0.03527900204062462f, 0.06516310572624207f,
-0.015176207758486271f, 0.14682966470718384f, -0.02665453404188156f,
-0.18779225647449493f};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}

0 comments on commit ea9d54a

Please sign in to comment.