diff --git a/src/simplify_algebra.cpp b/src/simplify_algebra.cpp index e79d4d48628..655ed0d7a0c 100644 --- a/src/simplify_algebra.cpp +++ b/src/simplify_algebra.cpp @@ -911,7 +911,7 @@ struct find_concat_op static bool is_valid_op(const operation& op) { return contains({"broadcast", "multibroadcast", "unpack_int4"}, op.name()) or - op.attributes().contains("pointwise"); + (op.attributes().contains("pointwise") and op.name() != "quantizelinear"); } static bool is_valid_concat(std::vector ins, size_t axis) diff --git a/src/simplify_qdq.cpp b/src/simplify_qdq.cpp index bd21564b618..e347cfd58db 100644 --- a/src/simplify_qdq.cpp +++ b/src/simplify_qdq.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -34,6 +34,7 @@ #include #include #include +#include #include #include #include @@ -348,6 +349,54 @@ struct match_qlinear_reused } }; +struct match_concat_qlinear +{ + auto matcher() const + { + auto any_pointwise_input = match::any_of[match::inputs()](match::pointwise()); + return match::name("quantizelinear")(match::arg(0)( + match::name("concat")(match::used_once(), any_pointwise_input).bind("cat"))); + } + auto get_slices(instruction_ref cat_ins) const + { + std::vector>> slices; + auto axis = any_cast(cat_ins->get_operator()).axis; + size_t start = 0; + for(auto cat_inp : cat_ins->inputs()) + { + auto end = start + cat_inp->get_shape().lens()[axis]; + slices.push_back({{"axes", {axis}}, {"starts", {start}}, {"ends", {end}}}); + start = end; + } + return slices; + } + + void apply(module& m, const match::matcher_result& r) const + { + auto ins = r.result; + auto cat_ins = r.instructions["cat"]; + + assert(ins->inputs().size() == 3); + auto scale = ins->inputs()[1]; + auto zp = ins->inputs()[2]; + + auto slices = get_slices(cat_ins); + std::vector new_cat_inputs; + std::transform( + cat_ins->inputs().begin(), + cat_ins->inputs().end(), + slices.begin(), + std::back_inserter(new_cat_inputs), + [&](auto i, auto slc) { + auto scale_slc = m.insert_instruction(ins, make_op("slice", slc), {scale}); + auto zp_slc = m.insert_instruction(ins, make_op("slice", slc), {zp}); + return m.insert_instruction(ins, ins->get_operator(), {i, scale_slc, zp_slc}); + }); + + m.replace_instruction(ins, cat_ins->get_operator(), new_cat_inputs); + } +}; + bool is_same_value(instruction_ref a, instruction_ref b) { if(a == b) @@ -456,6 +505,8 @@ void simplify_qdq::apply(module& m) const migraphx::run_passes(m, {migraphx::dead_code_elimination{}}); match::find_matches(m, match_qlinear_reused{}); migraphx::run_passes(m, {migraphx::dead_code_elimination{}}); + match::find_matches(m, match_concat_qlinear{}); + migraphx::run_passes(m, {migraphx::dead_code_elimination{}}); remove_zero_point(m); } diff --git a/test/simplify_qdq_test.cpp b/test/simplify_qdq_test.cpp index cef500fbfd3..c077ff5f815 100644 --- a/test/simplify_qdq_test.cpp +++ b/test/simplify_qdq_test.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -1539,4 +1539,103 @@ TEST_CASE(int4_simplify_qdq_pass_test) EXPECT(migraphx::contains(res_1.instructions, "q")); } +TEST_CASE(pointwise_concat_quant_per_tensor) +{ + migraphx::shape s1{migraphx::shape::float_type, {1, 4, 28, 28}}; + migraphx::shape s2{migraphx::shape::float_type, {1, 2, 28, 28}}; + + migraphx::module m1; + { + auto i1 = m1.add_parameter("i1", s1); + auto i2 = m1.add_parameter("i2", s2); + auto scale = m1.add_literal(0.5f); + auto zero = m1.add_literal(std::int8_t{0}); + + auto relu = m1.add_instruction(migraphx::make_op("relu"), i2); + auto cat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), i1, relu); + auto q = add_quantize_op(m1, "quantizelinear", cat, scale, zero); + m1.add_return({q}); + } + + migraphx::module m2; + { + std::vector cat_lens{1, 6, 28, 28}; + auto i1 = m2.add_parameter("i1", s1); + auto i2 = m2.add_parameter("i2", s2); + auto scale = m2.add_literal(0.5f); + auto zero = m2.add_literal(std::int8_t{0}); + + auto relu = m2.add_instruction(migraphx::make_op("relu"), i2); + auto scale_mb = broadcast_scale(m2, scale, cat_lens, 1); + auto zero_mb = broadcast_shift(m2, zero, cat_lens); + + auto sc1 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {4}}}), scale_mb); + auto zp1 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {4}}}), zero_mb); + auto q1 = add_quantize_op(m2, "quantizelinear", i1, sc1, zp1); + + auto sc2 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {4}}, {"ends", {6}}}), scale_mb); + auto zp2 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {4}}, {"ends", {6}}}), zero_mb); + auto q2 = add_quantize_op(m2, "quantizelinear", relu, sc2, zp2); + + auto cat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), q1, q2); + m2.add_return({cat}); + } + run_pass(m1); + EXPECT(m1 == m2); +} + +TEST_CASE(pointwise_concat_quant_per_channel) +{ + migraphx::shape s1{migraphx::shape::float_type, {1, 4, 28, 28}}; + migraphx::shape s2{migraphx::shape::float_type, {1, 2, 28, 28}}; + migraphx::shape s3{migraphx::shape::float_type, {6}}; + + migraphx::module m1; + { + auto i1 = m1.add_parameter("i1", s1); + auto i2 = m1.add_parameter("i2", s2); + auto scale = m1.add_literal(migraphx::generate_literal(s3, 0)); + auto zero = m1.add_literal(std::int8_t{0}); + + auto relu = m1.add_instruction(migraphx::make_op("relu"), i2); + auto cat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), i1, relu); + auto q = add_quantize_op(m1, "quantizelinear", cat, scale, zero); + m1.add_return({q}); + } + + migraphx::module m2; + { + std::vector cat_lens{1, 6, 28, 28}; + auto i1 = m2.add_parameter("i1", s1); + auto i2 = m2.add_parameter("i2", s2); + auto scale = m2.add_literal(migraphx::generate_literal(s3, 0)); + auto zero = m2.add_literal(std::int8_t{0}); + + auto relu = m2.add_instruction(migraphx::make_op("relu"), i2); + auto scale_mb = broadcast_scale(m2, scale, cat_lens, 1); + auto zero_mb = broadcast_shift(m2, zero, cat_lens); + + auto sc1 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {4}}}), scale_mb); + auto zp1 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {4}}}), zero_mb); + auto q1 = add_quantize_op(m2, "quantizelinear", i1, sc1, zp1); + + auto sc2 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {4}}, {"ends", {6}}}), scale_mb); + auto zp2 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {4}}, {"ends", {6}}}), zero_mb); + auto q2 = add_quantize_op(m2, "quantizelinear", relu, sc2, zp2); + + auto cat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), q1, q2); + m2.add_return({cat}); + } + run_pass(m1); + EXPECT(m1 == m2); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); }