Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(avm): separate binary and bytes finalization #8010

Merged
merged 1 commit into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 43 additions & 5 deletions barretenberg/cpp/src/barretenberg/vm/avm/trace/binary_trace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,6 @@

namespace bb::avm_trace {

std::vector<AvmBinaryTraceBuilder::BinaryTraceEntry> AvmBinaryTraceBuilder::finalize()
{
return std::move(binary_trace);
}

void AvmBinaryTraceBuilder::reset()
{
binary_trace.clear();
Expand Down Expand Up @@ -166,4 +161,47 @@ FF AvmBinaryTraceBuilder::op_xor(FF const& a, FF const& b, AvmMemoryTag instr_ta
return uint256_t::from_uint128(c_uint128);
}

void AvmBinaryTraceBuilder::finalize(std::vector<AvmFullRow<FF>>& main_trace)
{
for (size_t i = 0; i < size(); i++) {
auto const& src = binary_trace.at(i);
auto& dest = main_trace.at(i);
dest.binary_clk = src.binary_clk;
dest.binary_sel_bin = static_cast<uint8_t>(src.bin_sel);
dest.binary_acc_ia = src.acc_ia;
dest.binary_acc_ib = src.acc_ib;
dest.binary_acc_ic = src.acc_ic;
dest.binary_in_tag = src.in_tag;
dest.binary_op_id = src.op_id;
dest.binary_ia_bytes = src.bin_ia_bytes;
dest.binary_ib_bytes = src.bin_ib_bytes;
dest.binary_ic_bytes = src.bin_ic_bytes;
dest.binary_start = FF(static_cast<uint8_t>(src.start));
dest.binary_mem_tag_ctr = src.mem_tag_ctr;
dest.binary_mem_tag_ctr_inv = src.mem_tag_ctr_inv;
}

reset();
}

void AvmBinaryTraceBuilder::finalize_lookups(std::vector<AvmFullRow<FF>>& main_trace)
{
for (auto const& [clk, count] : byte_operation_counter) {
main_trace.at(clk).lookup_byte_operations_counts = count;
}

for (uint8_t avm_in_tag = 0; avm_in_tag < 5; avm_in_tag++) {
// The +1 here is because the instruction tags we care about (i.e excl U0 and FF) has the range [1,5]
main_trace.at(avm_in_tag).lookup_byte_lengths_counts = byte_length_counter[avm_in_tag + 1];
}
}

void AvmBinaryTraceBuilder::finalize_lookups_for_testing(std::vector<AvmFullRow<FF>>& main_trace)
{
for (uint8_t avm_in_tag = 0; avm_in_tag < 5; avm_in_tag++) {
// The +1 here is because the instruction tags we care about (i.e excl U0 and FF) has the range [1,5]
main_trace.at(avm_in_tag).lookup_byte_lengths_counts = byte_length_counter[avm_in_tag + 1];
}
}

} // namespace bb::avm_trace
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include "barretenberg/numeric/uint128/uint128.hpp"
#include "barretenberg/vm/avm/generated/full_row.hpp"
#include "barretenberg/vm/avm/trace/common.hpp"

#include <unordered_map>
Expand Down Expand Up @@ -32,9 +33,15 @@ class AvmBinaryTraceBuilder {
std::unordered_map<uint32_t, uint32_t> byte_length_counter;

AvmBinaryTraceBuilder() = default;

size_t size() const { return binary_trace.size(); }
void reset();
// Finalize the trace
std::vector<BinaryTraceEntry> finalize();

// These two have to be separate because the lookups need to be finalized
// after the extra first row is inserted in the main trace.
void finalize(std::vector<AvmFullRow<FF>>& main_trace);
void finalize_lookups(std::vector<AvmFullRow<FF>>& main_trace);
void finalize_lookups_for_testing(std::vector<AvmFullRow<FF>>& main_trace);

FF op_and(FF const& a, FF const& b, AvmMemoryTag instr_tag, uint32_t clk);
FF op_or(FF const& a, FF const& b, AvmMemoryTag instr_tag, uint32_t clk);
Expand Down
92 changes: 92 additions & 0 deletions barretenberg/cpp/src/barretenberg/vm/avm/trace/fixed_bytes.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#include "barretenberg/vm/avm/trace/fixed_bytes.hpp"

namespace bb::avm_trace {

// Singleton.
const FixedBytesTable& FixedBytesTable::get()
{
static FixedBytesTable table;
return table;
}

void FixedBytesTable::finalize(std::vector<AvmFullRow<FF>>& main_trace) const
{
if (main_trace.size() < 3 * (1 << 16)) {
main_trace.resize(3 * (1 << 16));
}
// Generate Lookup Table of all combinations of 2, 8-bit numbers and op_id.
for (uint32_t op_id = 0; op_id < 3; op_id++) {
for (uint32_t input_a = 0; input_a <= UINT8_MAX; input_a++) {
for (uint32_t input_b = 0; input_b <= UINT8_MAX; input_b++) {
auto a = static_cast<uint8_t>(input_a);
auto b = static_cast<uint8_t>(input_b);

// Derive a unique row index given op_id, a, and b.
auto main_trace_index = (op_id << 16) + (input_a << 8) + b;

main_trace.at(main_trace_index).byte_lookup_sel_bin = FF(1);
main_trace.at(main_trace_index).byte_lookup_table_op_id = op_id;
main_trace.at(main_trace_index).byte_lookup_table_input_a = a;
main_trace.at(main_trace_index).byte_lookup_table_input_b = b;
}
}
}

finalize_byte_length(main_trace);
}

void FixedBytesTable::finalize_for_testing(std::vector<AvmFullRow<FF>>& main_trace,
const std::unordered_map<uint32_t, uint32_t>& byte_operation_counter) const
{
// Generate ByteLength Lookup table of instruction tags to the number of bytes
// {U8: 1, U16: 2, U32: 4, U64: 8, U128: 16}
for (auto const& [clk, count] : byte_operation_counter) {
// from the clk we can derive the a and b inputs
auto b = static_cast<uint8_t>(clk);
auto a = static_cast<uint8_t>(clk >> 8);
auto op_id = static_cast<uint8_t>(clk >> 16);
uint8_t bit_op = 0;
if (op_id == 0) {
bit_op = a & b;
} else if (op_id == 1) {
bit_op = a | b;
} else {
bit_op = a ^ b;
}
if (clk > (main_trace.size() - 1)) {
main_trace.push_back(AvmFullRow<FF>{
.byte_lookup_sel_bin = FF(1),
.byte_lookup_table_input_a = a,
.byte_lookup_table_input_b = b,
.byte_lookup_table_op_id = op_id,
.byte_lookup_table_output = bit_op,
.main_clk = FF(clk),
.lookup_byte_operations_counts = count,
});
} else {
main_trace.at(clk).lookup_byte_operations_counts = count;
main_trace.at(clk).byte_lookup_sel_bin = FF(1);
main_trace.at(clk).byte_lookup_table_op_id = op_id;
main_trace.at(clk).byte_lookup_table_input_a = a;
main_trace.at(clk).byte_lookup_table_input_b = b;
main_trace.at(clk).byte_lookup_table_output = bit_op;
}
// Add the counter value stored throughout the execution
}

finalize_byte_length(main_trace);
}

void FixedBytesTable::finalize_byte_length(std::vector<AvmFullRow<FF>>& main_trace)
{
// Generate ByteLength Lookup table of instruction tags to the number of bytes
// {U8: 1, U16: 2, U32: 4, U64: 8, U128: 16}
for (uint8_t avm_in_tag = 0; avm_in_tag < 5; avm_in_tag++) {
// The +1 here is because the instruction tags we care about (i.e excl U0 and FF) has the range 1,5]
main_trace.at(avm_in_tag).byte_lookup_sel_bin = FF(1);
main_trace.at(avm_in_tag).byte_lookup_table_in_tags = avm_in_tag + 1;
main_trace.at(avm_in_tag).byte_lookup_table_byte_lengths = static_cast<uint8_t>(1 << avm_in_tag);
}
}

} // namespace bb::avm_trace
25 changes: 25 additions & 0 deletions barretenberg/cpp/src/barretenberg/vm/avm/trace/fixed_bytes.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#pragma once

#include <cstddef>
#include <cstdint>

#include "barretenberg/ecc/curves/bn254/fr.hpp"
#include "barretenberg/vm/avm/trace/common.hpp"
#include "barretenberg/vm/avm/trace/opcode.hpp"

namespace bb::avm_trace {

class FixedBytesTable {
public:
static const FixedBytesTable& get();

void finalize(std::vector<AvmFullRow<FF>>& main_trace) const;
void finalize_for_testing(std::vector<AvmFullRow<FF>>& main_trace,
const std::unordered_map<uint32_t, uint32_t>& byte_operation_counter) const;

private:
FixedBytesTable() = default;
static void finalize_byte_length(std::vector<AvmFullRow<FF>>& main_trace);
};

} // namespace bb::avm_trace
139 changes: 23 additions & 116 deletions barretenberg/cpp/src/barretenberg/vm/avm/trace/trace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "barretenberg/numeric/uint256/uint256.hpp"
#include "barretenberg/polynomials/univariate.hpp"
#include "barretenberg/vm/avm/trace/common.hpp"
#include "barretenberg/vm/avm/trace/fixed_bytes.hpp"
#include "barretenberg/vm/avm/trace/fixed_gas.hpp"
#include "barretenberg/vm/avm/trace/fixed_powers.hpp"
#include "barretenberg/vm/avm/trace/gadgets/slice_trace.hpp"
Expand All @@ -34,47 +35,6 @@ namespace bb::avm_trace {
* HELPERS IN ANONYMOUS NAMESPACE
**************************************************************************************************/
namespace {
// WARNING: FOR TESTING ONLY
// Generates the minimal lookup table for the binary trace
uint32_t finalize_bin_trace_lookup_for_testing(std::vector<Row>& main_trace, AvmBinaryTraceBuilder& bin_trace_builder)
{
// Generate ByteLength Lookup table of instruction tags to the number of bytes
// {U8: 1, U16: 2, U32: 4, U64: 8, U128: 16}
for (auto const& [clk, count] : bin_trace_builder.byte_operation_counter) {
// from the clk we can derive the a and b inputs
auto b = static_cast<uint8_t>(clk);
auto a = static_cast<uint8_t>(clk >> 8);
auto op_id = static_cast<uint8_t>(clk >> 16);
uint8_t bit_op = 0;
if (op_id == 0) {
bit_op = a & b;
} else if (op_id == 1) {
bit_op = a | b;
} else {
bit_op = a ^ b;
}
if (clk > (main_trace.size() - 1)) {
main_trace.push_back(Row{
.byte_lookup_sel_bin = FF(1),
.byte_lookup_table_input_a = a,
.byte_lookup_table_input_b = b,
.byte_lookup_table_op_id = op_id,
.byte_lookup_table_output = bit_op,
.main_clk = FF(clk),
.lookup_byte_operations_counts = count,
});
} else {
main_trace.at(clk).lookup_byte_operations_counts = count;
main_trace.at(clk).byte_lookup_sel_bin = FF(1);
main_trace.at(clk).byte_lookup_table_op_id = op_id;
main_trace.at(clk).byte_lookup_table_input_a = a;
main_trace.at(clk).byte_lookup_table_input_b = b;
main_trace.at(clk).byte_lookup_table_output = bit_op;
}
// Add the counter value stored throughout the execution
}
return static_cast<uint32_t>(main_trace.size());
}

constexpr size_t L2_HI_GAS_COUNTS_IDX = 0;
constexpr size_t L2_LO_GAS_COUNTS_IDX = 1;
Expand Down Expand Up @@ -3459,7 +3419,6 @@ std::vector<Row> AvmTraceBuilder::finalize(bool range_check_required)
auto poseidon2_trace = poseidon2_trace_builder.finalize();
auto keccak_trace = keccak_trace_builder.finalize();
auto pedersen_trace = pedersen_trace_builder.finalize();
auto bin_trace = bin_trace_builder.finalize();
auto gas_trace = gas_trace_builder.finalize();
auto slice_trace = slice_trace_builder.finalize();
const auto& fixed_gas_table = FixedGasTable::get();
Expand All @@ -3471,7 +3430,7 @@ std::vector<Row> AvmTraceBuilder::finalize(bool range_check_required)
size_t poseidon2_trace_size = poseidon2_trace.size();
size_t keccak_trace_size = keccak_trace.size();
size_t pedersen_trace_size = pedersen_trace.size();
size_t bin_trace_size = bin_trace.size();
size_t bin_trace_size = bin_trace_builder.size();
size_t gas_trace_size = gas_trace.size();
size_t slice_trace_size = slice_trace.size();

Expand All @@ -3480,18 +3439,14 @@ std::vector<Row> AvmTraceBuilder::finalize(bool range_check_required)
std::unordered_map<uint16_t, uint32_t> mem_rng_check_mid_counts;
std::unordered_map<uint8_t, uint32_t> mem_rng_check_hi_counts;

// Main Trace needs to be at least as big as the biggest subtrace.
// If the bin_trace_size has entries, we need the main_trace to be as big as our byte lookup table (3 *
// 2**16 long)
size_t const lookup_table_size = (bin_trace_size > 0 && range_check_required) ? 3 * (1 << 16) : 0;
// Range check size is 1 less than it needs to be since we insert a "first row" at the top of the trace at the
// end, with clk 0 (this doubles as our range check)
size_t const range_check_size = range_check_required ? UINT16_MAX : 0;
std::vector<size_t> trace_sizes = { mem_trace_size, main_trace_size, alu_trace_size,
range_check_size, conv_trace_size, lookup_table_size,
sha256_trace_size, poseidon2_trace_size, pedersen_trace_size,
gas_trace_size + 1, KERNEL_INPUTS_LENGTH, KERNEL_OUTPUTS_LENGTH,
fixed_gas_table.size(), slice_trace_size, calldata.size() };
std::vector<size_t> trace_sizes = { mem_trace_size, main_trace_size, alu_trace_size,
range_check_size, conv_trace_size, sha256_trace_size,
poseidon2_trace_size, pedersen_trace_size, gas_trace_size + 1,
KERNEL_INPUTS_LENGTH, KERNEL_OUTPUTS_LENGTH, fixed_gas_table.size(),
slice_trace_size, calldata.size() };
vinfo("Trace sizes before padding:",
"\n\tmain_trace_size: ",
main_trace_size,
Expand Down Expand Up @@ -3870,70 +3825,7 @@ std::vector<Row> AvmTraceBuilder::finalize(bool range_check_required)
* BINARY TRACE INCLUSION
**********************************************************************************************/

// Add Binary Trace table
for (size_t i = 0; i < bin_trace_size; i++) {
auto const& src = bin_trace.at(i);
auto& dest = main_trace.at(i);
dest.binary_clk = src.binary_clk;
dest.binary_sel_bin = static_cast<uint8_t>(src.bin_sel);
dest.binary_acc_ia = src.acc_ia;
dest.binary_acc_ib = src.acc_ib;
dest.binary_acc_ic = src.acc_ic;
dest.binary_in_tag = src.in_tag;
dest.binary_op_id = src.op_id;
dest.binary_ia_bytes = src.bin_ia_bytes;
dest.binary_ib_bytes = src.bin_ib_bytes;
dest.binary_ic_bytes = src.bin_ic_bytes;
dest.binary_start = FF(static_cast<uint8_t>(src.start));
dest.binary_mem_tag_ctr = src.mem_tag_ctr;
dest.binary_mem_tag_ctr_inv = src.mem_tag_ctr_inv;
}

// Only generate precomputed byte tables if we are actually going to use them in this main trace.
if (bin_trace_size > 0) {
if (!range_check_required) {
finalize_bin_trace_lookup_for_testing(main_trace, bin_trace_builder);
} else {
// Generate Lookup Table of all combinations of 2, 8-bit numbers and op_id.
for (uint32_t op_id = 0; op_id < 3; op_id++) {
for (uint32_t input_a = 0; input_a <= UINT8_MAX; input_a++) {
for (uint32_t input_b = 0; input_b <= UINT8_MAX; input_b++) {
auto a = static_cast<uint8_t>(input_a);
auto b = static_cast<uint8_t>(input_b);

// Derive a unique row index given op_id, a, and b.
auto main_trace_index = (op_id << 16) + (input_a << 8) + b;

main_trace.at(main_trace_index).byte_lookup_sel_bin = FF(1);
main_trace.at(main_trace_index).byte_lookup_table_op_id = op_id;
main_trace.at(main_trace_index).byte_lookup_table_input_a = a;
main_trace.at(main_trace_index).byte_lookup_table_input_b = b;
// Add the counter value stored throughout the execution
main_trace.at(main_trace_index).lookup_byte_operations_counts =
bin_trace_builder.byte_operation_counter[main_trace_index];
if (op_id == 0) {
main_trace.at(main_trace_index).byte_lookup_table_output = a & b;
} else if (op_id == 1) {
main_trace.at(main_trace_index).byte_lookup_table_output = a | b;
} else {
main_trace.at(main_trace_index).byte_lookup_table_output = a ^ b;
}
}
}
}
}
// Generate ByteLength Lookup table of instruction tags to the number of bytes
// {U8: 1, U16: 2, U32: 4, U64: 8, U128: 16}
for (uint8_t avm_in_tag = 0; avm_in_tag < 5; avm_in_tag++) {
// The +1 here is because the instruction tags we care about (i.e excl U0 and FF) has the range
// [1,5]
main_trace.at(avm_in_tag).byte_lookup_sel_bin = FF(1);
main_trace.at(avm_in_tag).byte_lookup_table_in_tags = avm_in_tag + 1;
main_trace.at(avm_in_tag).byte_lookup_table_byte_lengths = static_cast<uint8_t>(pow(2, avm_in_tag));
main_trace.at(avm_in_tag).lookup_byte_lengths_counts =
bin_trace_builder.byte_length_counter[avm_in_tag + 1];
}
}
bin_trace_builder.finalize(main_trace);

/**********************************************************************************************
* GAS TRACE INCLUSION
Expand Down Expand Up @@ -4015,6 +3907,21 @@ std::vector<Row> AvmTraceBuilder::finalize(bool range_check_required)
Row first_row = Row{ .main_sel_first = FF(1), .mem_lastAccess = FF(1) };
main_trace.insert(main_trace.begin(), first_row);

/**********************************************************************************************
* BYTES TRACE INCLUSION
**********************************************************************************************/

// Only generate precomputed byte tables if we are actually going to use them in this main trace.
if (bin_trace_size > 0) {
if (!range_check_required) {
FixedBytesTable::get().finalize_for_testing(main_trace, bin_trace_builder.byte_operation_counter);
bin_trace_builder.finalize_lookups_for_testing(main_trace);
} else {
FixedBytesTable::get().finalize(main_trace);
bin_trace_builder.finalize_lookups(main_trace);
}
}

/**********************************************************************************************
* RANGE CHECKS AND SELECTORS INCLUSION
**********************************************************************************************/
Expand Down
Loading