diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/memory/dynamic_array.test.cpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/memory/dynamic_array.test.cpp index 456edaf8f7a..bf83bee2f4a 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/memory/dynamic_array.test.cpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/memory/dynamic_array.test.cpp @@ -7,6 +7,7 @@ #include "../bool/bool.hpp" #include "../circuit_builders/circuit_builders.hpp" #include "barretenberg/circuit_checker/circuit_checker.hpp" +#include "barretenberg/transcript/origin_tag.hpp" using namespace bb; @@ -21,6 +22,60 @@ using field_ct = stdlib::field_t; using witness_ct = stdlib::witness_t; using DynamicArray_ct = stdlib::DynamicArray; +STANDARD_TESTING_TAGS + +/** + * @brief Check that tags in Dynamic array are propagated correctly + * + */ +TEST(DynamicArray, TagCorrectness) +{ + + Builder builder; + const size_t max_size = 4; + + DynamicArray_ct array(&builder, max_size); + + // Create random entries + field_ct entry_1 = witness_ct(&builder, bb::fr::random_element()); + field_ct entry_2 = witness_ct(&builder, bb::fr::random_element()); + field_ct entry_3 = witness_ct(&builder, bb::fr::random_element()); + field_ct entry_4 = witness_ct(&builder, bb::fr::random_element()); + + // Assign a different tag to each entry + entry_1.set_origin_tag(submitted_value_origin_tag); + entry_2.set_origin_tag(challenge_origin_tag); + entry_3.set_origin_tag(next_challenge_tag); + // Entry 4 has an "instant death" tag, that triggers an exception when merged with another tag + entry_4.set_origin_tag(instant_death_tag); + + // Fill out the dynamic array with the first 3 entries + array.push(entry_1); + array.push(entry_2); + array.push(entry_3); + + // Check that the tags are preserved + EXPECT_EQ(array.read(1).get_origin_tag(), challenge_origin_tag); + EXPECT_EQ(array.read(2).get_origin_tag(), next_challenge_tag); + EXPECT_EQ(array.read(0).get_origin_tag(), submitted_value_origin_tag); + // Update an element of the array + array.write(0, entry_2); + // Check that the tag changed + EXPECT_EQ(array.read(0).get_origin_tag(), challenge_origin_tag); + +#ifndef NDEBUG + // Check that "instant death" happens when an "instant death"-tagged element is taken from the array and added to + // another one + array.pop(); + array.pop(); + array.push(entry_4); + array.push(entry_2); + array.push(entry_3); + + EXPECT_THROW(array.read(witness_ct(&builder, 1)) + array.read(witness_ct(&builder, 2)), std::runtime_error); +#endif +} + TEST(DynamicArray, DynamicArrayReadWriteConsistency) { diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/memory/ram_table.cpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/memory/ram_table.cpp index 51a0c926daa..02d062b9d80 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/memory/ram_table.cpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/memory/ram_table.cpp @@ -1,6 +1,9 @@ #include "ram_table.hpp" #include "../circuit_builders/circuit_builders.hpp" +#include "barretenberg/numeric/uint256/uint256.hpp" +#include "barretenberg/transcript/origin_tag.hpp" +#include namespace bb::stdlib { @@ -52,6 +55,12 @@ template ram_table::ram_table(const std::vector void ram_table::initialize_table() const return; } ASSERT(_context != nullptr); - _ram_id = _context->create_RAM_array(_length); if (_raw_entries.size() > 0) { @@ -88,6 +96,13 @@ template void ram_table::initialize_table() const } } + // Store the tags of the original entries + _tags.resize(_length); + if (_raw_entries.size() > 0) { + for (size_t i = 0; i < _length; i++) { + _tags[i] = _raw_entries[i].get_origin_tag(); + } + } _ram_table_generated_in_builder = true; } @@ -100,6 +115,7 @@ template void ram_table::initialize_table() const template ram_table::ram_table(const ram_table& other) : _raw_entries(other._raw_entries) + , _tags(other._tags) , _index_initialized(other._index_initialized) , _length(other._length) , _ram_id(other._ram_id) @@ -117,6 +133,7 @@ ram_table::ram_table(const ram_table& other) template ram_table::ram_table(ram_table&& other) : _raw_entries(other._raw_entries) + , _tags(other._tags) , _index_initialized(other._index_initialized) , _length(other._length) , _ram_id(other._ram_id) @@ -135,6 +152,7 @@ ram_table::ram_table(ram_table&& other) template ram_table& ram_table::operator=(const ram_table& other) { _raw_entries = other._raw_entries; + _tags = other._tags; _length = other._length; _ram_id = other._ram_id; _index_initialized = other._index_initialized; @@ -161,6 +179,7 @@ template ram_table& ram_table::operator=(ra _ram_table_generated_in_builder = other._ram_table_generated_in_builder; _all_entries_written_to_with_constant_index = other._all_entries_written_to_with_constant_index; _context = other._context; + _tags = other._tags; return *this; } @@ -176,8 +195,8 @@ template field_t ram_table::read(const fiel if (_context == nullptr) { _context = index.get_context(); } - - if (uint256_t(index.get_value()) >= _length) { + const auto native_index = uint256_t(index.get_value()); + if (native_index >= _length) { // TODO: what's best practise here? We are assuming that this action will generate failing constraints, // and we set failure message here so that it better describes the point of failure. // However, we are not *ensuring* that failing constraints are generated at the point that `failure()` is @@ -197,8 +216,15 @@ template field_t ram_table::read(const fiel index_wire = field_pt::from_witness_index(_context, _context->put_constant_variable(index.get_value())); } - uint32_t output_idx = _context->read_RAM_array(_ram_id, index_wire.normalize().get_witness_index()); - return field_pt::from_witness_index(_context, output_idx); + uint32_t output_idx = _context->read_RAM_array(_ram_id, index_wire.get_normalized_witness_index()); + auto element = field_pt::from_witness_index(_context, output_idx); + + const size_t cast_index = static_cast(static_cast(native_index)); + // If the index is legitimate, restore the tag + if (native_index < _length) { + element.set_origin_tag(_tags[cast_index]); + } + return element; } /** @@ -224,7 +250,7 @@ template void ram_table::write(const field_pt& index initialize_table(); field_pt index_wire = index; - auto native_index = index.get_value(); + const auto native_index = uint256_t(index.get_value()); if (index.is_constant()) { // need to write every array element at a constant index before doing reads/writes at prover-defined indices index_wire = field_pt::from_witness_index(_context, _context->put_constant_variable(native_index)); @@ -247,7 +273,13 @@ template void ram_table::write(const field_pt& index _index_initialized[cast_index] = true; } else { - _context->write_RAM_array(_ram_id, index_wire.normalize().get_witness_index(), value_wire.get_witness_index()); + _context->write_RAM_array( + _ram_id, index_wire.get_normalized_witness_index(), value_wire.get_normalized_witness_index()); + } + // Update the value of the stored tag, if index is legitimate + + if (native_index < _length) { + _tags[cast_index] = value.get_origin_tag(); } } diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/memory/ram_table.hpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/memory/ram_table.hpp index a4bcd52569e..11340666137 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/memory/ram_table.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/memory/ram_table.hpp @@ -1,6 +1,7 @@ #pragma once #include "../circuit_builders/circuit_builders_fwd.hpp" #include "../field/field.hpp" +#include "barretenberg/transcript/origin_tag.hpp" namespace bb::stdlib { @@ -48,6 +49,8 @@ template class ram_table { private: std::vector _raw_entries; + // Origin Tags for detection of dangerous interactions within stdlib primitives + mutable std::vector _tags; mutable std::vector _index_initialized; size_t _length = 0; mutable size_t _ram_id = 0; // Builder identifier for this ROM table diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/memory/ram_table.test.cpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/memory/ram_table.test.cpp index 848e97cf970..e047dbe8730 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/memory/ram_table.test.cpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/memory/ram_table.test.cpp @@ -3,6 +3,7 @@ #include "barretenberg/circuit_checker/circuit_checker.hpp" #include "barretenberg/numeric/random/engine.hpp" #include "barretenberg/stdlib_circuit_builders/ultra_circuit_builder.hpp" +#include "barretenberg/transcript/origin_tag.hpp" #include "ram_table.hpp" using namespace bb; @@ -15,8 +16,58 @@ using ram_table_ct = stdlib::ram_table; namespace { auto& engine = numeric::get_debug_randomness(); } +STANDARD_TESTING_TAGS + +/** + * @brief Check that Origin Tags within the ram table are propagated correctly (when we lookup an element it has the + * same tag as the one inserted originally) + * + */ +TEST(RamTable, TagCorrectness) +{ + + Builder builder; + std::vector table_values; + + // Generate random witnesses + field_ct entry_1 = witness_ct(&builder, bb::fr::random_element()); + field_ct entry_2 = witness_ct(&builder, bb::fr::random_element()); + field_ct entry_3 = witness_ct(&builder, bb::fr::random_element()); + + // Tag them with 3 different tags + entry_1.set_origin_tag(submitted_value_origin_tag); + entry_2.set_origin_tag(challenge_origin_tag); + // The last tag is an instant death tag, that triggers a runtime failure if any computation happens on the element + entry_3.set_origin_tag(instant_death_tag); + + table_values.emplace_back(entry_1); + table_values.emplace_back(entry_2); + table_values.emplace_back(entry_3); + + // Initialize the table + ram_table_ct table(table_values); + + // Check that each element has the same tag as original entries + EXPECT_EQ(table.read(field_ct(0)).get_origin_tag(), submitted_value_origin_tag); + EXPECT_EQ(table.read(field_ct(witness_ct(&builder, 0))).get_origin_tag(), submitted_value_origin_tag); + EXPECT_EQ(table.read(field_ct(1)).get_origin_tag(), challenge_origin_tag); + EXPECT_EQ(table.read(field_ct(witness_ct(&builder, 1))).get_origin_tag(), challenge_origin_tag); + + // Replace one of the elements in the table with a new one + entry_2.set_origin_tag(next_challenge_tag); + table.write(field_ct(1), entry_2); + + // Check that the tag has been updated accordingly + EXPECT_EQ(table.read(field_ct(1)).get_origin_tag(), next_challenge_tag); + EXPECT_EQ(table.read(field_ct(witness_ct(&builder, 1))).get_origin_tag(), next_challenge_tag); + +#ifndef NDEBUG + // Check that interacting with the poisoned element causes a runtime error + EXPECT_THROW(table.read(0) + table.read(2), std::runtime_error); +#endif +} -TEST(ram_table, ram_table_init_read_consistency) +TEST(RamTable, RamTableInitReadConsistency) { Builder builder; @@ -50,7 +101,7 @@ TEST(ram_table, ram_table_init_read_consistency) EXPECT_EQ(verified, true); } -TEST(ram_table, ram_table_read_write_consistency) +TEST(RamTable, RamTableReadWriteConsistency) { Builder builder; const size_t table_size = 10; diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/memory/rom_table.cpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/memory/rom_table.cpp index a217e8ee51b..77c9d08e2ff 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/memory/rom_table.cpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/memory/rom_table.cpp @@ -22,6 +22,12 @@ template rom_table::rom_table(const std::vector void rom_table::initialize_table() const // populate table. Table entries must be normalized and cannot be constants for (const auto& entry : raw_entries) { if (entry.is_constant()) { - entries.emplace_back( - field_pt::from_witness_index(context, context->put_constant_variable(entry.get_value()))); + auto fixed_witness = + field_pt::from_witness_index(context, context->put_constant_variable(entry.get_value())); + fixed_witness.set_origin_tag(entry.get_origin_tag()); + entries.emplace_back(fixed_witness); + } else { entries.emplace_back(entry.normalize()); } @@ -49,6 +58,11 @@ template void rom_table::initialize_table() const context->set_ROM_element(rom_id, i, entries[i].get_witness_index()); } + // Preserve tags to restore them in the future lookups + _tags.resize(raw_entries.size()); + for (size_t i = 0; i < length; ++i) { + _tags[i] = raw_entries[i].get_origin_tag(); + } initialized = true; } @@ -56,6 +70,7 @@ template rom_table::rom_table(const rom_table& other) : raw_entries(other.raw_entries) , entries(other.entries) + , _tags(other._tags) , length(other.length) , rom_id(other.rom_id) , initialized(other.initialized) @@ -66,6 +81,7 @@ template rom_table::rom_table(rom_table&& other) : raw_entries(other.raw_entries) , entries(other.entries) + , _tags(other._tags) , length(other.length) , rom_id(other.rom_id) , initialized(other.initialized) @@ -76,6 +92,7 @@ template rom_table& rom_table::operator=(co { raw_entries = other.raw_entries; entries = other.entries; + _tags = other._tags; length = other.length; rom_id = other.rom_id; initialized = other.initialized; @@ -87,6 +104,7 @@ template rom_table& rom_table::operator=(ro { raw_entries = other.raw_entries; entries = other.entries; + _tags = other._tags; length = other.length; rom_id = other.rom_id; initialized = other.initialized; @@ -112,13 +130,24 @@ template field_t rom_table::operator[](cons if (context == nullptr) { context = index.get_context(); } + initialize_table(); - if (uint256_t(index.get_value()) >= length) { + const auto native_index = uint256_t(index.get_value()); + if (native_index >= length) { context->failure("rom_table: ROM array access out of bounds"); } - uint32_t output_idx = context->read_ROM_array(rom_id, index.normalize().get_witness_index()); - return field_pt::from_witness_index(context, output_idx); + uint32_t output_idx = context->read_ROM_array(rom_id, index.get_normalized_witness_index()); + auto element = field_pt::from_witness_index(context, output_idx); + + const size_t cast_index = static_cast(static_cast(native_index)); + + // If the index is legitimate, restore the tag + if (native_index < length) { + + element.set_origin_tag(_tags[cast_index]); + } + return element; } template class rom_table; diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/memory/rom_table.hpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/memory/rom_table.hpp index 68c0f2d1a13..45aef617ff8 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/memory/rom_table.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/memory/rom_table.hpp @@ -1,6 +1,7 @@ #pragma once #include "../circuit_builders/circuit_builders_fwd.hpp" #include "../field/field.hpp" +#include "barretenberg/transcript/origin_tag.hpp" namespace bb::stdlib { @@ -34,6 +35,8 @@ template class rom_table { private: std::vector raw_entries; mutable std::vector entries; + // Origin Tags for detecting problematic interactions of stdlib primitives + mutable std::vector _tags; size_t length = 0; mutable size_t rom_id = 0; // Builder identifier for this ROM table mutable bool initialized = false; diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/memory/rom_table.test.cpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/memory/rom_table.test.cpp index bb0eecf8756..9b259806464 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/memory/rom_table.test.cpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/memory/rom_table.test.cpp @@ -4,6 +4,7 @@ #include "barretenberg/circuit_checker/circuit_checker.hpp" #include "barretenberg/numeric/random/engine.hpp" #include "barretenberg/stdlib_circuit_builders/ultra_circuit_builder.hpp" +#include "barretenberg/transcript/origin_tag.hpp" #include "rom_table.hpp" using namespace bb; @@ -17,8 +18,46 @@ using rom_table_ct = stdlib::rom_table; namespace { auto& engine = numeric::get_debug_randomness(); } +STANDARD_TESTING_TAGS -TEST(rom_table, rom_table_read_write_consistency) +/** + * @brief Ensure the tags of elements initializing the ROM table are correctly propagated + * + */ +TEST(RomTable, TagCorrectness) +{ + + Builder builder; + std::vector table_values; + // Create random witness elements + field_ct entry_1 = witness_ct(&builder, bb::fr::random_element()); + field_ct entry_2 = witness_ct(&builder, bb::fr::random_element()); + field_ct entry_3 = witness_ct(&builder, bb::fr::random_element()); + + // Tag all 3 with different tags + entry_1.set_origin_tag(submitted_value_origin_tag); + entry_2.set_origin_tag(challenge_origin_tag); + // The last one is "poisoned" (calculating with this element should result in runtime error) + entry_3.set_origin_tag(instant_death_tag); + + table_values.emplace_back(entry_1); + table_values.emplace_back(entry_2); + table_values.emplace_back(entry_3); + + // Initialize the table with them + rom_table_ct table(table_values); + + // Check that the tags of the first two are preserved + EXPECT_EQ(table[field_ct(witness_ct(&builder, 0))].get_origin_tag(), submitted_value_origin_tag); + EXPECT_EQ(table[field_ct(witness_ct(&builder, 1))].get_origin_tag(), challenge_origin_tag); + +#ifndef NDEBUG + // Check that computing the sum with the last once crashes the program + EXPECT_THROW(table[0] + table[2], std::runtime_error); +#endif +} + +TEST(RomTable, RomTableReadWriteConsistency) { Builder builder; diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/memory/twin_rom_table.cpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/memory/twin_rom_table.cpp index 1f1ac0a58ce..fa214d52a1c 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/memory/twin_rom_table.cpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/memory/twin_rom_table.cpp @@ -27,6 +27,12 @@ twin_rom_table::twin_rom_table(const std::vector void twin_rom_table::initialize_table() con context->set_ROM_element_pair( rom_id, i, std::array{ entries[i][0].get_witness_index(), entries[i][1].get_witness_index() }); } + + // Ensure that the origin tags of all entries are preserved so we can assign them on lookups + tags.resize(length); + for (size_t i = 0; i < length; ++i) { + tags[i] = { raw_entries[i][0].get_origin_tag(), raw_entries[i][1].get_origin_tag() }; + } initialized = true; } @@ -68,6 +80,7 @@ template twin_rom_table::twin_rom_table(const twin_rom_table& other) : raw_entries(other.raw_entries) , entries(other.entries) + , tags(other.tags) , length(other.length) , rom_id(other.rom_id) , initialized(other.initialized) @@ -78,6 +91,7 @@ template twin_rom_table::twin_rom_table(twin_rom_table&& other) : raw_entries(other.raw_entries) , entries(other.entries) + , tags(other.tags) , length(other.length) , rom_id(other.rom_id) , initialized(other.initialized) @@ -88,6 +102,7 @@ template twin_rom_table& twin_rom_table::op { raw_entries = other.raw_entries; entries = other.entries; + tags = other.tags; length = other.length; rom_id = other.rom_id; initialized = other.initialized; @@ -99,6 +114,7 @@ template twin_rom_table& twin_rom_table::op { raw_entries = other.raw_entries; entries = other.entries; + tags = other.tags; length = other.length; rom_id = other.rom_id; initialized = other.initialized; @@ -126,16 +142,26 @@ std::array, 2> twin_rom_table::operator[](const field_ if (context == nullptr) { context = index.get_context(); } + initialize_table(); if (uint256_t(index.get_value()) >= length) { context->failure("twin_rom_table: ROM array access out of bounds"); } auto output_indices = context->read_ROM_array_pair(rom_id, index.normalize().get_witness_index()); - return field_pair_pt{ + auto pair = field_pair_pt{ field_pt::from_witness_index(context, output_indices[0]), field_pt::from_witness_index(context, output_indices[1]), }; + + const auto native_index = uint256_t(index.get_value()); + const size_t cast_index = static_cast(static_cast(native_index)); + // In case of a legitimate lookup, restore the tags of the original entries to the output + if (native_index < length) { + pair[0].set_origin_tag(tags[cast_index][0]); + pair[1].set_origin_tag(tags[cast_index][1]); + } + return pair; } template class twin_rom_table; diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/memory/twin_rom_table.hpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/memory/twin_rom_table.hpp index be0d4912d1e..757af5ae696 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/memory/twin_rom_table.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/memory/twin_rom_table.hpp @@ -1,6 +1,7 @@ #pragma once #include "../circuit_builders/circuit_builders_fwd.hpp" #include "../field/field.hpp" +#include "barretenberg/transcript/origin_tag.hpp" namespace bb::stdlib { @@ -36,6 +37,9 @@ template class twin_rom_table { private: std::vector raw_entries; mutable std::vector entries; + + // Origin Tags used for tracking dangerous interactions in stdlib primtives + mutable std::vector> tags; size_t length = 0; mutable size_t rom_id = 0; // Builder identifier for this ROM table mutable bool initialized = false; diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/memory/twin_rom_table.test.cpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/memory/twin_rom_table.test.cpp new file mode 100644 index 00000000000..55aeecf7fb2 --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/memory/twin_rom_table.test.cpp @@ -0,0 +1,131 @@ + +#include +#include + +#include "barretenberg/circuit_checker/circuit_checker.hpp" +#include "barretenberg/numeric/random/engine.hpp" +#include "barretenberg/stdlib_circuit_builders/ultra_circuit_builder.hpp" +#include "barretenberg/transcript/origin_tag.hpp" +#include "twin_rom_table.hpp" + +using namespace bb; + +// Defining ultra-specific types for local testing. +using Builder = UltraCircuitBuilder; +using field_ct = stdlib::field_t; +using witness_ct = stdlib::witness_t; +using twin_rom_table_ct = stdlib::twin_rom_table; +using field_pair_ct = std::array; + +namespace { +auto& engine = numeric::get_debug_randomness(); +} +STANDARD_TESTING_TAGS + +/** + * @brief Check the correctness of tag propagation within the twin rom tables + * + */ +TEST(TwinRomTable, TagCorrectness) +{ + + Builder builder; + std::vector table_values; + + // Create random entries + field_ct entry_1 = witness_ct(&builder, bb::fr::random_element()); + field_ct entry_2 = witness_ct(&builder, bb::fr::random_element()); + field_ct entry_3 = witness_ct(&builder, bb::fr::random_element()); + field_ct entry_4 = witness_ct(&builder, bb::fr::random_element()); + + // Assign different standard tags to them + entry_1.set_origin_tag(submitted_value_origin_tag); + entry_2.set_origin_tag(challenge_origin_tag); + entry_3.set_origin_tag(next_challenge_tag); + + // Assign the instant death tag to one of the + // entries + // It causes an error in Debug if it is being merged with another tag (when arithmetic actions are being performed + // on it) + entry_4.set_origin_tag(instant_death_tag); + + // Form entries in the twin table + table_values.emplace_back(field_pair_ct{ entry_1, entry_2 }); + table_values.emplace_back(field_pair_ct{ entry_3, entry_4 }); + + // Initialize the table + twin_rom_table_ct table(table_values); + + // Check that the tags in positions [0][0], [0][1], [1][0] are preserved + EXPECT_EQ(table[field_ct(witness_ct(&builder, 0))][0].get_origin_tag(), submitted_value_origin_tag); + EXPECT_EQ(table[field_ct(witness_ct(&builder, 0))][1].get_origin_tag(), challenge_origin_tag); + EXPECT_EQ(table[field_ct(1)][0].get_origin_tag(), next_challenge_tag); + +#ifndef NDEBUG + // Check that working with position [1][1] in debug causes "instant death" + EXPECT_THROW(table[1][1] + 1, std::runtime_error); +#endif +} + +/** + * @brief Check the consistency of read-write operations in the TwinRomTable + * + */ +TEST(TwinRomTable, ReadWriteConsistency) +{ + Builder builder; + + std::vector table_values; + const size_t table_size = 10; + // Generate random witness pairs to put in the table + for (size_t i = 0; i < table_size; ++i) { + table_values.emplace_back(field_pair_ct{ witness_ct(&builder, bb::fr::random_element()), + witness_ct(&builder, bb::fr::random_element()) }); + } + + // Initialize the table + twin_rom_table_ct table(table_values); + + field_pair_ct result{ field_ct(0), field_ct(0) }; + std::array expected{ 0, 0 }; + + // Go throught the cycle of accessing all entries + for (size_t i = 0; i < 10; ++i) { + field_ct index(witness_ct(&builder, (uint64_t)i)); + + if (i % 2 == 0) { + const auto before_n = builder.num_gates; + // Get the entry from the table + const auto to_add = table[index]; + const auto after_n = builder.num_gates; + // should cost 1 gates (the ROM read adds 1 extra gate when the proving key is constructed) + // (but not for 1st entry, the 1st ROM read also builts the ROM table, which will cost table_size * 2 gates) + if (i != 0) { + EXPECT_EQ(after_n - before_n, 1ULL); + } + // Accumulate each of the positions in the result + result[0] += to_add[0]; // variable lookup + result[1] += to_add[1]; // variable lookup + } else { + const auto before_n = builder.num_gates; + const auto to_add = table[i]; // constant lookup + const auto after_n = builder.num_gates; + // should cost 0 gates. Constant lookups are free + EXPECT_EQ(after_n - before_n, 0ULL); + result[0] += to_add[0]; + result[1] += to_add[1]; + } + // Accumulate original values + auto expected_values = table_values[i]; + expected[0] += expected_values[0].get_value(); + expected[1] += expected_values[1].get_value(); + } + + // Check that the sum of the original values is the same as the sum of the ones received from the TwinRomTable + // primitive + EXPECT_EQ(result[0].get_value(), expected[0]); + EXPECT_EQ(result[1].get_value(), expected[1]); + + bool verified = CircuitChecker::check(builder); + EXPECT_EQ(verified, true); +}