Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added barrett_reduction implementation into uintx #6768

Merged
merged 5 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions barretenberg/cpp/src/barretenberg/benchmark/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ add_subdirectory(indexed_tree_bench)
add_subdirectory(append_only_tree_bench)
add_subdirectory(ultra_bench)
add_subdirectory(stdlib_hash)
add_subdirectory(circuit_construction_bench)
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
barretenberg_module(circuit_construction_bench stdlib_primitives)
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@

#include <benchmark/benchmark.h>

#include "barretenberg/stdlib/primitives/biggroup/biggroup.hpp"
#include "barretenberg/stdlib/primitives/curves/bn254.hpp"
#include "barretenberg/stdlib_circuit_builders/ultra_circuit_builder.hpp"

using namespace benchmark;
using namespace bb;

namespace {

auto& engine = numeric::get_debug_randomness();
void biggroup_construction_bench(State& state)
{
using Curve = stdlib::bn254<UltraCircuitBuilder>;
using affine_element = Curve::AffineElementNative;
using element_ct = Curve::Element;
using scalar_ct = Curve::ScalarField;
for (auto _ : state) {
state.PauseTiming();

UltraCircuitBuilder builder;
size_t num_points = static_cast<size_t>(state.range(0));
std::vector<affine_element> points;
std::vector<fr> scalars;
for (size_t i = 0; i < num_points; ++i) {
points.push_back(affine_element(Curve::ElementNative::random_element()));
scalars.push_back(fr::random_element());
}

std::vector<element_ct> circuit_points;
std::vector<scalar_ct> circuit_scalars;
for (size_t i = 0; i < num_points; ++i) {
circuit_points.push_back(element_ct::from_witness(&builder, points[i]));
circuit_scalars.push_back(scalar_ct::from_witness(&builder, scalars[i]));
}
state.ResumeTiming();
element_ct::batch_mul(circuit_points, circuit_scalars);
state.PauseTiming();
}
}
} // namespace
BENCHMARK(biggroup_construction_bench)->Unit(kMicrosecond)->DenseRange(2, 20);

BENCHMARK_MAIN();
2 changes: 2 additions & 0 deletions barretenberg/cpp/src/barretenberg/numeric/uintx/uintx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,9 @@ template <class base_uint> class uintx {
base_uint lo;
base_uint hi;

template <base_uint modulus> constexpr std::pair<uintx, uintx> barrett_reduction() const;
constexpr std::pair<uintx, uintx> divmod(const uintx& b) const;
constexpr std::pair<uintx, uintx> divmod_base(const uintx& b) const;
};

template <typename B, typename Params> inline void read(B& it, uintx<Params>& value)
Expand Down
33 changes: 33 additions & 0 deletions barretenberg/cpp/src/barretenberg/numeric/uintx/uintx.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,39 @@ namespace {
auto& engine = numeric::get_debug_randomness();
} // namespace

TEST(uintx, BarrettReduction512)
{
uint512_t x = engine.get_random_uint512();

static constexpr uint64_t modulus_0 = 0x3C208C16D87CFD47UL;
static constexpr uint64_t modulus_1 = 0x97816a916871ca8dUL;
static constexpr uint64_t modulus_2 = 0xb85045b68181585dUL;
static constexpr uint64_t modulus_3 = 0x30644e72e131a029UL;
constexpr uint256_t modulus(modulus_0, modulus_1, modulus_2, modulus_3);

const auto [quotient_result, remainder_result] = x.barrett_reduction<modulus>();
const auto [quotient_expected, remainder_expected] = x.divmod_base(uint512_t(modulus));
EXPECT_EQ(quotient_result, quotient_expected);
EXPECT_EQ(remainder_result, remainder_expected);
}

TEST(uintx, BarrettReduction1024)
{
uint1024_t x = engine.get_random_uint1024();

static constexpr uint64_t modulus_0 = 0x3C208C16D87CFD47UL;
static constexpr uint64_t modulus_1 = 0x97816a916871ca8dUL;
static constexpr uint64_t modulus_2 = 0xb85045b68181585dUL;
static constexpr uint64_t modulus_3 = 0x30644e72e131a029UL;
constexpr uint256_t modulus_partial(modulus_0, modulus_1, modulus_2, modulus_3);
constexpr uint512_t modulus = uint512_t(modulus_partial) * uint512_t(modulus_partial);

const auto [quotient_result, remainder_result] = x.barrett_reduction<modulus>();
const auto [quotient_expected, remainder_expected] = x.divmod_base(uint1024_t(modulus));
EXPECT_EQ(quotient_result, quotient_expected);
EXPECT_EQ(remainder_result, remainder_expected);
}

