Skip to content

Commit

Permalink
feat(avm): get_row optimization - 25x faster logderiv inv (#11605)
Browse files Browse the repository at this point in the history
Proving times (VM1) on 16 cores, 850+ columns, dozens of lookups, bulk_test.

```
** Before **
prove/all_ms: 92606
prove/execute_log_derivative_inverse_round_ms: 21544

** After **
prove/all_ms: 73404
prove/execute_log_derivative_inverse_round_ms: 839
```

No change in sumcheck time.

For reviewing, you can focus on the templates. An explanation follows (with history).

---

This PR is about the `get_row()` method on the prover polynomials of a given flavor. This method is used by the [logderivative library](https://github.com/AztecProtocol/aztec-packages/blob/master/barretenberg/cpp/src/barretenberg/honk/proof_system/logderivative_library.hpp#L36) to compute logderivative inverses.

Originally, `get_row()` was supposed to be debug only but it ended up used in the library above. To be fair, the reason is as follows: the `accumulate` function of relations (including lookups and perms), takes in a row (or something that looks like it!). However, by the time that you have to compute inverses, you don't have your row-based trace anymore, you only have the prover polynomials which are column-based. So, you need to extract a row from columns.

The following sections explore a way to make things run faster, without completely breaking the `get_row()` expectations from the caller. That is, that it behaves like a row (you can do `.column` and it will return the field for it).

# Phase 1: `AllEntities<FF>`

So far so good. Normal [BB flavors](https://github.com/AztecProtocol/aztec-packages/blob/master/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/mega_flavor.hpp#L366) make `get_row()` return `AllEntities<FF>` which is literally a row with as many fields copied as columns you have. Note that the copy is done even for the columns that may not get used later in the accumulation of a relation, or in the computation of inverses.

This might be ok if you have 10 columns and a handful of lookups, but in our case we have dozens of lookups and 850+ columns (we estimate 3500 by completion of the AVM).

# Phase 2: something like `AllEntities<const FF&>`

As a quick fix you might think you can copy references instead and use `AllEntities<const FF&>`. Well you can't, at least not the way you would use `AllEntities<FF>`. Since the class would have members that are references, you need to define a constructor that initializes them all, maybe from a `RefArray` of sorts. The problem is because the class `AllEntities` is defined as inheriting from other classes, instead of being "flat".

This, for us, added an immense amount of codegen. See `AllConstRefValues` [here](https://github.com/AztecProtocol/aztec-packages/blob/2f05dc02fe7b147c7cd6fc235134279dbf332c08/barretenberg/cpp/src/barretenberg/vm/avm/generated/flavor.cpp).

This improvement was introduced in [this PR](AztecProtocol/aztec-packages#7419) and it gave a **20x** speed improvement over `AllEntities<FF>`.

The code itself was then improved in [this PR](AztecProtocol/aztec-packages#11504) by using a flat class and some fold expressions.

# Phase 3: Getters

Ideally what we'd want is for `get_row()` to return something like this:
```
    template <typename Polynomials> class PolynomialEntitiesAtFixedRow {
      public:
        PolynomialEntitiesAtFixedRow(const size_t row_idx, const Polynomials& pp)
            : row_idx(row_idx)
            , pp(pp)
        {}
        // what here?

      private:
        const size_t row_idx;
        const Polynomials& pp;
    };
```
such that if you do `row.column` it would secretly do `pp.column[row_idx]` instead. Unfortunately, you cannot override the `.` operator, and certainly not like this.

Instead, we compromise. I added a macro to generate getters `_column()` for every column, which do exactly that. Then I changed the lookups and permutation codegen to use that (i.e., `in._column()` instead of `in.column`). Note that we _only_ use these getters in lookups and perm, not in the main relations.

However, we are not done. The perms and lookups code that we changed is also called by `accumulate` when doing sumcheck, and `AllEntities` does not provide those getters so it will not compile. Well, we add them, and we are done.

This results in a **25x** time improvement in calculating logderiv inverses, amounting to a total of **500x** better than baseline.

# Conclusion

Some thing in BB are not thought for a VM :) I wonder if theres any such improvement lurking in sumcheck? :)
  • Loading branch information
fcarreiro authored and AztecBot committed Jan 30, 2025
1 parent 145e7da commit 2b718de
Show file tree
Hide file tree
Showing 29 changed files with 1,409 additions and 1,336 deletions.
4 changes: 2 additions & 2 deletions cpp/src/barretenberg/vm/avm/generated/circuit_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -860,7 +860,7 @@ bool AvmCircuitBuilder::check_circuit() const

std::array<bool, result.size()> subrelation_failed = { false };
for (size_t r = 0; r < num_rows; ++r) {
Relation::accumulate(result, polys.get_row(r), {}, 1);
Relation::accumulate(result, polys.get_standard_row(r), {}, 1);
for (size_t j = 0; j < result.size(); ++j) {
if (!subrelation_failed[j] && result[j] != 0) {
signal_error(format("Relation ",
Expand Down Expand Up @@ -891,7 +891,7 @@ bool AvmCircuitBuilder::check_circuit() const
r = 0;
}
for (size_t r = 0; r < num_rows; ++r) {
Relation::accumulate(lookup_result, polys.get_row(r), params, 1);
Relation::accumulate(lookup_result, polys.get_standard_row(r), params, 1);
}
for (auto r : lookup_result) {
if (r != 0) {
Expand Down
1 change: 1 addition & 0 deletions cpp/src/barretenberg/vm/avm/generated/columns.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ namespace bb::avm {
#define AVM_TO_BE_SHIFTED(e) e.binary_acc_ia, e.binary_acc_ib, e.binary_acc_ic, e.binary_mem_tag_ctr, e.binary_op_id, e.cmp_a_hi, e.cmp_a_lo, e.cmp_b_hi, e.cmp_b_lo, e.cmp_cmp_rng_ctr, e.cmp_op_gt, e.cmp_p_sub_a_hi, e.cmp_p_sub_a_lo, e.cmp_p_sub_b_hi, e.cmp_p_sub_b_lo, e.cmp_sel_rng_chk, e.main_da_gas_remaining, e.main_l2_gas_remaining, e.main_pc, e.main_sel_execution_end, e.main_sel_execution_row, e.mem_glob_addr, e.mem_rw, e.mem_sel_mem, e.mem_tag, e.mem_tsp, e.mem_val, e.merkle_tree_leaf_index, e.merkle_tree_leaf_value, e.merkle_tree_path_len, e.poseidon2_full_a_0, e.poseidon2_full_a_1, e.poseidon2_full_a_2, e.poseidon2_full_a_3, e.poseidon2_full_execute_poseidon_perm, e.poseidon2_full_input_0, e.poseidon2_full_input_1, e.poseidon2_full_input_2, e.poseidon2_full_num_perm_rounds_rem, e.poseidon2_full_sel_poseidon, e.poseidon2_full_start_poseidon, e.slice_addr, e.slice_clk, e.slice_cnt, e.slice_sel_cd_cpy, e.slice_sel_mem_active, e.slice_sel_return, e.slice_sel_start, e.slice_space_id
#define AVM_ALL_ENTITIES AVM_PRECOMPUTED_ENTITIES, AVM_WIRE_ENTITIES, AVM_DERIVED_WITNESS_ENTITIES, AVM_SHIFTED_ENTITIES
#define AVM_UNSHIFTED_ENTITIES AVM_PRECOMPUTED_ENTITIES, AVM_WIRE_ENTITIES, AVM_DERIVED_WITNESS_ENTITIES
#define AVM_WITNESS_ENTITIES AVM_WIRE_ENTITIES, AVM_DERIVED_WITNESS_ENTITIES

#define AVM_TO_BE_SHIFTED_COLUMNS Column::binary_acc_ia, Column::binary_acc_ib, Column::binary_acc_ic, Column::binary_mem_tag_ctr, Column::binary_op_id, Column::cmp_a_hi, Column::cmp_a_lo, Column::cmp_b_hi, Column::cmp_b_lo, Column::cmp_cmp_rng_ctr, Column::cmp_op_gt, Column::cmp_p_sub_a_hi, Column::cmp_p_sub_a_lo, Column::cmp_p_sub_b_hi, Column::cmp_p_sub_b_lo, Column::cmp_sel_rng_chk, Column::main_da_gas_remaining, Column::main_l2_gas_remaining, Column::main_pc, Column::main_sel_execution_end, Column::main_sel_execution_row, Column::mem_glob_addr, Column::mem_rw, Column::mem_sel_mem, Column::mem_tag, Column::mem_tsp, Column::mem_val, Column::merkle_tree_leaf_index, Column::merkle_tree_leaf_value, Column::merkle_tree_path_len, Column::poseidon2_full_a_0, Column::poseidon2_full_a_1, Column::poseidon2_full_a_2, Column::poseidon2_full_a_3, Column::poseidon2_full_execute_poseidon_perm, Column::poseidon2_full_input_0, Column::poseidon2_full_input_1, Column::poseidon2_full_input_2, Column::poseidon2_full_num_perm_rounds_rem, Column::poseidon2_full_sel_poseidon, Column::poseidon2_full_start_poseidon, Column::slice_addr, Column::slice_clk, Column::slice_cnt, Column::slice_sel_cd_cpy, Column::slice_sel_mem_active, Column::slice_sel_return, Column::slice_sel_start, Column::slice_space_id
#define AVM_SHIFTED_COLUMNS ColumnAndShifts::binary_acc_ia_shift, ColumnAndShifts::binary_acc_ib_shift, ColumnAndShifts::binary_acc_ic_shift, ColumnAndShifts::binary_mem_tag_ctr_shift, ColumnAndShifts::binary_op_id_shift, ColumnAndShifts::cmp_a_hi_shift, ColumnAndShifts::cmp_a_lo_shift, ColumnAndShifts::cmp_b_hi_shift, ColumnAndShifts::cmp_b_lo_shift, ColumnAndShifts::cmp_cmp_rng_ctr_shift, ColumnAndShifts::cmp_op_gt_shift, ColumnAndShifts::cmp_p_sub_a_hi_shift, ColumnAndShifts::cmp_p_sub_a_lo_shift, ColumnAndShifts::cmp_p_sub_b_hi_shift, ColumnAndShifts::cmp_p_sub_b_lo_shift, ColumnAndShifts::cmp_sel_rng_chk_shift, ColumnAndShifts::main_da_gas_remaining_shift, ColumnAndShifts::main_l2_gas_remaining_shift, ColumnAndShifts::main_pc_shift, ColumnAndShifts::main_sel_execution_end_shift, ColumnAndShifts::main_sel_execution_row_shift, ColumnAndShifts::mem_glob_addr_shift, ColumnAndShifts::mem_rw_shift, ColumnAndShifts::mem_sel_mem_shift, ColumnAndShifts::mem_tag_shift, ColumnAndShifts::mem_tsp_shift, ColumnAndShifts::mem_val_shift, ColumnAndShifts::merkle_tree_leaf_index_shift, ColumnAndShifts::merkle_tree_leaf_value_shift, ColumnAndShifts::merkle_tree_path_len_shift, ColumnAndShifts::poseidon2_full_a_0_shift, ColumnAndShifts::poseidon2_full_a_1_shift, ColumnAndShifts::poseidon2_full_a_2_shift, ColumnAndShifts::poseidon2_full_a_3_shift, ColumnAndShifts::poseidon2_full_execute_poseidon_perm_shift, ColumnAndShifts::poseidon2_full_input_0_shift, ColumnAndShifts::poseidon2_full_input_1_shift, ColumnAndShifts::poseidon2_full_input_2_shift, ColumnAndShifts::poseidon2_full_num_perm_rounds_rem_shift, ColumnAndShifts::poseidon2_full_sel_poseidon_shift, ColumnAndShifts::poseidon2_full_start_poseidon_shift, ColumnAndShifts::slice_addr_shift, ColumnAndShifts::slice_clk_shift, ColumnAndShifts::slice_cnt_shift, ColumnAndShifts::slice_sel_cd_cpy_shift, ColumnAndShifts::slice_sel_mem_active_shift, ColumnAndShifts::slice_sel_return_shift, ColumnAndShifts::slice_sel_start_shift, ColumnAndShifts::slice_space_id_shift
Expand Down
47 changes: 37 additions & 10 deletions cpp/src/barretenberg/vm/avm/generated/flavor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "barretenberg/transcript/transcript.hpp"

#include "barretenberg/vm/aztec_constants.hpp"
#include "barretenberg/vm2/common/macros.hpp"
#include "columns.hpp"
#include "flavor_settings.hpp"

Expand Down Expand Up @@ -52,6 +53,19 @@
// Metaprogramming to concatenate tuple types.
template <typename... input_t> using tuple_cat_t = decltype(std::tuple_cat(std::declval<input_t>()...));

// clang-format off
// These getters are used to speedup logderivative inverses.
// See https://github.com/AztecProtocol/aztec-packages/pull/11605/ for a full explanation.
#define DEFAULT_GETTERS(ENTITY) \
inline auto& _##ENTITY() { return ENTITY; } \
inline auto& _##ENTITY() const { return ENTITY; }
#define ROW_PROXY_GETTERS(ENTITY) \
inline auto& _##ENTITY() { return pp.ENTITY[row_idx]; } \
inline auto& _##ENTITY() const { return pp.ENTITY[row_idx]; }
#define DEFINE_GETTERS(GETTER_MACRO, ENTITIES) \
FOR_EACH(GETTER_MACRO, ENTITIES)
// clang-format on

namespace bb::avm {

class AvmFlavor {
Expand Down Expand Up @@ -210,32 +224,29 @@ class AvmFlavor {
"AVM circuit. In this case, modify AVM_VERIFICATION_LENGTH_IN_FIELDS \n"
"in constants.nr accordingly.");

template <typename DataType_> class PrecomputedEntities : public PrecomputedEntitiesBase {
template <typename DataType> class PrecomputedEntities : public PrecomputedEntitiesBase {
public:
using DataType = DataType_;

DEFINE_FLAVOR_MEMBERS(DataType, AVM_PRECOMPUTED_ENTITIES)

RefVector<DataType> get_selectors() { return get_all(); }
RefVector<DataType> get_sigma_polynomials() { return {}; }
RefVector<DataType> get_id_polynomials() { return {}; }
RefVector<DataType> get_table_polynomials() { return {}; }
DEFINE_GETTERS(DEFAULT_GETTERS, AVM_PRECOMPUTED_ENTITIES)
};

private:
template <typename DataType> class WireEntities {
public:
DEFINE_FLAVOR_MEMBERS(DataType, AVM_WIRE_ENTITIES)
DEFINE_GETTERS(DEFAULT_GETTERS, AVM_WIRE_ENTITIES)
};

template <typename DataType> class DerivedWitnessEntities {
public:
DEFINE_FLAVOR_MEMBERS(DataType, AVM_DERIVED_WITNESS_ENTITIES)
DEFINE_GETTERS(DEFAULT_GETTERS, AVM_DERIVED_WITNESS_ENTITIES)
};

template <typename DataType> class ShiftedEntities {
public:
DEFINE_FLAVOR_MEMBERS(DataType, AVM_SHIFTED_ENTITIES)
DEFINE_GETTERS(DEFAULT_GETTERS, AVM_SHIFTED_ENTITIES)
};

template <typename DataType, typename PrecomputedAndWitnessEntitiesSuperset>
Expand Down Expand Up @@ -341,12 +352,26 @@ class AvmFlavor {
using Base::Base;
};

// Only used by VM1 check_circuit. Remove.
class AllConstRefValues {
public:
using BaseDataType = const FF;
using DataType = BaseDataType&;

DEFINE_FLAVOR_MEMBERS(DataType, AVM_ALL_ENTITIES)
DEFINE_GETTERS(DEFAULT_GETTERS, AVM_ALL_ENTITIES)
};

template <typename Polynomials> class PolynomialEntitiesAtFixedRow {
public:
PolynomialEntitiesAtFixedRow(const size_t row_idx, const Polynomials& pp)
: row_idx(row_idx)
, pp(pp)
{}
DEFINE_GETTERS(ROW_PROXY_GETTERS, AVM_ALL_ENTITIES)

private:
const size_t row_idx;
const Polynomials& pp;
};

/**
Expand All @@ -365,12 +390,14 @@ class AvmFlavor {
ProverPolynomials(ProvingKey& proving_key);

size_t get_polynomial_size() const { return main_kernel_inputs.size(); }
AllConstRefValues get_row(size_t row_idx) const
// This is only used in VM1 check_circuit. Remove.
AllConstRefValues get_standard_row(size_t row_idx) const
{
return [row_idx](auto&... entities) -> AllConstRefValues {
return { entities[row_idx]... };
}(AVM_ALL_ENTITIES);
}
auto get_row(size_t row_idx) const { return PolynomialEntitiesAtFixedRow<ProverPolynomials>(row_idx, *this); }
};

class PartiallyEvaluatedMultivariates : public AllEntities<Polynomial> {
Expand Down
44 changes: 22 additions & 22 deletions cpp/src/barretenberg/vm/avm/generated/relations/lookups_alu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ class lookup_pow_2_0_lookup_settings {

template <typename AllEntities> static inline auto inverse_polynomial_is_computed_at_row(const AllEntities& in)
{
return (in.alu_sel_shift_which == 1 || in.main_sel_rng_8 == 1);
return (in._alu_sel_shift_which() == 1 || in._main_sel_rng_8() == 1);
}

template <typename Accumulator, typename AllEntities>
static inline auto compute_inverse_exists(const AllEntities& in)
{
using View = typename Accumulator::View;
const auto is_operation = View(in.alu_sel_shift_which);
const auto is_table_entry = View(in.main_sel_rng_8);
const auto is_operation = View(in._alu_sel_shift_which());
const auto is_table_entry = View(in._main_sel_rng_8());
return (is_operation + is_table_entry - is_operation * is_table_entry);
}

Expand All @@ -58,14 +58,14 @@ class lookup_pow_2_0_lookup_settings {

template <typename AllEntities> static inline auto get_entities(AllEntities&& in)
{
return std::forward_as_tuple(in.lookup_pow_2_0_inv,
in.lookup_pow_2_0_counts,
in.alu_sel_shift_which,
in.main_sel_rng_8,
in.alu_ib,
in.alu_b_pow,
in.main_clk,
in.powers_power_of_2);
return std::forward_as_tuple(in._lookup_pow_2_0_inv(),
in._lookup_pow_2_0_counts(),
in._alu_sel_shift_which(),
in._main_sel_rng_8(),
in._alu_ib(),
in._alu_b_pow(),
in._main_clk(),
in._powers_power_of_2());
}
};

Expand Down Expand Up @@ -101,15 +101,15 @@ class lookup_pow_2_1_lookup_settings {

template <typename AllEntities> static inline auto inverse_polynomial_is_computed_at_row(const AllEntities& in)
{
return (in.alu_sel_shift_which == 1 || in.main_sel_rng_8 == 1);
return (in._alu_sel_shift_which() == 1 || in._main_sel_rng_8() == 1);
}

template <typename Accumulator, typename AllEntities>
static inline auto compute_inverse_exists(const AllEntities& in)
{
using View = typename Accumulator::View;
const auto is_operation = View(in.alu_sel_shift_which);
const auto is_table_entry = View(in.main_sel_rng_8);
const auto is_operation = View(in._alu_sel_shift_which());
const auto is_table_entry = View(in._main_sel_rng_8());
return (is_operation + is_table_entry - is_operation * is_table_entry);
}

Expand All @@ -125,14 +125,14 @@ class lookup_pow_2_1_lookup_settings {

template <typename AllEntities> static inline auto get_entities(AllEntities&& in)
{
return std::forward_as_tuple(in.lookup_pow_2_1_inv,
in.lookup_pow_2_1_counts,
in.alu_sel_shift_which,
in.main_sel_rng_8,
in.alu_max_bits_sub_b_bits,
in.alu_max_bits_sub_b_pow,
in.main_clk,
in.powers_power_of_2);
return std::forward_as_tuple(in._lookup_pow_2_1_inv(),
in._lookup_pow_2_1_counts(),
in._alu_sel_shift_which(),
in._main_sel_rng_8(),
in._alu_max_bits_sub_b_bits(),
in._alu_max_bits_sub_b_pow(),
in._main_clk(),
in._powers_power_of_2());
}
};

Expand Down
52 changes: 26 additions & 26 deletions cpp/src/barretenberg/vm/avm/generated/relations/lookups_binary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ class lookup_byte_lengths_lookup_settings {

template <typename AllEntities> static inline auto inverse_polynomial_is_computed_at_row(const AllEntities& in)
{
return (in.binary_start == 1 || in.byte_lookup_sel_bin == 1);
return (in._binary_start() == 1 || in._byte_lookup_sel_bin() == 1);
}

template <typename Accumulator, typename AllEntities>
static inline auto compute_inverse_exists(const AllEntities& in)
{
using View = typename Accumulator::View;
const auto is_operation = View(in.binary_start);
const auto is_table_entry = View(in.byte_lookup_sel_bin);
const auto is_operation = View(in._binary_start());
const auto is_table_entry = View(in._byte_lookup_sel_bin());
return (is_operation + is_table_entry - is_operation * is_table_entry);
}

Expand All @@ -59,14 +59,14 @@ class lookup_byte_lengths_lookup_settings {

template <typename AllEntities> static inline auto get_entities(AllEntities&& in)
{
return std::forward_as_tuple(in.lookup_byte_lengths_inv,
in.lookup_byte_lengths_counts,
in.binary_start,
in.byte_lookup_sel_bin,
in.binary_in_tag,
in.binary_mem_tag_ctr,
in.byte_lookup_table_in_tags,
in.byte_lookup_table_byte_lengths);
return std::forward_as_tuple(in._lookup_byte_lengths_inv(),
in._lookup_byte_lengths_counts(),
in._binary_start(),
in._byte_lookup_sel_bin(),
in._binary_in_tag(),
in._binary_mem_tag_ctr(),
in._byte_lookup_table_in_tags(),
in._byte_lookup_table_byte_lengths());
}
};

Expand Down Expand Up @@ -105,15 +105,15 @@ class lookup_byte_operations_lookup_settings {

template <typename AllEntities> static inline auto inverse_polynomial_is_computed_at_row(const AllEntities& in)
{
return (in.binary_sel_bin == 1 || in.byte_lookup_sel_bin == 1);
return (in._binary_sel_bin() == 1 || in._byte_lookup_sel_bin() == 1);
}

template <typename Accumulator, typename AllEntities>
static inline auto compute_inverse_exists(const AllEntities& in)
{
using View = typename Accumulator::View;
const auto is_operation = View(in.binary_sel_bin);
const auto is_table_entry = View(in.byte_lookup_sel_bin);
const auto is_operation = View(in._binary_sel_bin());
const auto is_table_entry = View(in._byte_lookup_sel_bin());
return (is_operation + is_table_entry - is_operation * is_table_entry);
}

Expand All @@ -129,18 +129,18 @@ class lookup_byte_operations_lookup_settings {

template <typename AllEntities> static inline auto get_entities(AllEntities&& in)
{
return std::forward_as_tuple(in.lookup_byte_operations_inv,
in.lookup_byte_operations_counts,
in.binary_sel_bin,
in.byte_lookup_sel_bin,
in.binary_op_id,
in.binary_ia_bytes,
in.binary_ib_bytes,
in.binary_ic_bytes,
in.byte_lookup_table_op_id,
in.byte_lookup_table_input_a,
in.byte_lookup_table_input_b,
in.byte_lookup_table_output);
return std::forward_as_tuple(in._lookup_byte_operations_inv(),
in._lookup_byte_operations_counts(),
in._binary_sel_bin(),
in._byte_lookup_sel_bin(),
in._binary_op_id(),
in._binary_ia_bytes(),
in._binary_ib_bytes(),
in._binary_ic_bytes(),
in._byte_lookup_table_op_id(),
in._byte_lookup_table_input_a(),
in._byte_lookup_table_input_b(),
in._byte_lookup_table_output());
}
};

Expand Down
Loading

0 comments on commit 2b718de

Please sign in to comment.