Skip to content

Commit

Permalink
use type of A to compute, rather than hardcode fp16
Browse files Browse the repository at this point in the history
  • Loading branch information
bopeng1234 committed Feb 17, 2025
1 parent 6695dde commit 44b2e0f
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions src/frontends/onnx/frontend/src/op/com.microsoft/matmulnbits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,12 +169,12 @@ ov::OutputVector matmulnbits(const ov::frontend::onnx::Node& node) {
// use fp16 for compute

// convert b to fp16
auto converted_b = std::make_shared<v0::Convert>(casted_b, ov::element::f16);
auto converted_zero_points = std::make_shared<v0::Convert>(zero_points, ov::element::f16);
auto converted_b = std::make_shared<v0::Convert>(casted_b, a.get_element_type());
auto converted_zero_points = std::make_shared<v0::Convert>(zero_points, a.get_element_type());

// sub and scale
const auto sub_b = std::make_shared<v1::Subtract>(converted_b, converted_zero_points);
const auto scales_fp16 = std::make_shared<v0::Convert>(scales, ov::element::f16);
const auto scales_fp16 = std::make_shared<v0::Convert>(scales, a.get_element_type());
const auto scales_reshaped =
op::util::reshape(scales_fp16, ov::Shape{static_cast<size_t>(N), static_cast<size_t>(n_blocks_per_col), 1});
const auto scaled_b = std::make_shared<v1::Multiply>(sub_b, scales_reshaped);
Expand All @@ -198,9 +198,7 @@ ov::OutputVector matmulnbits(const ov::frontend::onnx::Node& node) {
}

// mm = matmul(a,b)
auto a_fp16 = std::make_shared<v0::Convert>(a, ov::element::f16);
auto results = std::make_shared<v0::MatMul>(a_fp16, b, false, true);
mm_output = std::make_shared<v0::Convert>(results, a.get_element_type());
mm_output = std::make_shared<v0::MatMul>(a, b, false, true);
}

if (bias.get_node_shared_ptr()) {
Expand All @@ -216,4 +214,4 @@ ONNX_OP("MatMulNBits", OPSET_SINCE(1), com_microsoft::opset_1::matmulnbits, MICR
} // namespace com_microsoft
} // namespace onnx
} // namespace frontend
} // namespace ov
} // namespace ov

0 comments on commit 44b2e0f

Please sign in to comment.