TEST(uintx, GetBit)
{
constexpr uint256_t lo{ 0b0110011001110010011001100111001001100110011100100110011001110011,
Expand Down
81 changes: 80 additions & 1 deletion barretenberg/cpp/src/barretenberg/numeric/uintx/uintx_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

namespace bb::numeric {
template <class base_uint>
constexpr std::pair<uintx<base_uint>, uintx<base_uint>> uintx<base_uint>::divmod(const uintx& b) const
constexpr std::pair<uintx<base_uint>, uintx<base_uint>> uintx<base_uint>::divmod_base(const uintx& b) const
{
ASSERT(b != 0);
if (*this == 0) {
Expand Down Expand Up @@ -336,4 +336,83 @@ template <class base_uint> constexpr uintx<base_uint> uintx<base_uint>::operator
}
return result;
}

template <class base_uint>
constexpr std::pair<uintx<base_uint>, uintx<base_uint>> uintx<base_uint>::divmod(const uintx& b) const
{
constexpr uint256_t BN254FQMODULUS256 =
uint256_t(0x3C208C16D87CFD47UL, 0x97816a916871ca8dUL, 0xb85045b68181585dUL, 0x30644e72e131a029UL);
constexpr uint256_t SECP256K1FQMODULUS256 =
uint256_t(0xFFFFFFFEFFFFFC2FULL, 0xFFFFFFFFFFFFFFFFULL, 0xFFFFFFFFFFFFFFFFULL, 0xFFFFFFFFFFFFFFFFULL);
constexpr uint256_t SECP256R1FQMODULUS256 =
uint256_t(0xFFFFFFFFFFFFFFFFULL, 0x00000000FFFFFFFFULL, 0x0000000000000000ULL, 0xFFFFFFFF00000001ULL);

if (b == uintx(BN254FQMODULUS256)) {
return (*this).template barrett_reduction<BN254FQMODULUS256>();
}
if (b == uintx(SECP256K1FQMODULUS256)) {
return (*this).template barrett_reduction<SECP256K1FQMODULUS256>();
}
if (b == uintx(SECP256R1FQMODULUS256)) {
return (*this).template barrett_reduction<SECP256R1FQMODULUS256>();
}

return divmod_base(b);
}

/**
* @brief Compute fast division via a barrett reduction
* Evaluates x = qm + r where m = modulus. returns q, r
* @details This implementation is less efficient due to making no assumptions about the value of *self.
* When using this method to perform modular reductions e.g. (*self) mod m, if (*self) < m^2 a lot of the
* `uintx` operations in this method could be replaced with `base_uint` operations
*
* @tparam base_uint
* @tparam modulus
* @return constexpr std::pair<uintx<base_uint>, uintx<base_uint>>
*/
template <class base_uint>
template <base_uint modulus>
constexpr std::pair<uintx<base_uint>, uintx<base_uint>> uintx<base_uint>::barrett_reduction() const
{
// N.B. k could be modulus.get_msb() + 1 if we have strong bounds on the max value of (*self)
// (a smaller k would allow us to fit `redc_parameter` into `base_uint` and not `uintx`)
constexpr size_t k = base_uint::length() - 1;
// N.B. computation of redc_parameter requires division operation - if this cannot be precomputed (or amortized over
// multiple reductions over the same modulus), barrett_reduction is much slower than divmod
constexpr uintx redc_parameter = ((uintx(1) << (k * 2)).divmod_base(uintx(modulus))).first;

const auto x = *this;

// compute x * redc_parameter
const auto mul_result = x.mul_extended(redc_parameter);
constexpr size_t shift = 2 * k;

// compute (x * redc_parameter) >> 2k
// This is equivalent to (x * (2^{2k} / modulus) / 2^{2k})
// which approximates to x / modulus
const uintx downshifted_hi_bits = mul_result.second & ((uintx(1) << shift) - 1);
const uintx mul_hi_underflow = uintx(downshifted_hi_bits) << (length() - shift);
uintx quotient = (mul_result.first >> shift) | mul_hi_underflow;

// compute remainder by determining value of x - quotient * modulus
uintx qm_lo(0);
{
const auto lolo = quotient.lo.mul_extended(modulus);
const auto lohi = quotient.hi.mul_extended(modulus);
base_uint t0 = lolo.first;
base_uint t1 = lolo.second;
t1 = t1 + lohi.first;
qm_lo = uintx(t0, t1);
}
uintx remainder = x - qm_lo;

// because redc_parameter is an imperfect representation of 2^{2k} / n (might be too small),
// the computed quotient may be off by 1
if (remainder >= uintx(modulus)) {
remainder = remainder - modulus;
quotient = quotient + 1;
}
return std::make_pair(quotient, remainder);
}
} // namespace bb::numeric
Loading