From e8809b0a993845221ff6c47d2b9ed7abde1ac5b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dino=20Musi=C4=87?= Date: Fri, 27 Sep 2024 07:52:44 +0000 Subject: [PATCH 1/7] Base case implementation --- src/onnx/parse_matmulnbits.cpp | 119 ++++++++++++++++++++++++++ test/onnx/gen_onnx.py | 18 ++++ test/onnx/matmulnbits_test.onnx | 27 ++++++ test/onnx/verify/matmulnbits_test.cpp | 65 ++++++++++++++ 4 files changed, 229 insertions(+) create mode 100644 src/onnx/parse_matmulnbits.cpp create mode 100644 test/onnx/matmulnbits_test.onnx create mode 100644 test/onnx/verify/matmulnbits_test.cpp diff --git a/src/onnx/parse_matmulnbits.cpp b/src/onnx/parse_matmulnbits.cpp new file mode 100644 index 00000000000..d04bdba9abb --- /dev/null +++ b/src/onnx/parse_matmulnbits.cpp @@ -0,0 +1,119 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include "migraphx/instruction_ref.hpp" +#include "migraphx/onnx/onnx_parser.hpp" +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +struct parse_matmulnbits : op_parser +{ + std::vector operators() const { return {{"MatMulNBits"}}; } + + instruction_ref parse(const op_desc& /*opd*/, + const onnx_parser& parser, + onnx_parser::node_info info, + const std::vector& args) const + { + std::cout << "MatMulNBits parser" << std::endl; + const auto N = *parse_attribute("N", parser, info); + const auto K = *parse_attribute("K", parser, info); + // TODO only 4 is valid + const auto bits = *parse_attribute("bits", parser, info); + // TODO check that it is >= 16 and a power of 2 + const auto block_size = *parse_attribute("block_size", parser, info); + std::cout << "N: " << N << std::endl; + std::cout << "K: " << K << std::endl; + std::cout << "bits: " << bits << std::endl; + std::cout << "block_size: " << block_size << std::endl; + + std::cout << "A shape: " << args[0]->get_shape() << std::endl; + std::cout << "B shape: " << args[1]->get_shape() << std::endl; + + auto b = info.add_instruction(make_op("reshape", {{"dims", {N, -1}}}), args[1]); + b = info.add_instruction(make_op("unpack_int4"), b); + // Shape: [N x n_blocks_per_col] -> reshape to [N, n_blocks_per_col] + auto n_blocks_per_col = (K + block_size - 1) / block_size; + auto scales = + info.add_instruction(make_op("reshape", {{"dims", {N, n_blocks_per_col}}}), args[2]); + std::cout << scales->get_shape() << std::endl; + b = add_dequantize(info, block_size, 1, b, scales); + std::cout << "b shape after dq: " << b->get_shape() << std::endl; + b = info.add_instruction(make_op("transpose", {{"permutation", {1, 0}}}), b); + // Replace with proper matmul + return info.add_instruction(make_op("dot"), args[0], b); + } + + private: + template + std::optional parse_attribute(const std::string& attribute_name, + const onnx_parser& parser, + onnx_parser::node_info& info) const + { + if(not contains(info.attributes, attribute_name)) + return std::nullopt; + + return parser.parse_value(info.attributes[attribute_name]).at(); + } + + instruction_ref add_dequantize(onnx_parser::node_info& info, + int block_size, + int axis, + instruction_ref b, + instruction_ref scales) const + { + scales = info.add_instruction(make_op("unsqueeze", {{"axes", {axis + 1}}}), scales); + + auto bc_lens = scales->get_shape().lens(); + bc_lens[axis + 1] = block_size; + scales = info.add_instruction(make_op("multibroadcast", {{"out_lens", bc_lens}}), scales); + std::cout << scales->get_shape() << std::endl; + + auto reshape_lens = b->get_shape().lens(); + reshape_lens[axis] = scales->get_shape().lens()[axis] * block_size; + scales = info.add_instruction(make_op("reshape", {{"dims", reshape_lens}}), scales); + + // TODO: Runt blocks shouldn't be able to happen, but double check + // // Detect runt block + // if(x_lens[axis] < reshape_lens[axis]) + // { + // ins = info.add_instruction( + // make_op("slice", {{"axes", {axis}}, {"starts", {0}}, {"ends", + // {x_lens[axis]}}}), ins); + // } + + // TODO if zeropoint input is not present, it should not be assumed to be zero, it should be -8 + return info.add_instruction(make_op("dequantizelinear"), {b, scales}); + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index a1c9d0dd81f..bfaa8b6e0fa 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -8566,6 +8566,24 @@ def qlinearmatmul_3D_test(): [sc_a, zero_pt_a, sc_b, zero_pt_b, sc_c, zero_pt_c]) +@onnx_test() +def matmulnbits_test(): + a = onnx.helper.make_tensor_value_info("a", onnx.TensorProto.FLOAT, [2, 16]) + b = onnx.helper.make_tensor_value_info("b", onnx.TensorProto.UINT8, [4, 1, 8]) + scales = onnx.helper.make_tensor_value_info("scales", onnx.TensorProto.FLOAT, [4]) + c = onnx.helper.make_tensor_value_info("c", onnx.TensorProto.FLOAT, [2, 4]) + + node = onnx.helper.make_node("MatMulNBits", + inputs=["a", "b", "scales"], + outputs=["c"], + bits=4, + block_size=16, + K=16, + N=4, + domain='com.microsoft') + return ([node], [a, b, scales], [c]) + + @onnx_test() def qlinearmul_test(): a = helper.make_tensor_value_info('A', TensorProto.UINT8, [64]) diff --git a/test/onnx/matmulnbits_test.onnx b/test/onnx/matmulnbits_test.onnx new file mode 100644 index 00000000000..ead42e821df --- /dev/null +++ b/test/onnx/matmulnbits_test.onnx @@ -0,0 +1,27 @@ + matmulnbits_test:Î +a +a +b +scalesc" MatMulNBits* +K * +N * +bits * + +block_size : com.microsoftmatmulnbits_testZ +a +  + +Z +b + + + +Z +scales + + +b +c +  + +B \ No newline at end of file diff --git a/test/onnx/verify/matmulnbits_test.cpp b/test/onnx/verify/matmulnbits_test.cpp new file mode 100644 index 00000000000..4d83bce5d28 --- /dev/null +++ b/test/onnx/verify/matmulnbits_test.cpp @@ -0,0 +1,65 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "migraphx/module.hpp" +#include +#include +#include +#include +#include + +TEST_CASE(matmulnbits_test) +{ + auto p = optimize_onnx("matmulnbits_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::parameter_map pm; + auto a_shape = migraphx::shape{migraphx::shape::float_type, {2, 16}}; + std::vector a(a_shape.elements()); + std::iota(a.begin(), a.end(), 0); + pm["a"] = migraphx::argument(a_shape, a.data()); + + auto b_shape = migraphx::shape{migraphx::shape::uint8_type, {4, 1, 8}}; + std::vector b(b_shape.elements(), 0x21); + pm["b"] = migraphx::argument(b_shape, b.data()); + + auto scales_shape = migraphx::shape{migraphx::shape::float_type, {4}}; + std::vector scales{1, 2, 3, 4}; + pm["scales"] = migraphx::argument(scales_shape, scales.data()); + + auto result = p.eval(pm).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::cout << result.get_shape() << std::endl; + std::cout << migraphx::to_string_range(result_vector) << std::endl; + // for(auto i = 0; i < 4; ++i) + // { + // for(auto j = 0; j < 16; ++j) + // { + // std::cout << static_cast(result_vector[i * 16 + j]) << " "; + // } + // std::cout << std::endl; + // } +} From 3267543f2c007fd8ca9fc01e73937422b5fffd09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dino=20Musi=C4=87?= Date: Tue, 1 Oct 2024 10:36:43 +0000 Subject: [PATCH 2/7] Fix error in b input shape assumption, implement support for zero point input --- src/onnx/parse_matmulnbits.cpp | 158 ++++++++++++++++---------- test/onnx/gen_onnx.py | 45 ++++++-- test/onnx/matmulnbits2_test.onnx | 27 +++++ test/onnx/matmulnbits_test.onnx | 11 +- test/onnx/verify/matmulnbits_test.cpp | 65 +++++++++-- 5 files changed, 225 insertions(+), 81 deletions(-) create mode 100644 test/onnx/matmulnbits2_test.onnx diff --git a/src/onnx/parse_matmulnbits.cpp b/src/onnx/parse_matmulnbits.cpp index d04bdba9abb..03c15597d36 100644 --- a/src/onnx/parse_matmulnbits.cpp +++ b/src/onnx/parse_matmulnbits.cpp @@ -21,6 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ +#include "migraphx/errors.hpp" +#include "migraphx/functional.hpp" #include "migraphx/instruction_ref.hpp" #include "migraphx/onnx/onnx_parser.hpp" #include @@ -42,75 +44,117 @@ struct parse_matmulnbits : op_parser onnx_parser::node_info info, const std::vector& args) const { - std::cout << "MatMulNBits parser" << std::endl; - const auto N = *parse_attribute("N", parser, info); - const auto K = *parse_attribute("K", parser, info); - // TODO only 4 is valid - const auto bits = *parse_attribute("bits", parser, info); - // TODO check that it is >= 16 and a power of 2 - const auto block_size = *parse_attribute("block_size", parser, info); - std::cout << "N: " << N << std::endl; - std::cout << "K: " << K << std::endl; - std::cout << "bits: " << bits << std::endl; - std::cout << "block_size: " << block_size << std::endl; - - std::cout << "A shape: " << args[0]->get_shape() << std::endl; - std::cout << "B shape: " << args[1]->get_shape() << std::endl; - - auto b = info.add_instruction(make_op("reshape", {{"dims", {N, -1}}}), args[1]); - b = info.add_instruction(make_op("unpack_int4"), b); - // Shape: [N x n_blocks_per_col] -> reshape to [N, n_blocks_per_col] - auto n_blocks_per_col = (K + block_size - 1) / block_size; - auto scales = - info.add_instruction(make_op("reshape", {{"dims", {N, n_blocks_per_col}}}), args[2]); - std::cout << scales->get_shape() << std::endl; - b = add_dequantize(info, block_size, 1, b, scales); - std::cout << "b shape after dq: " << b->get_shape() << std::endl; + const int N = parse_attribute(parser, info, "N"); + const int K = parse_attribute(parser, info, "K"); + const int bits = parse_attribute(parser, info, "bits"); + const int block_size = parse_attribute(parser, info, "block_size"); + + if(bits != 4) + MIGRAPHX_THROW("MatMulNBits: bits only supported for value of 4, actual value " + + std::to_string(bits)); + + if(block_size < 16 and block_size % 2 != 0) + MIGRAPHX_THROW("MatMulNBits: block_size must be a power of 2 and greater or equal to " + "16, actual value " + + std::to_string(block_size)); + + int n_blocks_per_col = (K + block_size - 1) / block_size; + int blob_size = std::ceil(block_size * bits / 8.0f); + + std::vector expected_b_lens{static_cast(N), + static_cast(n_blocks_per_col), + static_cast(blob_size)}; + if(args[1]->get_shape().lens() != expected_b_lens) + MIGRAPHX_THROW("Input B does not match expected dims TODO"); + + std::vector expected_scales_lens{static_cast(N * n_blocks_per_col)}; + if(args[2]->get_shape().lens() != expected_scales_lens) + MIGRAPHX_THROW("Input Scales does not match expected dims TODO"); + + if(args.size() > 3) + { + std::vector expected_zp_lens{ + static_cast(N * std::ceil(n_blocks_per_col * bits / 8.0f))}; + if(args[3]->get_shape().lens() != expected_zp_lens) + MIGRAPHX_THROW("MatMulNBits: TODO"); + } + + auto b = dequantize_b(info, N, K, block_size, args); + b = info.add_instruction(make_op("transpose", {{"permutation", {1, 0}}}), b); // Replace with proper matmul return info.add_instruction(make_op("dot"), args[0], b); } private: - template - std::optional parse_attribute(const std::string& attribute_name, - const onnx_parser& parser, - onnx_parser::node_info& info) const + int parse_attribute(const onnx_parser& parser, + onnx_parser::node_info& info, + const std::string& attribute_name) const { if(not contains(info.attributes, attribute_name)) - return std::nullopt; + MIGRAPHX_THROW("MatMulNBits: Attribute " + attribute_name + + " required, but is missing"); - return parser.parse_value(info.attributes[attribute_name]).at(); + return parser.parse_value(info.attributes[attribute_name]).at(); } - instruction_ref add_dequantize(onnx_parser::node_info& info, - int block_size, - int axis, - instruction_ref b, - instruction_ref scales) const + instruction_ref dequantize_b(onnx_parser::node_info& info, + int N, + int K, + int block_size, + const std::vector& args) const { - scales = info.add_instruction(make_op("unsqueeze", {{"axes", {axis + 1}}}), scales); - - auto bc_lens = scales->get_shape().lens(); - bc_lens[axis + 1] = block_size; - scales = info.add_instruction(make_op("multibroadcast", {{"out_lens", bc_lens}}), scales); - std::cout << scales->get_shape() << std::endl; - - auto reshape_lens = b->get_shape().lens(); - reshape_lens[axis] = scales->get_shape().lens()[axis] * block_size; - scales = info.add_instruction(make_op("reshape", {{"dims", reshape_lens}}), scales); - - // TODO: Runt blocks shouldn't be able to happen, but double check - // // Detect runt block - // if(x_lens[axis] < reshape_lens[axis]) - // { - // ins = info.add_instruction( - // make_op("slice", {{"axes", {axis}}, {"starts", {0}}, {"ends", - // {x_lens[axis]}}}), ins); - // } - - // TODO if zeropoint input is not present, it should not be assumed to be zero, it should be -8 - return info.add_instruction(make_op("dequantizelinear"), {b, scales}); + auto b = unpack(info, args[1], N, K); + + auto n_blocks_per_col = (K + block_size - 1) / block_size; + auto scales = info.add_instruction(make_op("reshape", {{"dims", {N, -1}}}), args[2]); + scales = prepare_blockwise_dq_arg(info, scales, N, K, block_size); + + instruction_ref zp; + if(args.size() == 4) + { + zp = unpack(info, args[3], N, n_blocks_per_col); + zp = prepare_blockwise_dq_arg(info, zp, N, K, block_size); + } + else + { + zp = info.add_literal(literal{shape{shape::uint8_type, {1}}, {8}}); + zp = info.add_instruction( + make_op("multibroadcast", {{"out_lens", b->get_shape().lens()}}), zp); + } + return info.add_instruction(make_op("dequantizelinear"), {b, scales, zp}); + } + + instruction_ref unpack(onnx_parser::node_info& info, instruction_ref x, int N, int dim1) const + { + x = info.add_instruction(make_op("unpack_int4"), x); + x = info.add_instruction(make_op("reshape", {{"dims", {N, -1}}}), x); + if(x->get_shape().lens()[1] > dim1) + { + x = info.add_instruction( + make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {dim1}}}), x); + } + return x; + } + + instruction_ref prepare_blockwise_dq_arg( + onnx_parser::node_info& info, instruction_ref x, int N, int K, int block_size) const + { + x = info.add_instruction(make_op("unsqueeze", {{"axes", {2}}}), x); + + auto bc_lens = x->get_shape().lens(); + bc_lens[2] = block_size; + x = info.add_instruction(make_op("multibroadcast", {{"out_lens", bc_lens}}), x); + x = info.add_instruction(make_op("reshape", {{"dims", {N, -1}}}), x); + + // Detect runt block + if(x->get_shape().lens()[1] > K) + { + x = info.add_instruction( + make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {K}}}), x); + } + + return x; } }; diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index bfaa8b6e0fa..5555d218cb5 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -8568,19 +8568,44 @@ def qlinearmatmul_3D_test(): @onnx_test() def matmulnbits_test(): - a = onnx.helper.make_tensor_value_info("a", onnx.TensorProto.FLOAT, [2, 16]) - b = onnx.helper.make_tensor_value_info("b", onnx.TensorProto.UINT8, [4, 1, 8]) - scales = onnx.helper.make_tensor_value_info("scales", onnx.TensorProto.FLOAT, [4]) + a = onnx.helper.make_tensor_value_info("a", onnx.TensorProto.FLOAT, + [2, 16]) + b = onnx.helper.make_tensor_value_info("b", onnx.TensorProto.UINT8, + [4, 1, 8]) + scales = onnx.helper.make_tensor_value_info("scales", + onnx.TensorProto.FLOAT, [4]) + zp = onnx.helper.make_tensor_value_info("zp", onnx.TensorProto.UINT8, [4]) c = onnx.helper.make_tensor_value_info("c", onnx.TensorProto.FLOAT, [2, 4]) node = onnx.helper.make_node("MatMulNBits", - inputs=["a", "b", "scales"], - outputs=["c"], - bits=4, - block_size=16, - K=16, - N=4, - domain='com.microsoft') + inputs=["a", "b", "scales", "zp"], + outputs=["c"], + bits=4, + block_size=16, + K=16, + N=4, + domain='com.microsoft') + return ([node], [a, b, scales, zp], [c]) + + +@onnx_test() +def matmulnbits2_test(): + a = onnx.helper.make_tensor_value_info("a", onnx.TensorProto.FLOAT, + [2, 33]) + b = onnx.helper.make_tensor_value_info("b", onnx.TensorProto.UINT8, + [2, 3, 8]) + scales = onnx.helper.make_tensor_value_info("scales", + onnx.TensorProto.FLOAT, [6]) + c = onnx.helper.make_tensor_value_info("c", onnx.TensorProto.FLOAT, [2, 2]) + + node = onnx.helper.make_node("MatMulNBits", + inputs=["a", "b", "scales"], + outputs=["c"], + bits=4, + block_size=16, + K=33, + N=2, + domain='com.microsoft') return ([node], [a, b, scales], [c]) diff --git a/test/onnx/matmulnbits2_test.onnx b/test/onnx/matmulnbits2_test.onnx new file mode 100644 index 00000000000..872a4289910 --- /dev/null +++ b/test/onnx/matmulnbits2_test.onnx @@ -0,0 +1,27 @@ + matmulnbits2_test:Ï +a +a +b +scalesc" MatMulNBits* +K! * +N * +bits * + +block_size : com.microsoftmatmulnbits2_testZ +a +  + +!Z +b + + + +Z +scales + + +b +c +  + +B \ No newline at end of file diff --git a/test/onnx/matmulnbits_test.onnx b/test/onnx/matmulnbits_test.onnx index ead42e821df..2579e23b86f 100644 --- a/test/onnx/matmulnbits_test.onnx +++ b/test/onnx/matmulnbits_test.onnx @@ -1,8 +1,9 @@ - matmulnbits_test:Î -a + matmulnbits_test:ä +e a b -scalesc" MatMulNBits* +scales +zpc" MatMulNBits* K * N * bits * @@ -20,6 +21,10 @@ block_size scales  +Z +zp + + b c  diff --git a/test/onnx/verify/matmulnbits_test.cpp b/test/onnx/verify/matmulnbits_test.cpp index 4d83bce5d28..d1268ca0f5c 100644 --- a/test/onnx/verify/matmulnbits_test.cpp +++ b/test/onnx/verify/matmulnbits_test.cpp @@ -22,6 +22,7 @@ * THE SOFTWARE. */ +#include "migraphx/argument.hpp" #include "migraphx/module.hpp" #include #include @@ -41,25 +42,67 @@ TEST_CASE(matmulnbits_test) pm["a"] = migraphx::argument(a_shape, a.data()); auto b_shape = migraphx::shape{migraphx::shape::uint8_type, {4, 1, 8}}; - std::vector b(b_shape.elements(), 0x21); + std::vector b{0x2, 0xe3, 0xc7, 0x89, 0xbd, 0xbe, 0x50, 0x41, 0xe9, 0xb4, 0xd4, + 0x54, 0xc6, 0xb2, 0xfa, 0x27, 0x14, 0x3d, 0xbb, 0xe7, 0xa5, 0x0, + 0x52, 0x28, 0xc1, 0xd9, 0x1f, 0x33, 0x16, 0x1e, 0x8b, 0x3c}; pm["b"] = migraphx::argument(b_shape, b.data()); auto scales_shape = migraphx::shape{migraphx::shape::float_type, {4}}; std::vector scales{1, 2, 3, 4}; pm["scales"] = migraphx::argument(scales_shape, scales.data()); + auto zp_shape = migraphx::shape{migraphx::shape::uint8_type, {4}}; + std::vector zp{0x08, 0x09, 0x0a, 0x0b}; + pm["zp"] = migraphx::argument{zp_shape, zp.data()}; + auto result = p.eval(pm).back(); std::vector result_vector; result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + std::vector gold{ + -111.0f, -290.0f, -1692.0f, -1960.0f, -335.0f, -770.0f, -4764.0f, -5992.0f}; - std::cout << result.get_shape() << std::endl; - std::cout << migraphx::to_string_range(result_vector) << std::endl; - // for(auto i = 0; i < 4; ++i) - // { - // for(auto j = 0; j < 16; ++j) - // { - // std::cout << static_cast(result_vector[i * 16 + j]) << " "; - // } - // std::cout << std::endl; - // } + EXPECT(result.get_shape().lens() == std::vector{2, 4}); + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); } + +TEST_CASE(matmulnbits2_test) +{ + auto p = optimize_onnx("matmulnbits2_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::parameter_map pm; + auto a_shape = migraphx::shape{migraphx::shape::float_type, {2, 33}}; + std::vector a{ + 0.15541f, 0.24434f, 0.66716f, 0.13632f, 0.76915f, 0.21328f, 0.17331f, 0.93251f, 0.14816f, + 0.08181f, 0.54035f, 0.86664f, 0.92605f, 0.89766f, 0.02441f, 0.33504f, 0.60488f, 0.25918f, + 0.64644f, 0.98881f, 0.27669f, 0.94888f, 0.21201f, 0.33377f, 0.95608f, 0.40923f, 0.66899f, + 0.58904f, 0.41560f, 0.87399f, 0.74596f, 0.10849f, 0.94527f, 0.88573f, 0.66875f, 0.57536f, + 0.81454f, 0.15699f, 0.15464f, 0.17399f, 0.08090f, 0.99368f, 0.45535f, 0.92528f, 0.91968f, + 0.76970f, 0.59638f, 0.23635f, 0.54877f, 0.96025f, 0.48969f, 0.55297f, 0.52498f, 0.29102f, + 0.01359f, 0.77372f, 0.81897f, 0.03003f, 0.00822f, 0.55477f, 0.54635f, 0.91918f, 0.76486f, + 0.73698f, 0.29821f, 0.41801}; + pm["a"] = migraphx::argument(a_shape, a.data()); + + auto b_shape = migraphx::shape{migraphx::shape::uint8_type, {2, 3, 8}}; + std::vector b{0x18, 0x9, 0x8b, 0xe1, 0xfb, 0x94, 0x11, 0x56, 0x4e, 0xac, 0xd3, 0x4b, + 0xf7, 0x8e, 0x54, 0xef, 0x0b, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x6e, 0xb7, 0x20, 0x4f, 0xa7, 0x82, 0x83, 0xbf, 0x20, 0xde, 0xa4, 0xf, + 0x72, 0x81, 0x8, 0x83, 0x0a, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}; + pm["b"] = migraphx::argument(b_shape, b.data()); + + auto scales_shape = migraphx::shape{migraphx::shape::float_type, {6}}; + std::vector scales{0.29033, 0.80435, 2.60200, 2.39623, 1.40796, 2.38139}; + pm["scales"] = migraphx::argument(scales_shape, scales.data()); + + // auto zp_shape = migraphx::shape{migraphx::shape::uint8_type, {4}}; + // std::vector zp{0x08, 0x09, 0x0a, 0x0b}; + // pm["zp"] = migraphx::argument{zp_shape, zp.data()}; + + auto result = p.eval(pm).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + std::vector gold{18.54672f, -62.38305f, 4.978874f, -31.228657f}; + + EXPECT(result.get_shape().lens() == std::vector{2, 2}); + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} \ No newline at end of file From 1d1b8a3f03b677ea1d9c6b4112573b010eb5254f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dino=20Musi=C4=87?= Date: Wed, 2 Oct 2024 07:32:31 +0000 Subject: [PATCH 3/7] Implement onnx verify tests for MatMulNBits --- src/onnx/parse_matmulnbits.cpp | 63 ++++++++++-- test/onnx/gen_onnx.py | 47 ++++++++- test/onnx/matmulnbits_bmm_test.onnx | 29 ++++++ ...s2_test.onnx => matmulnbits_mm2_test.onnx} | 4 +- ...its_test.onnx => matmulnbits_mm_test.onnx} | 4 +- test/onnx/matmulnbits_vm_test.onnx | 32 ++++++ test/onnx/verify/matmulnbits_test.cpp | 99 +++++++++++++++++-- 7 files changed, 253 insertions(+), 25 deletions(-) create mode 100644 test/onnx/matmulnbits_bmm_test.onnx rename test/onnx/{matmulnbits2_test.onnx => matmulnbits_mm2_test.onnx} (65%) rename test/onnx/{matmulnbits_test.onnx => matmulnbits_mm_test.onnx} (68%) create mode 100644 test/onnx/matmulnbits_vm_test.onnx diff --git a/src/onnx/parse_matmulnbits.cpp b/src/onnx/parse_matmulnbits.cpp index 03c15597d36..29815580e3e 100644 --- a/src/onnx/parse_matmulnbits.cpp +++ b/src/onnx/parse_matmulnbits.cpp @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ +#include "migraphx/common.hpp" #include "migraphx/errors.hpp" -#include "migraphx/functional.hpp" #include "migraphx/instruction_ref.hpp" #include "migraphx/onnx/onnx_parser.hpp" #include @@ -53,7 +53,7 @@ struct parse_matmulnbits : op_parser MIGRAPHX_THROW("MatMulNBits: bits only supported for value of 4, actual value " + std::to_string(bits)); - if(block_size < 16 and block_size % 2 != 0) + if(block_size < 16 and (block_size & (block_size - 1)) != 0) MIGRAPHX_THROW("MatMulNBits: block_size must be a power of 2 and greater or equal to " "16, actual value " + std::to_string(block_size)); @@ -65,25 +65,29 @@ struct parse_matmulnbits : op_parser static_cast(n_blocks_per_col), static_cast(blob_size)}; if(args[1]->get_shape().lens() != expected_b_lens) - MIGRAPHX_THROW("Input B does not match expected dims TODO"); + MIGRAPHX_THROW("MatMulNBits: Input B does not match expected dims: " + + to_string_range(expected_b_lens) + + ". Actual dims: " + to_string_range(args[1]->get_shape().lens())); std::vector expected_scales_lens{static_cast(N * n_blocks_per_col)}; if(args[2]->get_shape().lens() != expected_scales_lens) - MIGRAPHX_THROW("Input Scales does not match expected dims TODO"); + MIGRAPHX_THROW("MatMulNBits: Input scales does not match expected dims: " + + to_string_range(expected_scales_lens) + + ". Actual dims: " + to_string_range(args[2]->get_shape().lens())); if(args.size() > 3) { std::vector expected_zp_lens{ static_cast(N * std::ceil(n_blocks_per_col * bits / 8.0f))}; if(args[3]->get_shape().lens() != expected_zp_lens) - MIGRAPHX_THROW("MatMulNBits: TODO"); + MIGRAPHX_THROW("MatMulNBits: Input zero_points does not match expected dims: " + + to_string_range(expected_zp_lens) + + ". Actual dims: " + to_string_range(args[3]->get_shape().lens())); } auto b = dequantize_b(info, N, K, block_size, args); - - b = info.add_instruction(make_op("transpose", {{"permutation", {1, 0}}}), b); - // Replace with proper matmul - return info.add_instruction(make_op("dot"), args[0], b); + b = info.add_instruction(make_op("transpose", {{"permutation", {1, 0}}}), b); + return matmul(info, args[0], b); } private: @@ -156,6 +160,47 @@ struct parse_matmulnbits : op_parser return x; } + + instruction_ref matmul(onnx_parser::node_info& info, instruction_ref a, instruction_ref b) const + { + bool is_a_prepended = false; + // B will always be a rank 2 matrix([N, K]), only need to check for A + if(a->get_shape().ndim() == 1) + { + is_a_prepended = true; + a = info.add_instruction(make_op("unsqueeze", {{"axes", {0}}}), a); + } + + auto a_lens = a->get_shape().lens(); + auto b_lens = b->get_shape().lens(); + if(not std::equal(a_lens.rbegin() + 2, a_lens.rend(), b_lens.rbegin() + 2, b_lens.rend())) + { + auto a_it = a_lens.begin() + a_lens.size() - 2; + std::vector a_broadcasted_lens(a_lens.begin(), a_it); + auto b_it = b_lens.begin() + b_lens.size() - 2; + std::vector b_broadcasted_lens(b_lens.begin(), b_it); + auto output_lens = compute_broadcasted_lens(a_broadcasted_lens, b_broadcasted_lens); + a_broadcasted_lens = output_lens; + a_broadcasted_lens.insert(a_broadcasted_lens.end(), a_it, a_lens.end()); + b_broadcasted_lens = output_lens; + b_broadcasted_lens.insert(b_broadcasted_lens.end(), b_it, b_lens.end()); + + if(a_lens != a_broadcasted_lens) + a = info.add_instruction( + make_op("multibroadcast", {{"out_lens", a_broadcasted_lens}}), a); + + if(b_lens != b_broadcasted_lens) + b = info.add_instruction( + make_op("multibroadcast", {{"out_lens", b_broadcasted_lens}}), b); + } + auto dot = info.add_instruction(make_op("dot"), a, b); + + if(is_a_prepended) + dot = info.add_instruction( + make_op("squeeze", {{"axes", {dot->get_shape().ndim() - 2}}}), dot); + + return dot; + } }; } // namespace onnx diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index 5555d218cb5..e313475014f 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -8567,7 +8567,7 @@ def qlinearmatmul_3D_test(): @onnx_test() -def matmulnbits_test(): +def matmulnbits_mm_test(): a = onnx.helper.make_tensor_value_info("a", onnx.TensorProto.FLOAT, [2, 16]) b = onnx.helper.make_tensor_value_info("b", onnx.TensorProto.UINT8, @@ -8589,7 +8589,7 @@ def matmulnbits_test(): @onnx_test() -def matmulnbits2_test(): +def matmulnbits_mm2_test(): a = onnx.helper.make_tensor_value_info("a", onnx.TensorProto.FLOAT, [2, 33]) b = onnx.helper.make_tensor_value_info("b", onnx.TensorProto.UINT8, @@ -8609,6 +8609,49 @@ def matmulnbits2_test(): return ([node], [a, b, scales], [c]) +@onnx_test() +def matmulnbits_vm_test(): + a = onnx.helper.make_tensor_value_info("a", onnx.TensorProto.FLOAT, [20]) + b = onnx.helper.make_tensor_value_info("b", onnx.TensorProto.UINT8, + [3, 2, 8]) + scales = onnx.helper.make_tensor_value_info("scales", + onnx.TensorProto.FLOAT, [6]) + zp = onnx.helper.make_tensor_value_info("zp", onnx.TensorProto.UINT8, [3]) + c = onnx.helper.make_tensor_value_info("c", onnx.TensorProto.FLOAT, [3]) + + node = onnx.helper.make_node("MatMulNBits", + inputs=["a", "b", "scales", "zp"], + outputs=["c"], + bits=4, + block_size=16, + K=20, + N=3, + domain='com.microsoft') + return ([node], [a, b, scales, zp], [c]) + + +@onnx_test() +def matmulnbits_bmm_test(): + a = onnx.helper.make_tensor_value_info("a", onnx.TensorProto.FLOAT, + [2, 3, 8]) + b = onnx.helper.make_tensor_value_info("b", onnx.TensorProto.UINT8, + [2, 1, 8]) + scales = onnx.helper.make_tensor_value_info("scales", + onnx.TensorProto.FLOAT, [2]) + c = onnx.helper.make_tensor_value_info("c", onnx.TensorProto.FLOAT, + [2, 3, 2]) + + node = onnx.helper.make_node("MatMulNBits", + inputs=["a", "b", "scales"], + outputs=["c"], + bits=4, + block_size=16, + K=8, + N=2, + domain='com.microsoft') + return ([node], [a, b, scales], [c]) + + @onnx_test() def qlinearmul_test(): a = helper.make_tensor_value_info('A', TensorProto.UINT8, [64]) diff --git a/test/onnx/matmulnbits_bmm_test.onnx b/test/onnx/matmulnbits_bmm_test.onnx new file mode 100644 index 00000000000..3e792bf7d5d --- /dev/null +++ b/test/onnx/matmulnbits_bmm_test.onnx @@ -0,0 +1,29 @@ + matmulnbits_bmm_test:Ú +a +a +b +scalesc" MatMulNBits* +K * +N * +bits * + +block_size : com.microsoftmatmulnbits_bmm_testZ +a + + + +Z +b + + + +Z +scales + + +b +c + + + +B \ No newline at end of file diff --git a/test/onnx/matmulnbits2_test.onnx b/test/onnx/matmulnbits_mm2_test.onnx similarity index 65% rename from test/onnx/matmulnbits2_test.onnx rename to test/onnx/matmulnbits_mm2_test.onnx index 872a4289910..2c2283f08c5 100644 --- a/test/onnx/matmulnbits2_test.onnx +++ b/test/onnx/matmulnbits_mm2_test.onnx @@ -1,4 +1,4 @@ - matmulnbits2_test:Ï + matmulnbits_mm2_test:Ò a a b @@ -7,7 +7,7 @@ a N * bits * -block_size : com.microsoftmatmulnbits2_testZ +block_size : com.microsoftmatmulnbits_mm2_testZ a   diff --git a/test/onnx/matmulnbits_test.onnx b/test/onnx/matmulnbits_mm_test.onnx similarity index 68% rename from test/onnx/matmulnbits_test.onnx rename to test/onnx/matmulnbits_mm_test.onnx index 2579e23b86f..14efed868cd 100644 --- a/test/onnx/matmulnbits_test.onnx +++ b/test/onnx/matmulnbits_mm_test.onnx @@ -1,4 +1,4 @@ - matmulnbits_test:ä + matmulnbits_mm_test:ç e a b @@ -8,7 +8,7 @@ e N * bits * -block_size : com.microsoftmatmulnbits_testZ +block_size : com.microsoftmatmulnbits_mm_testZ a   diff --git a/test/onnx/matmulnbits_vm_test.onnx b/test/onnx/matmulnbits_vm_test.onnx new file mode 100644 index 00000000000..e526131b288 --- /dev/null +++ b/test/onnx/matmulnbits_vm_test.onnx @@ -0,0 +1,32 @@ + matmulnbits_vm_test:ß +e +a +b +scales +zpc" MatMulNBits* +K * +N * +bits * + +block_size : com.microsoftmatmulnbits_vm_testZ +a + + +Z +b + + + +Z +scales + + +Z +zp + + +b +c + + +B \ No newline at end of file diff --git a/test/onnx/verify/matmulnbits_test.cpp b/test/onnx/verify/matmulnbits_test.cpp index d1268ca0f5c..4fa4c22768c 100644 --- a/test/onnx/verify/matmulnbits_test.cpp +++ b/test/onnx/verify/matmulnbits_test.cpp @@ -30,9 +30,9 @@ #include #include -TEST_CASE(matmulnbits_test) +TEST_CASE(matmulnbits_mm_test) { - auto p = optimize_onnx("matmulnbits_test.onnx"); + auto p = optimize_onnx("matmulnbits_mm_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::parameter_map pm; @@ -65,9 +65,9 @@ TEST_CASE(matmulnbits_test) EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); } -TEST_CASE(matmulnbits2_test) +TEST_CASE(matmulnbits_mm2_test) { - auto p = optimize_onnx("matmulnbits2_test.onnx"); + auto p = optimize_onnx("matmulnbits_mm2_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::parameter_map pm; @@ -80,7 +80,7 @@ TEST_CASE(matmulnbits2_test) 0.81454f, 0.15699f, 0.15464f, 0.17399f, 0.08090f, 0.99368f, 0.45535f, 0.92528f, 0.91968f, 0.76970f, 0.59638f, 0.23635f, 0.54877f, 0.96025f, 0.48969f, 0.55297f, 0.52498f, 0.29102f, 0.01359f, 0.77372f, 0.81897f, 0.03003f, 0.00822f, 0.55477f, 0.54635f, 0.91918f, 0.76486f, - 0.73698f, 0.29821f, 0.41801}; + 0.73698f, 0.29821f, 0.41801f}; pm["a"] = migraphx::argument(a_shape, a.data()); auto b_shape = migraphx::shape{migraphx::shape::uint8_type, {2, 3, 8}}; @@ -94,10 +94,6 @@ TEST_CASE(matmulnbits2_test) std::vector scales{0.29033, 0.80435, 2.60200, 2.39623, 1.40796, 2.38139}; pm["scales"] = migraphx::argument(scales_shape, scales.data()); - // auto zp_shape = migraphx::shape{migraphx::shape::uint8_type, {4}}; - // std::vector zp{0x08, 0x09, 0x0a, 0x0b}; - // pm["zp"] = migraphx::argument{zp_shape, zp.data()}; - auto result = p.eval(pm).back(); std::vector result_vector; result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); @@ -105,4 +101,87 @@ TEST_CASE(matmulnbits2_test) EXPECT(result.get_shape().lens() == std::vector{2, 2}); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); -} \ No newline at end of file +} + +TEST_CASE(matmulnbits_vm_test) +{ + auto p = optimize_onnx("matmulnbits_vm_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::parameter_map pm; + auto a_shape = migraphx::shape{migraphx::shape::float_type, {20}}; + std::vector a{0.10266f, 0.12772f, 0.10865f, 0.66181f, 0.49644f, 0.30307f, 0.11225f, + 0.65619f, 0.06290f, 0.29208f, 0.63246f, 0.22758f, 0.99302f, 0.09735f, + 0.68126f, 0.93334f, 0.90533f, 0.31082f, 0.58161f, 0.61385f}; + pm["a"] = migraphx::argument(a_shape, a.data()); + + auto b_shape = migraphx::shape{migraphx::shape::uint8_type, {3, 2, 8}}; + std::vector b{0xb7, 0x55, 0xfc, 0xc3, 0x66, 0xf9, 0x97, 0x83, 0xdd, 0x79, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0xcb, 0x52, 0xaf, 0x1d, 0x85, 0xbb, 0x64, 0x60, + 0x23, 0x42, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x38, 0xc6, 0xf7, 0x7a, + 0x68, 0xb1, 0x5, 0xc3, 0x37, 0xbb, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}; + pm["b"] = migraphx::argument(b_shape, b.data()); + + auto scales_shape = migraphx::shape{migraphx::shape::float_type, {6}}; + std::vector scales{3.74611f, 0.29444f, 0.29047f, 0.55739f, 3.94635f, 2.86177f}; + pm["scales"] = migraphx::argument(scales_shape, scales.data()); + + auto zp_shape = migraphx::shape{migraphx::shape::uint8_type, {3}}; + std::vector zp{0x43, 0x28, 0x65}; + pm["zp"] = migraphx::argument{zp_shape, zp.data()}; + + auto result = p.eval(pm).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + std::vector gold{131.22989f, -1.9659958f, 75.00621f}; + + EXPECT(result.get_shape().lens() == std::vector{3}); + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} + +TEST_CASE(matmulnbits_bmm_test) +{ + auto p = optimize_onnx("matmulnbits_bmm_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::parameter_map pm; + auto a_shape = migraphx::shape{migraphx::shape::float_type, {2, 3, 8}}; + std::vector a{0.01602f, 0.41420f, 0.97385f, 0.31764f, 0.40434f, 0.46265f, 0.93490f, + 0.16076f, 0.62340f, 0.39614f, 0.45347f, 0.98619f, 0.65113f, 0.56039f, + 0.33137f, 0.51959f, 0.70136f, 0.73935f, 0.95997f, 0.25623f, 0.26716f, + 0.27764f, 0.52128f, 0.55242f, 0.31295f, 0.54679f, 0.43674f, 0.21178f, + 0.99311f, 0.86172f, 0.10848f, 0.34330f, 0.36977f, 0.00948f, 0.93841f, + 0.88137f, 0.31069f, 0.39034f, 0.22825f, 0.29626f, 0.22664f, 0.51612f, + 0.39870f, 0.73411f, 0.07540f, 0.36283f, 0.62662f, 0.49075f}; + pm["a"] = migraphx::argument(a_shape, a.data()); + + auto b_shape = migraphx::shape{migraphx::shape::uint8_type, {2, 1, 8}}; + std::vector b{ + 0xed, 0xf8, 0xa0, 0xac, 0x0, 0x0, 0x0, 0x0, 0x34, 0xf7, 0x42, 0x1f, 0x0, 0x0, 0x0, 0x0}; + pm["b"] = migraphx::argument(b_shape, b.data()); + + auto scales_shape = migraphx::shape{migraphx::shape::float_type, {2}}; + std::vector scales{1.43507, 1.28074}; + pm["scales"] = migraphx::argument(scales_shape, scales.data()); + + auto result = p.eval(pm).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + std::vector gold{ + 9.386047f, + 0.32900935f, + 15.317321f, + -7.0316725f, + 16.28011f, + -11.014428f, + 1.7608745f, + -17.91667f, + 11.302611f, + -0.2521392f, + 18.625961f, + 0.38458022f, + }; + + EXPECT(result.get_shape().lens() == std::vector{2, 3, 2}); + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} From ebf438ef624e0a9ec858b402b9b3f0634359b27e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dino=20Musi=C4=87?= Date: Wed, 2 Oct 2024 10:51:42 +0000 Subject: [PATCH 4/7] Refactor matmulnbits parser, implement onnx parse tests --- src/onnx/parse_matmulnbits.cpp | 44 +++---- test/onnx/parse/matmulnbits_test.cpp | 176 +++++++++++++++++++++++++++ 2 files changed, 192 insertions(+), 28 deletions(-) create mode 100644 test/onnx/parse/matmulnbits_test.cpp diff --git a/src/onnx/parse_matmulnbits.cpp b/src/onnx/parse_matmulnbits.cpp index 29815580e3e..09447fae088 100644 --- a/src/onnx/parse_matmulnbits.cpp +++ b/src/onnx/parse_matmulnbits.cpp @@ -25,6 +25,7 @@ #include "migraphx/errors.hpp" #include "migraphx/instruction_ref.hpp" #include "migraphx/onnx/onnx_parser.hpp" +#include #include #include #include @@ -131,8 +132,8 @@ struct parse_matmulnbits : op_parser instruction_ref unpack(onnx_parser::node_info& info, instruction_ref x, int N, int dim1) const { - x = info.add_instruction(make_op("unpack_int4"), x); x = info.add_instruction(make_op("reshape", {{"dims", {N, -1}}}), x); + x = info.add_instruction(make_op("unpack_int4"), x); if(x->get_shape().lens()[1] > dim1) { x = info.add_instruction( @@ -163,39 +164,26 @@ struct parse_matmulnbits : op_parser instruction_ref matmul(onnx_parser::node_info& info, instruction_ref a, instruction_ref b) const { - bool is_a_prepended = false; - // B will always be a rank 2 matrix([N, K]), only need to check for A - if(a->get_shape().ndim() == 1) + const auto a_rank = a->get_shape().ndim(); + // B is always rank 2: + // If A is rank 1, unsqueeze A to make it rank 2 to prepare for dot + // If A is rank 2, just a regular dot + // If A is rank > 2, broadcast B to match outer dims of A to prepare for dot + if(a_rank == 1) { - is_a_prepended = true; - a = info.add_instruction(make_op("unsqueeze", {{"axes", {0}}}), a); + a = info.add_instruction(make_op("unsqueeze", {{"axes", {0}}}), a); } - - auto a_lens = a->get_shape().lens(); - auto b_lens = b->get_shape().lens(); - if(not std::equal(a_lens.rbegin() + 2, a_lens.rend(), b_lens.rbegin() + 2, b_lens.rend())) + else if(a_rank > 2) { - auto a_it = a_lens.begin() + a_lens.size() - 2; - std::vector a_broadcasted_lens(a_lens.begin(), a_it); - auto b_it = b_lens.begin() + b_lens.size() - 2; - std::vector b_broadcasted_lens(b_lens.begin(), b_it); - auto output_lens = compute_broadcasted_lens(a_broadcasted_lens, b_broadcasted_lens); - a_broadcasted_lens = output_lens; - a_broadcasted_lens.insert(a_broadcasted_lens.end(), a_it, a_lens.end()); - b_broadcasted_lens = output_lens; - b_broadcasted_lens.insert(b_broadcasted_lens.end(), b_it, b_lens.end()); - - if(a_lens != a_broadcasted_lens) - a = info.add_instruction( - make_op("multibroadcast", {{"out_lens", a_broadcasted_lens}}), a); - - if(b_lens != b_broadcasted_lens) - b = info.add_instruction( - make_op("multibroadcast", {{"out_lens", b_broadcasted_lens}}), b); + auto b_lens = b->get_shape().lens(); + auto b_bc_lens = a->get_shape().lens(); + std::copy(b_lens.begin(), b_lens.end(), b_bc_lens.end() - 2); + b = info.add_instruction(make_op("multibroadcast", {{"out_lens", b_bc_lens}}), b); } + auto dot = info.add_instruction(make_op("dot"), a, b); - if(is_a_prepended) + if(a_rank == 1) dot = info.add_instruction( make_op("squeeze", {{"axes", {dot->get_shape().ndim() - 2}}}), dot); diff --git a/test/onnx/parse/matmulnbits_test.cpp b/test/onnx/parse/matmulnbits_test.cpp new file mode 100644 index 00000000000..f4bed4f5868 --- /dev/null +++ b/test/onnx/parse/matmulnbits_test.cpp @@ -0,0 +1,176 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "migraphx/make_op.hpp" +#include + +TEST_CASE(matmulnbits_mm_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {2, 16}}); + auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::uint8_type, {4, 1, 8}}); + auto scales = mm->add_parameter("scales", migraphx::shape{migraphx::shape::float_type, {4}}); + auto zp = mm->add_parameter("zp", migraphx::shape{migraphx::shape::uint8_type, {4}}); + + scales = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4, -1}}}), scales); + scales = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), scales); + scales = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {4, 1, 16}}}), + scales); + scales = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4, -1}}}), scales); + + zp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4, -1}}}), zp); + zp = mm->add_instruction(migraphx::make_op("unpack_int4"), zp); + zp = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {1}}}), zp); + zp = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), zp); + zp = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {4, 1, 16}}}), zp); + zp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4, -1}}}), zp); + + b = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4, -1}}}), b); + b = mm->add_instruction(migraphx::make_op("unpack_int4"), b); + b = mm->add_instruction(migraphx::make_op("dequantizelinear"), b, scales, zp); + b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), b); + mm->add_instruction(migraphx::make_op("dot"), a, b); + + auto prog = optimize_onnx("matmulnbits_mm_test.onnx"); + + p.sort(); + prog.sort(); + EXPECT(p == prog); +} + +TEST_CASE(matmulnbits_mm2_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {2, 33}}); + auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::uint8_type, {2, 3, 8}}); + auto scales = mm->add_parameter("scales", migraphx::shape{migraphx::shape::float_type, {6}}); + + scales = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, -1}}}), scales); + scales = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), scales); + scales = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 16}}}), + scales); + scales = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, -1}}}), scales); + scales = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {33}}}), scales); + + auto zp = + mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::uint8_type, {1}}, {8}}); + zp = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 33}}}), zp); + + b = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, -1}}}), b); + b = mm->add_instruction(migraphx::make_op("unpack_int4"), b); + b = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {33}}}), b); + b = mm->add_instruction(migraphx::make_op("dequantizelinear"), b, scales, zp); + b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), b); + mm->add_instruction(migraphx::make_op("dot"), a, b); + + auto prog = optimize_onnx("matmulnbits_mm2_test.onnx"); + + p.sort(); + prog.sort(); + EXPECT(p == prog); +} + +TEST_CASE(matmulnbits_vm_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {20}}); + auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::uint8_type, {3, 2, 8}}); + auto scales = mm->add_parameter("scales", migraphx::shape{migraphx::shape::float_type, {6}}); + auto zp = mm->add_parameter("zp", migraphx::shape{migraphx::shape::uint8_type, {3}}); + + scales = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {3, -1}}}), scales); + scales = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), scales); + scales = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 2, 16}}}), + scales); + scales = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {3, -1}}}), scales); + scales = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {20}}}), scales); + + zp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {3, -1}}}), zp); + zp = mm->add_instruction(migraphx::make_op("unpack_int4"), zp); + zp = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), zp); + zp = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 2, 16}}}), zp); + zp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {3, -1}}}), zp); + zp = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {20}}}), zp); + + b = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {3, -1}}}), b); + b = mm->add_instruction(migraphx::make_op("unpack_int4"), b); + b = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {20}}}), b); + b = mm->add_instruction(migraphx::make_op("dequantizelinear"), b, scales, zp); + b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), b); + + a = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), a); + auto dot = mm->add_instruction(migraphx::make_op("dot"), a, b); + mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), dot); + + auto prog = optimize_onnx("matmulnbits_vm_test.onnx"); + + p.sort(); + prog.sort(); + EXPECT(p == prog); +} + +TEST_CASE(matmulnbits_bmm_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto a = mm->add_parameter("a", migraphx::shape{migraphx::shape::float_type, {2, 3, 8}}); + auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::uint8_type, {2, 1, 8}}); + auto scales = mm->add_parameter("scales", migraphx::shape{migraphx::shape::float_type, {2}}); + + scales = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, -1}}}), scales); + scales = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), scales); + scales = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 1, 16}}}), + scales); + scales = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, -1}}}), scales); + scales = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {8}}}), scales); + + auto zp = + mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::uint8_type, {1}}, {8}}); + zp = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 8}}}), zp); + + b = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, -1}}}), b); + b = mm->add_instruction(migraphx::make_op("unpack_int4"), b); + b = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {8}}}), b); + b = mm->add_instruction(migraphx::make_op("dequantizelinear"), b, scales, zp); + b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), b); + b = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 8, 2}}}), b); + mm->add_instruction(migraphx::make_op("dot"), a, b); + + auto prog = optimize_onnx("matmulnbits_bmm_test.onnx"); + + p.sort(); + prog.sort(); + EXPECT(p == prog); +} From cccb35a0254adccd533b1ed3530601628306a281 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dino=20Musi=C4=87?= Date: Wed, 2 Oct 2024 11:28:03 +0000 Subject: [PATCH 5/7] Implement negative parse tests for MatMulNBits --- test/onnx/gen_onnx.py | 58 +++++++++++++++++++ ...bits_block_size_not_power_of_two_test.onnx | 32 ++++++++++ ...matmulnbits_block_size_too_small_test.onnx | 32 ++++++++++ .../onnx/matmulnbits_invalid_b_dims_test.onnx | 32 ++++++++++ .../matmulnbits_invalid_bits_value_test.onnx | 32 ++++++++++ .../matmulnbits_invalid_scales_dims_test.onnx | 32 ++++++++++ .../matmulnbits_invalid_zp_dims_test.onnx | 32 ++++++++++ .../onnx/parse/matmulnbits_negative_tests.cpp | 55 ++++++++++++++++++ ...ulnbits_test.cpp => matmulnbits_tests.cpp} | 0 ...ulnbits_test.cpp => matmulnbits_tests.cpp} | 0 10 files changed, 305 insertions(+) create mode 100644 test/onnx/matmulnbits_block_size_not_power_of_two_test.onnx create mode 100644 test/onnx/matmulnbits_block_size_too_small_test.onnx create mode 100644 test/onnx/matmulnbits_invalid_b_dims_test.onnx create mode 100644 test/onnx/matmulnbits_invalid_bits_value_test.onnx create mode 100644 test/onnx/matmulnbits_invalid_scales_dims_test.onnx create mode 100644 test/onnx/matmulnbits_invalid_zp_dims_test.onnx create mode 100644 test/onnx/parse/matmulnbits_negative_tests.cpp rename test/onnx/parse/{matmulnbits_test.cpp => matmulnbits_tests.cpp} (100%) rename test/onnx/verify/{matmulnbits_test.cpp => matmulnbits_tests.cpp} (100%) diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index e313475014f..10dfa847e1d 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -8652,6 +8652,64 @@ def matmulnbits_bmm_test(): return ([node], [a, b, scales], [c]) +def matmulnbits_negative_test(bits=4, + block_size=16, + a_dims=[2, 16], + b_dims=[4, 1, 8], + scales_dims=[4], + zp_dims=[4], + out_dims=[2, 4]): + a = onnx.helper.make_tensor_value_info("a", onnx.TensorProto.FLOAT, a_dims) + b = onnx.helper.make_tensor_value_info("b", onnx.TensorProto.UINT8, b_dims) + scales = onnx.helper.make_tensor_value_info("scales", + onnx.TensorProto.FLOAT, + scales_dims) + zp = onnx.helper.make_tensor_value_info("zp", onnx.TensorProto.UINT8, + zp_dims) + c = onnx.helper.make_tensor_value_info("c", onnx.TensorProto.FLOAT, + out_dims) + + node = onnx.helper.make_node("MatMulNBits", + inputs=["a", "b", "scales", "zp"], + outputs=["c"], + bits=bits, + block_size=block_size, + K=16, + N=4, + domain='com.microsoft') + return ([node], [a, b, scales, zp], [c]) + + +@onnx_test() +def matmulnbits_invalid_bits_value_test(): + return matmulnbits_negative_test(bits=5) + + +@onnx_test() +def matmulnbits_block_size_too_small_test(): + return matmulnbits_negative_test(block_size=8) + + +@onnx_test() +def matmulnbits_block_size_not_power_of_two_test(): + return matmulnbits_negative_test(block_size=20) + + +@onnx_test() +def matmulnbits_invalid_b_dims_test(): + return matmulnbits_negative_test(b_dims=[4, 2, 8]) + + +@onnx_test() +def matmulnbits_invalid_scales_dims_test(): + return matmulnbits_negative_test(scales_dims=[3]) + + +@onnx_test() +def matmulnbits_invalid_zp_dims_test(): + return matmulnbits_negative_test(zp_dims=[5]) + + @onnx_test() def qlinearmul_test(): a = helper.make_tensor_value_info('A', TensorProto.UINT8, [64]) diff --git a/test/onnx/matmulnbits_block_size_not_power_of_two_test.onnx b/test/onnx/matmulnbits_block_size_not_power_of_two_test.onnx new file mode 100644 index 00000000000..52dfbe82146 --- /dev/null +++ b/test/onnx/matmulnbits_block_size_not_power_of_two_test.onnx @@ -0,0 +1,32 @@ + ,matmulnbits_block_size_not_power_of_two_test:€ +e +a +b +scales +zpc" MatMulNBits* +K * +N * +bits * + +block_size : com.microsoft,matmulnbits_block_size_not_power_of_two_testZ +a +  + +Z +b + + + +Z +scales + + +Z +zp + + +b +c +  + +B \ No newline at end of file diff --git a/test/onnx/matmulnbits_block_size_too_small_test.onnx b/test/onnx/matmulnbits_block_size_too_small_test.onnx new file mode 100644 index 00000000000..e4fc07fb3de --- /dev/null +++ b/test/onnx/matmulnbits_block_size_too_small_test.onnx @@ -0,0 +1,32 @@ + %matmulnbits_block_size_too_small_test:ù +e +a +b +scales +zpc" MatMulNBits* +K * +N * +bits * + +block_size : com.microsoft%matmulnbits_block_size_too_small_testZ +a +  + +Z +b + + + +Z +scales + + +Z +zp + + +b +c +  + +B \ No newline at end of file diff --git a/test/onnx/matmulnbits_invalid_b_dims_test.onnx b/test/onnx/matmulnbits_invalid_b_dims_test.onnx new file mode 100644 index 00000000000..8e33dafc2e9 --- /dev/null +++ b/test/onnx/matmulnbits_invalid_b_dims_test.onnx @@ -0,0 +1,32 @@ + matmulnbits_invalid_b_dims_test:ó +e +a +b +scales +zpc" MatMulNBits* +K * +N * +bits * + +block_size : com.microsoftmatmulnbits_invalid_b_dims_testZ +a +  + +Z +b + + + +Z +scales + + +Z +zp + + +b +c +  + +B \ No newline at end of file diff --git a/test/onnx/matmulnbits_invalid_bits_value_test.onnx b/test/onnx/matmulnbits_invalid_bits_value_test.onnx new file mode 100644 index 00000000000..e084c8002a5 --- /dev/null +++ b/test/onnx/matmulnbits_invalid_bits_value_test.onnx @@ -0,0 +1,32 @@ + #matmulnbits_invalid_bits_value_test:÷ +e +a +b +scales +zpc" MatMulNBits* +K * +N * +bits * + +block_size : com.microsoft#matmulnbits_invalid_bits_value_testZ +a +  + +Z +b + + + +Z +scales + + +Z +zp + + +b +c +  + +B \ No newline at end of file diff --git a/test/onnx/matmulnbits_invalid_scales_dims_test.onnx b/test/onnx/matmulnbits_invalid_scales_dims_test.onnx new file mode 100644 index 00000000000..675e3cfc26f --- /dev/null +++ b/test/onnx/matmulnbits_invalid_scales_dims_test.onnx @@ -0,0 +1,32 @@ + $matmulnbits_invalid_scales_dims_test:ø +e +a +b +scales +zpc" MatMulNBits* +K * +N * +bits * + +block_size : com.microsoft$matmulnbits_invalid_scales_dims_testZ +a +  + +Z +b + + + +Z +scales + + +Z +zp + + +b +c +  + +B \ No newline at end of file diff --git a/test/onnx/matmulnbits_invalid_zp_dims_test.onnx b/test/onnx/matmulnbits_invalid_zp_dims_test.onnx new file mode 100644 index 00000000000..63242a15592 --- /dev/null +++ b/test/onnx/matmulnbits_invalid_zp_dims_test.onnx @@ -0,0 +1,32 @@ +  matmulnbits_invalid_zp_dims_test:ô +e +a +b +scales +zpc" MatMulNBits* +K * +N * +bits * + +block_size : com.microsoft matmulnbits_invalid_zp_dims_testZ +a +  + +Z +b + + + +Z +scales + + +Z +zp + + +b +c +  + +B \ No newline at end of file diff --git a/test/onnx/parse/matmulnbits_negative_tests.cpp b/test/onnx/parse/matmulnbits_negative_tests.cpp new file mode 100644 index 00000000000..e5e8ad44336 --- /dev/null +++ b/test/onnx/parse/matmulnbits_negative_tests.cpp @@ -0,0 +1,55 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include + +TEST_CASE(matmulnbits_invalid_bits_value_test) +{ + EXPECT(test::throws([&] { read_onnx("matmulnbits_invalid_bits_value_test.onnx"); })); +} + +TEST_CASE(matmulnbits_block_size_too_small_test) +{ + EXPECT(test::throws([&] { read_onnx("matmulnbits_block_size_too_small_test.onnx"); })); +} + +TEST_CASE(matmulnbits_block_size_not_power_of_two_test) +{ + EXPECT(test::throws([&] { read_onnx("matmulnbits_block_size_not_power_of_two_test.onnx"); })); +} + +TEST_CASE(matmulnbits_invalid_b_dims_test) +{ + EXPECT(test::throws([&] { read_onnx("matmulnbits_invalid_b_dims_test.onnx"); })); +} + +TEST_CASE(matmulnbits_invalid_scales_dims_test) +{ + EXPECT(test::throws([&] { read_onnx("matmulnbits_invalid_scales_dims_test.onnx"); })); +} + +TEST_CASE(matmulnbits_invalid_zp_dims_test) +{ + EXPECT(test::throws([&] { read_onnx("matmulnbits_invalid_zp_dims_test.onnx"); })); +} diff --git a/test/onnx/parse/matmulnbits_test.cpp b/test/onnx/parse/matmulnbits_tests.cpp similarity index 100% rename from test/onnx/parse/matmulnbits_test.cpp rename to test/onnx/parse/matmulnbits_tests.cpp diff --git a/test/onnx/verify/matmulnbits_test.cpp b/test/onnx/verify/matmulnbits_tests.cpp similarity index 100% rename from test/onnx/verify/matmulnbits_test.cpp rename to test/onnx/verify/matmulnbits_tests.cpp From 879176b48a6cd6ec3f9956a14c2c490ed523254c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dino=20Musi=C4=87?= Date: Wed, 2 Oct 2024 12:08:57 +0000 Subject: [PATCH 6/7] Offload unpack_int4 to cpu during gpu lowering --- src/onnx/parse_matmulnbits.cpp | 5 ++--- src/targets/gpu/lowering.cpp | 21 +++++++++++++++++++++ 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/src/onnx/parse_matmulnbits.cpp b/src/onnx/parse_matmulnbits.cpp index 09447fae088..52c1a64d25c 100644 --- a/src/onnx/parse_matmulnbits.cpp +++ b/src/onnx/parse_matmulnbits.cpp @@ -54,9 +54,8 @@ struct parse_matmulnbits : op_parser MIGRAPHX_THROW("MatMulNBits: bits only supported for value of 4, actual value " + std::to_string(bits)); - if(block_size < 16 and (block_size & (block_size - 1)) != 0) - MIGRAPHX_THROW("MatMulNBits: block_size must be a power of 2 and greater or equal to " - "16, actual value " + + if(block_size < 16 or (block_size & (block_size - 1)) != 0) + MIGRAPHX_THROW("MatMulNBits: block_size must be a power of 2 and >=16, actual value " + std::to_string(block_size)); int n_blocks_per_col = (K + block_size - 1) / block_size; diff --git a/src/targets/gpu/lowering.cpp b/src/targets/gpu/lowering.cpp index 967a98ba936..2dbe32ccba2 100644 --- a/src/targets/gpu/lowering.cpp +++ b/src/targets/gpu/lowering.cpp @@ -122,6 +122,7 @@ struct miopen_apply add_select_module_op(); add_reshape_lazy_op(); add_scan_slice_op(); + add_unpack_int4_op(); } void copy_params() const @@ -509,6 +510,26 @@ struct miopen_apply ins, mod->insert_instruction(ins, ins->get_operator(), inputs)); }); } + + void add_unpack_int4_op() + { + apply_map.emplace("unpack_int4", [=](instruction_ref ins) { + auto inputs = ins->inputs(); + auto output = insert_allocation(ins, ins->get_shape()); + std::vector cpu_inputs; + auto gpu_inputs = ins->inputs(); + std::transform( + gpu_inputs.begin(), gpu_inputs.end(), std::back_inserter(cpu_inputs), [&](auto in) { + return mod->insert_instruction(ins, make_op("hip::copy_from_gpu"), in); + }); + cpu_inputs.front() = + mod->insert_instruction(ins, make_op("hip::sync_stream"), cpu_inputs); + auto cpu_out = mod->insert_instruction(ins, ins->get_operator(), cpu_inputs); + auto gpu_out = + mod->insert_instruction(ins, make_op("hip::copy_to_gpu"), cpu_out, output); + return mod->replace_instruction(ins, gpu_out); + }); + } }; void lowering::apply(module_pass_manager& mpm) const From ffc530467a11b95f90bdf4792e456f408c0857c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dino=20Musi=C4=87?= Date: Wed, 2 Oct 2024 20:14:21 +0000 Subject: [PATCH 7/7] Fix clang-tidy issues, change type for attribute variables --- src/onnx/parse_matmulnbits.cpp | 51 ++++++++++++++++------------------ 1 file changed, 24 insertions(+), 27 deletions(-) diff --git a/src/onnx/parse_matmulnbits.cpp b/src/onnx/parse_matmulnbits.cpp index 52c1a64d25c..af9f09790aa 100644 --- a/src/onnx/parse_matmulnbits.cpp +++ b/src/onnx/parse_matmulnbits.cpp @@ -21,7 +21,6 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ -#include "migraphx/common.hpp" #include "migraphx/errors.hpp" #include "migraphx/instruction_ref.hpp" #include "migraphx/onnx/onnx_parser.hpp" @@ -45,10 +44,10 @@ struct parse_matmulnbits : op_parser onnx_parser::node_info info, const std::vector& args) const { - const int N = parse_attribute(parser, info, "N"); - const int K = parse_attribute(parser, info, "K"); - const int bits = parse_attribute(parser, info, "bits"); - const int block_size = parse_attribute(parser, info, "block_size"); + const size_t n = parse_attribute(parser, info, "N"); + const size_t k = parse_attribute(parser, info, "K"); + const size_t bits = parse_attribute(parser, info, "bits"); + const size_t block_size = parse_attribute(parser, info, "block_size"); if(bits != 4) MIGRAPHX_THROW("MatMulNBits: bits only supported for value of 4, actual value " + @@ -58,18 +57,16 @@ struct parse_matmulnbits : op_parser MIGRAPHX_THROW("MatMulNBits: block_size must be a power of 2 and >=16, actual value " + std::to_string(block_size)); - int n_blocks_per_col = (K + block_size - 1) / block_size; - int blob_size = std::ceil(block_size * bits / 8.0f); + const size_t n_blocks_per_col = (k + block_size - 1) / block_size; + const size_t blob_size = std::ceil(block_size * bits / 8.0f); - std::vector expected_b_lens{static_cast(N), - static_cast(n_blocks_per_col), - static_cast(blob_size)}; + std::vector expected_b_lens{n, n_blocks_per_col, blob_size}; if(args[1]->get_shape().lens() != expected_b_lens) MIGRAPHX_THROW("MatMulNBits: Input B does not match expected dims: " + to_string_range(expected_b_lens) + ". Actual dims: " + to_string_range(args[1]->get_shape().lens())); - std::vector expected_scales_lens{static_cast(N * n_blocks_per_col)}; + std::vector expected_scales_lens{n * n_blocks_per_col}; if(args[2]->get_shape().lens() != expected_scales_lens) MIGRAPHX_THROW("MatMulNBits: Input scales does not match expected dims: " + to_string_range(expected_scales_lens) + @@ -78,14 +75,14 @@ struct parse_matmulnbits : op_parser if(args.size() > 3) { std::vector expected_zp_lens{ - static_cast(N * std::ceil(n_blocks_per_col * bits / 8.0f))}; + static_cast(n * std::ceil(n_blocks_per_col * bits / 8.0f))}; if(args[3]->get_shape().lens() != expected_zp_lens) MIGRAPHX_THROW("MatMulNBits: Input zero_points does not match expected dims: " + to_string_range(expected_zp_lens) + ". Actual dims: " + to_string_range(args[3]->get_shape().lens())); } - auto b = dequantize_b(info, N, K, block_size, args); + auto b = dequantize_b(info, n, k, block_size, args); b = info.add_instruction(make_op("transpose", {{"permutation", {1, 0}}}), b); return matmul(info, args[0], b); } @@ -103,22 +100,22 @@ struct parse_matmulnbits : op_parser } instruction_ref dequantize_b(onnx_parser::node_info& info, - int N, - int K, + int n, + int k, int block_size, const std::vector& args) const { - auto b = unpack(info, args[1], N, K); + auto b = unpack(info, n, k, args[1]); - auto n_blocks_per_col = (K + block_size - 1) / block_size; - auto scales = info.add_instruction(make_op("reshape", {{"dims", {N, -1}}}), args[2]); - scales = prepare_blockwise_dq_arg(info, scales, N, K, block_size); + auto n_blocks_per_col = (k + block_size - 1) / block_size; + auto scales = info.add_instruction(make_op("reshape", {{"dims", {n, -1}}}), args[2]); + scales = prepare_blockwise_dq_arg(info, n, k, block_size, scales); instruction_ref zp; if(args.size() == 4) { - zp = unpack(info, args[3], N, n_blocks_per_col); - zp = prepare_blockwise_dq_arg(info, zp, N, K, block_size); + zp = unpack(info, n, n_blocks_per_col, args[3]); + zp = prepare_blockwise_dq_arg(info, n, k, block_size, zp); } else { @@ -129,9 +126,9 @@ struct parse_matmulnbits : op_parser return info.add_instruction(make_op("dequantizelinear"), {b, scales, zp}); } - instruction_ref unpack(onnx_parser::node_info& info, instruction_ref x, int N, int dim1) const + instruction_ref unpack(onnx_parser::node_info& info, int n, int dim1, instruction_ref x) const { - x = info.add_instruction(make_op("reshape", {{"dims", {N, -1}}}), x); + x = info.add_instruction(make_op("reshape", {{"dims", {n, -1}}}), x); x = info.add_instruction(make_op("unpack_int4"), x); if(x->get_shape().lens()[1] > dim1) { @@ -142,20 +139,20 @@ struct parse_matmulnbits : op_parser } instruction_ref prepare_blockwise_dq_arg( - onnx_parser::node_info& info, instruction_ref x, int N, int K, int block_size) const + onnx_parser::node_info& info, int n, int k, int block_size, instruction_ref x) const { x = info.add_instruction(make_op("unsqueeze", {{"axes", {2}}}), x); auto bc_lens = x->get_shape().lens(); bc_lens[2] = block_size; x = info.add_instruction(make_op("multibroadcast", {{"out_lens", bc_lens}}), x); - x = info.add_instruction(make_op("reshape", {{"dims", {N, -1}}}), x); + x = info.add_instruction(make_op("reshape", {{"dims", {n, -1}}}), x); // Detect runt block - if(x->get_shape().lens()[1] > K) + if(x->get_shape().lens()[1] > k) { x = info.add_instruction( - make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {K}}}), x); + make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {k}}}), x); } return x;