Skip to content

Commit

Permalink
feat(avm): get_row optimization - 25x faster logderiv inv (#11605)
Browse files Browse the repository at this point in the history
Proving times (VM1) on 16 cores, 850+ columns, dozens of lookups, bulk_test.

```
** Before **
prove/all_ms: 92606
prove/execute_log_derivative_inverse_round_ms: 21544

** After **
prove/all_ms: 73404
prove/execute_log_derivative_inverse_round_ms: 839
```

No change in sumcheck time.

For reviewing, you can focus on the templates. An explanation follows (with history).

---

This PR is about the `get_row()` method on the prover polynomials of a given flavor. This method is used by the [logderivative library](https://github.com/AztecProtocol/aztec-packages/blob/master/barretenberg/cpp/src/barretenberg/honk/proof_system/logderivative_library.hpp#L36) to compute logderivative inverses.

Originally, `get_row()` was supposed to be debug only but it ended up used in the library above. To be fair, the reason is as follows: the `accumulate` function of relations (including lookups and perms), takes in a row (or something that looks like it!). However, by the time that you have to compute inverses, you don't have your row-based trace anymore, you only have the prover polynomials which are column-based. So, you need to extract a row from columns.

The following sections explore a way to make things run faster, without completely breaking the `get_row()` expectations from the caller. That is, that it behaves like a row (you can do `.column` and it will return the field for it).

# Phase 1: `AllEntities<FF>`

So far so good. Normal [BB flavors](https://github.com/AztecProtocol/aztec-packages/blob/master/barretenberg/cpp/src/barretenberg/stdlib_circuit_builders/mega_flavor.hpp#L366) make `get_row()` return `AllEntities<FF>` which is literally a row with as many fields copied as columns you have. Note that the copy is done even for the columns that may not get used later in the accumulation of a relation, or in the computation of inverses.

This might be ok if you have 10 columns and a handful of lookups, but in our case we have dozens of lookups and 850+ columns (we estimate 3500 by completion of the AVM).

# Phase 2: something like `AllEntities<const FF&>`

As a quick fix you might think you can copy references instead and use `AllEntities<const FF&>`. Well you can't, at least not the way you would use `AllEntities<FF>`. Since the class would have members that are references, you need to define a constructor that initializes them all, maybe from a `RefArray` of sorts. The problem is because the class `AllEntities` is defined as inheriting from other classes, instead of being "flat".

This, for us, added an immense amount of codegen. See `AllConstRefValues` [here](https://github.com/AztecProtocol/aztec-packages/blob/2f05dc02fe7b147c7cd6fc235134279dbf332c08/barretenberg/cpp/src/barretenberg/vm/avm/generated/flavor.cpp).

This improvement was introduced in [this PR](#7419) and it gave a **20x** speed improvement over `AllEntities<FF>`.

The code itself was then improved in [this PR](#11504) by using a flat class and some fold expressions.

# Phase 3: Getters

Ideally what we'd want is for `get_row()` to return something like this:
```
    template <typename Polynomials> class PolynomialEntitiesAtFixedRow {
      public:
        PolynomialEntitiesAtFixedRow(const size_t row_idx, const Polynomials& pp)
            : row_idx(row_idx)
            , pp(pp)
        {}
        // what here?

      private:
        const size_t row_idx;
        const Polynomials& pp;
    };
```
such that if you do `row.column` it would secretly do `pp.column[row_idx]` instead. Unfortunately, you cannot override the `.` operator, and certainly not like this.

Instead, we compromise. I added a macro to generate getters `_column()` for every column, which do exactly that. Then I changed the lookups and permutation codegen to use that (i.e., `in._column()` instead of `in.column`). Note that we _only_ use these getters in lookups and perm, not in the main relations.

However, we are not done. The perms and lookups code that we changed is also called by `accumulate` when doing sumcheck, and `AllEntities` does not provide those getters so it will not compile. Well, we add them, and we are done.

This results in a **25x** time improvement in calculating logderiv inverses, amounting to a total of **500x** better than baseline.

# Conclusion

Some thing in BB are not thought for a VM :) I wonder if theres any such improvement lurking in sumcheck? :)
  • Loading branch information
fcarreiro authored Jan 29, 2025
1 parent b44233f commit a273136
Show file tree
Hide file tree
Showing 34 changed files with 1,461 additions and 1,366 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -860,7 +860,7 @@ bool AvmCircuitBuilder::check_circuit() const

std::array<bool, result.size()> subrelation_failed = { false };
for (size_t r = 0; r < num_rows; ++r) {
Relation::accumulate(result, polys.get_row(r), {}, 1);
Relation::accumulate(result, polys.get_standard_row(r), {}, 1);
for (size_t j = 0; j < result.size(); ++j) {
if (!subrelation_failed[j] && result[j] != 0) {
signal_error(format("Relation ",
Expand Down Expand Up @@ -891,7 +891,7 @@ bool AvmCircuitBuilder::check_circuit() const
r = 0;
}
for (size_t r = 0; r < num_rows; ++r) {
Relation::accumulate(lookup_result, polys.get_row(r), params, 1);
Relation::accumulate(lookup_result, polys.get_standard_row(r), params, 1);
}
for (auto r : lookup_result) {
if (r != 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ namespace bb::avm {
#define AVM_TO_BE_SHIFTED(e) e.binary_acc_ia, e.binary_acc_ib, e.binary_acc_ic, e.binary_mem_tag_ctr, e.binary_op_id, e.cmp_a_hi, e.cmp_a_lo, e.cmp_b_hi, e.cmp_b_lo, e.cmp_cmp_rng_ctr, e.cmp_op_gt, e.cmp_p_sub_a_hi, e.cmp_p_sub_a_lo, e.cmp_p_sub_b_hi, e.cmp_p_sub_b_lo, e.cmp_sel_rng_chk, e.main_da_gas_remaining, e.main_l2_gas_remaining, e.main_pc, e.main_sel_execution_end, e.main_sel_execution_row, e.mem_glob_addr, e.mem_rw, e.mem_sel_mem, e.mem_tag, e.mem_tsp, e.mem_val, e.merkle_tree_leaf_index, e.merkle_tree_leaf_value, e.merkle_tree_path_len, e.poseidon2_full_a_0, e.poseidon2_full_a_1, e.poseidon2_full_a_2, e.poseidon2_full_a_3, e.poseidon2_full_execute_poseidon_perm, e.poseidon2_full_input_0, e.poseidon2_full_input_1, e.poseidon2_full_input_2, e.poseidon2_full_num_perm_rounds_rem, e.poseidon2_full_sel_poseidon, e.poseidon2_full_start_poseidon, e.slice_addr, e.slice_clk, e.slice_cnt, e.slice_sel_cd_cpy, e.slice_sel_mem_active, e.slice_sel_return, e.slice_sel_start, e.slice_space_id
#define AVM_ALL_ENTITIES AVM_PRECOMPUTED_ENTITIES, AVM_WIRE_ENTITIES, AVM_DERIVED_WITNESS_ENTITIES, AVM_SHIFTED_ENTITIES
#define AVM_UNSHIFTED_ENTITIES AVM_PRECOMPUTED_ENTITIES, AVM_WIRE_ENTITIES, AVM_DERIVED_WITNESS_ENTITIES
#define AVM_WITNESS_ENTITIES AVM_WIRE_ENTITIES, AVM_DERIVED_WITNESS_ENTITIES

#define AVM_TO_BE_SHIFTED_COLUMNS Column::binary_acc_ia, Column::binary_acc_ib, Column::binary_acc_ic, Column::binary_mem_tag_ctr, Column::binary_op_id, Column::cmp_a_hi, Column::cmp_a_lo, Column::cmp_b_hi, Column::cmp_b_lo, Column::cmp_cmp_rng_ctr, Column::cmp_op_gt, Column::cmp_p_sub_a_hi, Column::cmp_p_sub_a_lo, Column::cmp_p_sub_b_hi, Column::cmp_p_sub_b_lo, Column::cmp_sel_rng_chk, Column::main_da_gas_remaining, Column::main_l2_gas_remaining, Column::main_pc, Column::main_sel_execution_end, Column::main_sel_execution_row, Column::mem_glob_addr, Column::mem_rw, Column::mem_sel_mem, Column::mem_tag, Column::mem_tsp, Column::mem_val, Column::merkle_tree_leaf_index, Column::merkle_tree_leaf_value, Column::merkle_tree_path_len, Column::poseidon2_full_a_0, Column::poseidon2_full_a_1, Column::poseidon2_full_a_2, Column::poseidon2_full_a_3, Column::poseidon2_full_execute_poseidon_perm, Column::poseidon2_full_input_0, Column::poseidon2_full_input_1, Column::poseidon2_full_input_2, Column::poseidon2_full_num_perm_rounds_rem, Column::poseidon2_full_sel_poseidon, Column::poseidon2_full_start_poseidon, Column::slice_addr, Column::slice_clk, Column::slice_cnt, Column::slice_sel_cd_cpy, Column::slice_sel_mem_active, Column::slice_sel_return, Column::slice_sel_start, Column::slice_space_id
#define AVM_SHIFTED_COLUMNS ColumnAndShifts::binary_acc_ia_shift, ColumnAndShifts::binary_acc_ib_shift, ColumnAndShifts::binary_acc_ic_shift, ColumnAndShifts::binary_mem_tag_ctr_shift, ColumnAndShifts::binary_op_id_shift, ColumnAndShifts::cmp_a_hi_shift, ColumnAndShifts::cmp_a_lo_shift, ColumnAndShifts::cmp_b_hi_shift, ColumnAndShifts::cmp_b_lo_shift, ColumnAndShifts::cmp_cmp_rng_ctr_shift, ColumnAndShifts::cmp_op_gt_shift, ColumnAndShifts::cmp_p_sub_a_hi_shift, ColumnAndShifts::cmp_p_sub_a_lo_shift, ColumnAndShifts::cmp_p_sub_b_hi_shift, ColumnAndShifts::cmp_p_sub_b_lo_shift, ColumnAndShifts::cmp_sel_rng_chk_shift, ColumnAndShifts::main_da_gas_remaining_shift, ColumnAndShifts::main_l2_gas_remaining_shift, ColumnAndShifts::main_pc_shift, ColumnAndShifts::main_sel_execution_end_shift, ColumnAndShifts::main_sel_execution_row_shift, ColumnAndShifts::mem_glob_addr_shift, ColumnAndShifts::mem_rw_shift, ColumnAndShifts::mem_sel_mem_shift, ColumnAndShifts::mem_tag_shift, ColumnAndShifts::mem_tsp_shift, ColumnAndShifts::mem_val_shift, ColumnAndShifts::merkle_tree_leaf_index_shift, ColumnAndShifts::merkle_tree_leaf_value_shift, ColumnAndShifts::merkle_tree_path_len_shift, ColumnAndShifts::poseidon2_full_a_0_shift, ColumnAndShifts::poseidon2_full_a_1_shift, ColumnAndShifts::poseidon2_full_a_2_shift, ColumnAndShifts::poseidon2_full_a_3_shift, ColumnAndShifts::poseidon2_full_execute_poseidon_perm_shift, ColumnAndShifts::poseidon2_full_input_0_shift, ColumnAndShifts::poseidon2_full_input_1_shift, ColumnAndShifts::poseidon2_full_input_2_shift, ColumnAndShifts::poseidon2_full_num_perm_rounds_rem_shift, ColumnAndShifts::poseidon2_full_sel_poseidon_shift, ColumnAndShifts::poseidon2_full_start_poseidon_shift, ColumnAndShifts::slice_addr_shift, ColumnAndShifts::slice_clk_shift, ColumnAndShifts::slice_cnt_shift, ColumnAndShifts::slice_sel_cd_cpy_shift, ColumnAndShifts::slice_sel_mem_active_shift, ColumnAndShifts::slice_sel_return_shift, ColumnAndShifts::slice_sel_start_shift, ColumnAndShifts::slice_space_id_shift
Expand Down
47 changes: 37 additions & 10 deletions barretenberg/cpp/src/barretenberg/vm/avm/generated/flavor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "barretenberg/transcript/transcript.hpp"

#include "barretenberg/vm/aztec_constants.hpp"
#include "barretenberg/vm2/common/macros.hpp"
#include "columns.hpp"
#include "flavor_settings.hpp"

Expand Down Expand Up @@ -52,6 +53,19 @@
// Metaprogramming to concatenate tuple types.
template <typename... input_t> using tuple_cat_t = decltype(std::tuple_cat(std::declval<input_t>()...));

// clang-format off
// These getters are used to speedup logderivative inverses.
// See https://github.com/AztecProtocol/aztec-packages/pull/11605/ for a full explanation.
#define DEFAULT_GETTERS(ENTITY) \
inline auto& _##ENTITY() { return ENTITY; } \
inline auto& _##ENTITY() const { return ENTITY; }
#define ROW_PROXY_GETTERS(ENTITY) \
inline auto& _##ENTITY() { return pp.ENTITY[row_idx]; } \
inline auto& _##ENTITY() const { return pp.ENTITY[row_idx]; }
#define DEFINE_GETTERS(GETTER_MACRO, ENTITIES) \
FOR_EACH(GETTER_MACRO, ENTITIES)
// clang-format on

namespace bb::avm {

class AvmFlavor {
Expand Down Expand Up @@ -210,32 +224,29 @@ class AvmFlavor {
"AVM circuit. In this case, modify AVM_VERIFICATION_LENGTH_IN_FIELDS \n"
"in constants.nr accordingly.");

template <typename DataType_> class PrecomputedEntities : public PrecomputedEntitiesBase {
template <typename DataType> class PrecomputedEntities : public PrecomputedEntitiesBase {
public:
using DataType = DataType_;

DEFINE_FLAVOR_MEMBERS(DataType, AVM_PRECOMPUTED_ENTITIES)

RefVector<DataType> get_selectors() { return get_all(); }
RefVector<DataType> get_sigma_polynomials() { return {}; }
RefVector<DataType> get_id_polynomials() { return {}; }
RefVector<DataType> get_table_polynomials() { return {}; }
DEFINE_GETTERS(DEFAULT_GETTERS, AVM_PRECOMPUTED_ENTITIES)
};

private:
template <typename DataType> class WireEntities {
public:
DEFINE_FLAVOR_MEMBERS(DataType, AVM_WIRE_ENTITIES)
DEFINE_GETTERS(DEFAULT_GETTERS, AVM_WIRE_ENTITIES)
};

template <typename DataType> class DerivedWitnessEntities {
public:
DEFINE_FLAVOR_MEMBERS(DataType, AVM_DERIVED_WITNESS_ENTITIES)
DEFINE_GETTERS(DEFAULT_GETTERS, AVM_DERIVED_WITNESS_ENTITIES)
};

template <typename DataType> class ShiftedEntities {
public:
DEFINE_FLAVOR_MEMBERS(DataType, AVM_SHIFTED_ENTITIES)
DEFINE_GETTERS(DEFAULT_GETTERS, AVM_SHIFTED_ENTITIES)
};

template <typename DataType, typename PrecomputedAndWitnessEntitiesSuperset>
Expand Down Expand Up @@ -341,12 +352,26 @@ class AvmFlavor {
using Base::Base;
};

// Only used by VM1 check_circuit. Remove.
class AllConstRefValues {
public:
using BaseDataType = const FF;
using DataType = BaseDataType&;

DEFINE_FLAVOR_MEMBERS(DataType, AVM_ALL_ENTITIES)
DEFINE_GETTERS(DEFAULT_GETTERS, AVM_ALL_ENTITIES)
};

template <typename Polynomials> class PolynomialEntitiesAtFixedRow {
public:
PolynomialEntitiesAtFixedRow(const size_t row_idx, const Polynomials& pp)
: row_idx(row_idx)
, pp(pp)
{}
DEFINE_GETTERS(ROW_PROXY_GETTERS, AVM_ALL_ENTITIES)

private:
const size_t row_idx;
const Polynomials& pp;
};

/**
Expand All @@ -365,12 +390,14 @@ class AvmFlavor {
ProverPolynomials(ProvingKey& proving_key);

size_t get_polynomial_size() const { return main_kernel_inputs.size(); }
AllConstRefValues get_row(size_t row_idx) const
// This is only used in VM1 check_circuit. Remove.
AllConstRefValues get_standard_row(size_t row_idx) const
{
return [row_idx](auto&... entities) -> AllConstRefValues {
return { entities[row_idx]... };
}(AVM_ALL_ENTITIES);
}
auto get_row(size_t row_idx) const { return PolynomialEntitiesAtFixedRow<ProverPolynomials>(row_idx, *this); }
};

class PartiallyEvaluatedMultivariates : public AllEntities<Polynomial> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ class lookup_pow_2_0_lookup_settings {

template <typename AllEntities> static inline auto inverse_polynomial_is_computed_at_row(const AllEntities& in)
{
return (in.alu_sel_shift_which == 1 || in.main_sel_rng_8 == 1);
return (in._alu_sel_shift_which() == 1 || in._main_sel_rng_8() == 1);
}

template <typename Accumulator, typename AllEntities>
static inline auto compute_inverse_exists(const AllEntities& in)
{
using View = typename Accumulator::View;
const auto is_operation = View(in.alu_sel_shift_which);
const auto is_table_entry = View(in.main_sel_rng_8);
const auto is_operation = View(in._alu_sel_shift_which());
const auto is_table_entry = View(in._main_sel_rng_8());
return (is_operation + is_table_entry - is_operation * is_table_entry);
}

Expand All @@ -58,14 +58,14 @@ class lookup_pow_2_0_lookup_settings {

template <typename AllEntities> static inline auto get_entities(AllEntities&& in)
{
return std::forward_as_tuple(in.lookup_pow_2_0_inv,
in.lookup_pow_2_0_counts,
in.alu_sel_shift_which,
in.main_sel_rng_8,
in.alu_ib,
in.alu_b_pow,
in.main_clk,
in.powers_power_of_2);
return std::forward_as_tuple(in._lookup_pow_2_0_inv(),
in._lookup_pow_2_0_counts(),
in._alu_sel_shift_which(),
in._main_sel_rng_8(),
in._alu_ib(),
in._alu_b_pow(),
in._main_clk(),
in._powers_power_of_2());
}
};

Expand Down Expand Up @@ -101,15 +101,15 @@ class lookup_pow_2_1_lookup_settings {

template <typename AllEntities> static inline auto inverse_polynomial_is_computed_at_row(const AllEntities& in)
{
return (in.alu_sel_shift_which == 1 || in.main_sel_rng_8 == 1);
return (in._alu_sel_shift_which() == 1 || in._main_sel_rng_8() == 1);
}

template <typename Accumulator, typename AllEntities>
static inline auto compute_inverse_exists(const AllEntities& in)
{
using View = typename Accumulator::View;
const auto is_operation = View(in.alu_sel_shift_which);
const auto is_table_entry = View(in.main_sel_rng_8);
const auto is_operation = View(in._alu_sel_shift_which());
const auto is_table_entry = View(in._main_sel_rng_8());
return (is_operation + is_table_entry - is_operation * is_table_entry);
}

Expand All @@ -125,14 +125,14 @@ class lookup_pow_2_1_lookup_settings {

template <typename AllEntities> static inline auto get_entities(AllEntities&& in)
{
return std::forward_as_tuple(in.lookup_pow_2_1_inv,
in.lookup_pow_2_1_counts,
in.alu_sel_shift_which,
in.main_sel_rng_8,
in.alu_max_bits_sub_b_bits,
in.alu_max_bits_sub_b_pow,
in.main_clk,
in.powers_power_of_2);
return std::forward_as_tuple(in._lookup_pow_2_1_inv(),
in._lookup_pow_2_1_counts(),
in._alu_sel_shift_which(),
in._main_sel_rng_8(),
in._alu_max_bits_sub_b_bits(),
in._alu_max_bits_sub_b_pow(),
in._main_clk(),
in._powers_power_of_2());
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ class lookup_byte_lengths_lookup_settings {

template <typename AllEntities> static inline auto inverse_polynomial_is_computed_at_row(const AllEntities& in)
{
return (in.binary_start == 1 || in.byte_lookup_sel_bin == 1);
return (in._binary_start() == 1 || in._byte_lookup_sel_bin() == 1);
}

template <typename Accumulator, typename AllEntities>
static inline auto compute_inverse_exists(const AllEntities& in)
{
using View = typename Accumulator::View;
const auto is_operation = View(in.binary_start);
const auto is_table_entry = View(in.byte_lookup_sel_bin);
const auto is_operation = View(in._binary_start());
const auto is_table_entry = View(in._byte_lookup_sel_bin());
return (is_operation + is_table_entry - is_operation * is_table_entry);
}

Expand All @@ -59,14 +59,14 @@ class lookup_byte_lengths_lookup_settings {

template <typename AllEntities> static inline auto get_entities(AllEntities&& in)
{
return std::forward_as_tuple(in.lookup_byte_lengths_inv,
in.lookup_byte_lengths_counts,
in.binary_start,
in.byte_lookup_sel_bin,
in.binary_in_tag,
in.binary_mem_tag_ctr,
in.byte_lookup_table_in_tags,
in.byte_lookup_table_byte_lengths);
return std::forward_as_tuple(in._lookup_byte_lengths_inv(),
in._lookup_byte_lengths_counts(),
in._binary_start(),
in._byte_lookup_sel_bin(),
in._binary_in_tag(),
in._binary_mem_tag_ctr(),
in._byte_lookup_table_in_tags(),
in._byte_lookup_table_byte_lengths());
}
};

Expand Down Expand Up @@ -105,15 +105,15 @@ class lookup_byte_operations_lookup_settings {

template <typename AllEntities> static inline auto inverse_polynomial_is_computed_at_row(const AllEntities& in)
{
return (in.binary_sel_bin == 1 || in.byte_lookup_sel_bin == 1);
return (in._binary_sel_bin() == 1 || in._byte_lookup_sel_bin() == 1);
}

template <typename Accumulator, typename AllEntities>
static inline auto compute_inverse_exists(const AllEntities& in)
{
using View = typename Accumulator::View;
const auto is_operation = View(in.binary_sel_bin);
const auto is_table_entry = View(in.byte_lookup_sel_bin);
const auto is_operation = View(in._binary_sel_bin());
const auto is_table_entry = View(in._byte_lookup_sel_bin());
return (is_operation + is_table_entry - is_operation * is_table_entry);
}

Expand All @@ -129,18 +129,18 @@ class lookup_byte_operations_lookup_settings {

template <typename AllEntities> static inline auto get_entities(AllEntities&& in)
{
return std::forward_as_tuple(in.lookup_byte_operations_inv,
in.lookup_byte_operations_counts,
in.binary_sel_bin,
in.byte_lookup_sel_bin,
in.binary_op_id,
in.binary_ia_bytes,
in.binary_ib_bytes,
in.binary_ic_bytes,
in.byte_lookup_table_op_id,
in.byte_lookup_table_input_a,
in.byte_lookup_table_input_b,
in.byte_lookup_table_output);
return std::forward_as_tuple(in._lookup_byte_operations_inv(),
in._lookup_byte_operations_counts(),
in._binary_sel_bin(),
in._byte_lookup_sel_bin(),
in._binary_op_id(),
in._binary_ia_bytes(),
in._binary_ib_bytes(),
in._binary_ic_bytes(),
in._byte_lookup_table_op_id(),
in._byte_lookup_table_input_a(),
in._byte_lookup_table_input_b(),
in._byte_lookup_table_output());
}
};

Expand Down
Loading

1 comment on commit a273136

@AztecBot
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'C++ Benchmark'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 1.05.

Benchmark suite Current: a273136 Previous: 13863eb Ratio
commit(t) 3677696919 ns/iter 3108473453 ns/iter 1.18
Goblin::merge(t) 166983290 ns/iter 141495750 ns/iter 1.18

This comment was automatically generated by workflow using github-action-benchmark.

CC: @ludamad @codygunton

Please sign in to comment.