diff --git a/src/utility/include/mp-units/math.h b/src/utility/include/mp-units/math.h index 1a6aaf25b..67d7abb83 100644 --- a/src/utility/include/mp-units/math.h +++ b/src/utility/include/mp-units/math.h @@ -136,6 +136,37 @@ template return {static_cast(abs(q.numerical_value_ref_in(q.unit))), R}; } +/** + * @brief Computes the fma of 3 quantities + * + * @param a: Multiplicand + * @param x: Multiplicand + * @param b: Addend + * @return Quantity: The nearest floating point representable to ax+b + */ +template +[[nodiscard]] constexpr QuantityOf auto +fma(const quantity& a, const quantity& x, const quantity& b) noexcept + requires requires { common_quantity_spec(get_quantity_spec(R) * get_quantity_spec(S), get_quantity_spec(T)); } && + (get_unit(R) * get_unit(S) == get_unit(T)) && + ( + requires { + fma(a.numerical_value_ref_in(a.unit), x.numerical_value_ref_in(x.unit), + b.numerical_value_ref_in(b.unit)); + } || + requires { + std::fma(a.numerical_value_ref_in(a.unit), x.numerical_value_ref_in(x.unit), + b.numerical_value_ref_in(b.unit)); + }) +{ + using std::fma; + return quantity{ + fma(a.numerical_value_ref_in(a.unit), x.numerical_value_ref_in(x.unit), b.numerical_value_ref_in(b.unit)), + common_reference(R * S, T)}; +} + + /** * @brief Returns the epsilon of the quantity * diff --git a/test/unit_test/runtime/math_test.cpp b/test/unit_test/runtime/math_test.cpp index e468225f4..714f71e34 100644 --- a/test/unit_test/runtime/math_test.cpp +++ b/test/unit_test/runtime/math_test.cpp @@ -62,6 +62,17 @@ TEST_CASE("'cbrt()' on quantity changes the value and the dimension accordingly" REQUIRE(cbrt(8 * isq::volume[m3]) == 2 * isq::length[m]); } +TEST_CASE("'fma()' on quantity changes the value and the dimension accordingly", "[math][fma]") +{ + REQUIRE(fma(1.0 * isq::length[m], 2.0 * one, 2.0 * isq::length[m]) == 4.0 * isq::length[m]); +} + +TEST_CASE("'fma()' returns a common reference.", "[math][fma]") +{ + REQUIRE(fma(isq::speed(10.0 * m / s), isq::time(2.0 * s), isq::height(42.0 * m)) == isq::length(62.0 * m)); +} + + TEST_CASE("'pow()' on quantity changes the value and the dimension accordingly", "[math][pow]") { REQUIRE(pow<1, 4>(16 * isq::area[m2]) == sqrt(4 * isq::length[m])); diff --git a/test/unit_test/static/math_test.cpp b/test/unit_test/static/math_test.cpp index 980f68e08..4874aabbc 100644 --- a/test/unit_test/static/math_test.cpp +++ b/test/unit_test/static/math_test.cpp @@ -40,6 +40,10 @@ template #if __cpp_lib_constexpr_cmath || MP_UNITS_COMP_GCC +static_assert(compare(fma(2.0 * s, 3.0 * Hz, 1.0 * one), 7.0 * one)); +static_assert(compare(fma(2.0 * one, 3.0 * m, 1.0 * m), 7.0 * m)); +static_assert(compare(fma(2.0 * m, 3.0 * one, 1.0 * m), 7.0 * m)); +static_assert(compare(fma(2 * m, 3.0f * m, 1.0 * m2), 7.0 * m2)); static_assert(compare(pow<0>(2 * m), 1 * one)); static_assert(compare(pow<1>(2 * m), 2 * m)); static_assert(compare(pow<2>(2 * m), 4 * pow<2>(m), 4 * m2));