From 44b2e0fe7bc5c63cc018917aeaffe4a60cf7df8a Mon Sep 17 00:00:00 2001 From: "Peng, Bo" Date: Mon, 17 Feb 2025 15:25:35 +0800 Subject: [PATCH] use type of A to compute, rather than hardcode fp16 --- .../frontend/src/op/com.microsoft/matmulnbits.cpp | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/frontends/onnx/frontend/src/op/com.microsoft/matmulnbits.cpp b/src/frontends/onnx/frontend/src/op/com.microsoft/matmulnbits.cpp index 8f6cf5d4a9c853..4b401bbd2f9037 100644 --- a/src/frontends/onnx/frontend/src/op/com.microsoft/matmulnbits.cpp +++ b/src/frontends/onnx/frontend/src/op/com.microsoft/matmulnbits.cpp @@ -169,12 +169,12 @@ ov::OutputVector matmulnbits(const ov::frontend::onnx::Node& node) { // use fp16 for compute // convert b to fp16 - auto converted_b = std::make_shared(casted_b, ov::element::f16); - auto converted_zero_points = std::make_shared(zero_points, ov::element::f16); + auto converted_b = std::make_shared(casted_b, a.get_element_type()); + auto converted_zero_points = std::make_shared(zero_points, a.get_element_type()); // sub and scale const auto sub_b = std::make_shared(converted_b, converted_zero_points); - const auto scales_fp16 = std::make_shared(scales, ov::element::f16); + const auto scales_fp16 = std::make_shared(scales, a.get_element_type()); const auto scales_reshaped = op::util::reshape(scales_fp16, ov::Shape{static_cast(N), static_cast(n_blocks_per_col), 1}); const auto scaled_b = std::make_shared(sub_b, scales_reshaped); @@ -198,9 +198,7 @@ ov::OutputVector matmulnbits(const ov::frontend::onnx::Node& node) { } // mm = matmul(a,b) - auto a_fp16 = std::make_shared(a, ov::element::f16); - auto results = std::make_shared(a_fp16, b, false, true); - mm_output = std::make_shared(results, a.get_element_type()); + mm_output = std::make_shared(a, b, false, true); } if (bias.get_node_shared_ptr()) { @@ -216,4 +214,4 @@ ONNX_OP("MatMulNBits", OPSET_SINCE(1), com_microsoft::opset_1::matmulnbits, MICR } // namespace com_microsoft } // namespace onnx } // namespace frontend -} // namespace ov +} // namespace ov \ No newline at end of file