Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cheaper bootloader #2529

Draft
wants to merge 2 commits into
base: poseidon2_half_output
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions riscv-executor/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2429,7 +2429,9 @@ impl<F: FieldElement> Executor<'_, '_, F> {
None
}
Instruction::poseidon2_gl => {
let input_ptr = self.proc.get_reg_mem(args[0].u()).u();
let stride = self.proc.get_reg_mem(args[2].u()).u();

let input_ptr = self.proc.get_reg_mem(args[0].u()).u() + stride * 32;
assert!(is_multiple_of_4(input_ptr));

let inputs: [F; 8] = (0..8)
Expand All @@ -2438,7 +2440,7 @@ impl<F: FieldElement> Executor<'_, '_, F> {
.try_into()
.unwrap();

let output_half = self.proc.get_reg_mem(args[2].u()).u();
let output_half = self.proc.get_reg_mem(args[3].u()).u();

let result = poseidon2_gl::poseidon2_gl(&inputs);
let result = match output_half {
Expand All @@ -2449,7 +2451,8 @@ impl<F: FieldElement> Executor<'_, '_, F> {
_ => unreachable!(),
};

let output_ptr = self.proc.get_reg_mem(args[1].u()).u();
let output_byte_size = result.len() as u32 * 4;
let output_ptr = self.proc.get_reg_mem(args[1].u()).u() + stride * output_byte_size;
assert!(is_multiple_of_4(output_ptr));
result.iter().enumerate().for_each(|(i, v)| {
self.proc
Expand Down
5 changes: 3 additions & 2 deletions riscv/src/continuations/bootloader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,16 @@ pub const N_LEAVES_LOG: usize = MEMORY_SIZE_LOG - PAGE_SIZE_BYTES_LOG;
pub const MERKLE_TREE_DEPTH: usize = N_LEAVES_LOG + 1;
pub const PAGE_SIZE_BYTES: usize = 1 << PAGE_SIZE_BYTES_LOG;
pub const PAGE_NUMBER_MASK: usize = (1 << N_LEAVES_LOG) - 1;
pub const WORDS_PER_HASH: usize = 8;
pub const WORDS_PER_HASH: usize = 4;
pub const BOOTLOADER_INPUTS_PER_PAGE: usize =
WORDS_PER_PAGE + 1 + WORDS_PER_HASH + (MERKLE_TREE_DEPTH - 1) * WORDS_PER_HASH;
pub const MEMORY_HASH_START_INDEX: usize = 2 * (REGISTER_MEMORY_NAMES.len() + REGISTER_NAMES.len());
pub const NUM_PAGES_INDEX: usize = MEMORY_HASH_START_INDEX + WORDS_PER_HASH * 2;
pub const PAGE_INPUTS_OFFSET: usize = NUM_PAGES_INDEX + 1;

// Ensure we have enough addresses for the scratch space.
const_assert!(PAGE_SIZE_BYTES > 384);
// TODO: review this number
const_assert!(PAGE_SIZE_BYTES > 1024);

/// Computes the size of the bootloader given the number of input pages.
pub fn bootloader_size(accessed_pages: &BTreeSet<u32>) -> usize {
Expand Down
101 changes: 50 additions & 51 deletions riscv/src/large_field/bootloader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ bootloader_init:
// - x2: Current page index
// - x3: Current page number
// - x4: The ith bit of the page number (during Merkle proof validation)
// - x18-x25: The current memory hash
// - x5: Constant 1
// - x18-x21: The current memory hash
// - x9: 0: Merkle tree validation phase; 1: Merkle tree update phase
// - Scratch space for hash operations:
// - [0], [4], [8], [12], [16], [20], [24], [28] will usually contain the "current" hash (either in the context
Expand All @@ -172,15 +173,14 @@ bootloader_init:
load_bootloader_input 0, 1, 1, {NUM_PAGES_INDEX};
add_wrap 1, 0, 0, 1;

// Constant 1
add_wrap 0, 0, 1, 5;

// Initialize memory hash
load_bootloader_input 0, 18, 1, {MEMORY_HASH_START_INDEX};
load_bootloader_input 0, 19, 1, {MEMORY_HASH_START_INDEX} + 1;
load_bootloader_input 0, 20, 1, {MEMORY_HASH_START_INDEX} + 2;
load_bootloader_input 0, 21, 1, {MEMORY_HASH_START_INDEX} + 3;
load_bootloader_input 0, 22, 1, {MEMORY_HASH_START_INDEX} + 4;
load_bootloader_input 0, 23, 1, {MEMORY_HASH_START_INDEX} + 5;
load_bootloader_input 0, 24, 1, {MEMORY_HASH_START_INDEX} + 6;
load_bootloader_input 0, 25, 1, {MEMORY_HASH_START_INDEX} + 7;

// Current page index
set_reg 2, 0;
Expand All @@ -201,59 +201,58 @@ fail;

page_number_ok:

// Store & hash {WORDS_PER_PAGE} page words. This is an unrolled loop that for each word:
// - Loads the word into the P{{(i % 4) + 4}} register
// - Stores the word at the address x3 * {PAGE_SIZE_BYTES} + i * {BYTES_PER_WORD}
// - If i % 4 == 3: Hashes mem[0, 96), storing the result in mem[0, 32)
// Build the Merkle Tree of a single page, where each leaf is made up of 8 words (32 bytes), for
// a total of 64 leafs.
//
// We hash one level at a time, saving the intermediate hash valus in the scratch space.
// Each level will output this many words to the the scratch space:
// - level 0: 256 field element words
// - level 1: 128 field element words
// - level 2: 64 field element words
// - level 3: 32 field element words
// - level 4: 16 field element words
// - level 5: 8 field element words
// - level 6: 4 field element words (final page hash)
//
// At the end of the loop, we'll have a linear hash of the page in [0, 32), using a Merkle-Damgård
// construction. The initial [0, 32) values are 0, and the capacity [64, 96) is 0 throughout the
// bootloader execution.

mstore_bootloader 0, 0, 0, 0;
mstore_bootloader 0, 0, 4, 0;
mstore_bootloader 0, 0, 8, 0;
mstore_bootloader 0, 0, 12, 0;
mstore_bootloader 0, 0, 16, 0;
mstore_bootloader 0, 0, 20, 0;
mstore_bootloader 0, 0, 24, 0;
mstore_bootloader 0, 0, 28, 0;
mstore_bootloader 0, 0, 32, 0;
mstore_bootloader 0, 0, 36, 0;
mstore_bootloader 0, 0, 40, 0;
mstore_bootloader 0, 0, 44, 0;
mstore_bootloader 0, 0, 48, 0;
mstore_bootloader 0, 0, 52, 0;
mstore_bootloader 0, 0, 56, 0;
mstore_bootloader 0, 0, 60, 0;
mstore_bootloader 0, 0, 64, 0;
mstore_bootloader 0, 0, 68, 0;
mstore_bootloader 0, 0, 72, 0;
mstore_bootloader 0, 0, 76, 0;
mstore_bootloader 0, 0, 80, 0;
mstore_bootloader 0, 0, 84, 0;
mstore_bootloader 0, 0, 88, 0;
mstore_bootloader 0, 0, 92, 0;
// Level 0 is special in that it hashes the leafs directly from the page location in
// memory, using unconstrained reads, where the witgen must supply the page data.
//
// The remaining levels are hashed with normal memory operations.

// Hashes for level 0

"#,
));
));

bootloader.push_str(&format!("affine 3, 90, {PAGE_SIZE_BYTES}, 0;\n"));
for i in 0..WORDS_PER_PAGE {
// Multiply the index by 8 to skip 2 words, 4 words,
// one used for the actual 32-bit word and a zero.
let idx = ((i % 4) + 4) * WORDS_PER_HASH;
// Level 0 is special because it reads from the pages themselves.
let hash_count = WORDS_PER_PAGE / 8;
for hash_idx in 0..hash_count {
// Read 8 words from val(x3) * PAGE_SIZE_BYTES + hash_idx * 8 * BYTES_PER_WORD,
let input_offset = hash_idx * 8 * BYTES_PER_WORD;
// Output 4 FE words packed from the start of the scratch space.
let output_addr = hash_idx * WORDS_PER_HASH * BYTES_PER_WORD;
bootloader.push_str(&format!(
"poseidon2_gl_bootloader 3, {PAGE_SIZE_BYTES}, {input_offset}, {output_addr};"
));
}

// The other levels reads and outputs to the scratch space.
// The number of hashes in this level is half the number of hashes in the previous level.
for (hash_count, level) in itertools::iterate(hash_count / 2, |x| x / 2)
.take_while(|x| *x >= 8)
.zip(1..)
{
bootloader.push_str(&format!(
r#"
load_bootloader_input 2, 91, {BOOTLOADER_INPUTS_PER_PAGE}, {PAGE_INPUTS_OFFSET} + 1 + {i};
mstore 0, 0, {idx}, 91;
// Hashes for level {level}

affine 3, 90, {PAGE_SIZE_BYTES}, {i} * {BYTES_PER_WORD};
mstore_bootloader 90, 0, 0, 91;"#
"#
));

// Hash if buffer is full
if i % 4 == 3 {
bootloader.push_str("poseidon_gl 0, 0;");
for hash_idx in 0..hash_count {
// Hash the 8 words at the input address and store 4 words result at the output address. Register x5 is
// a constant 1, selecting the first half of the full poseidon state as output.
bootloader.push_str(&format!("poseidon2_gl 0, 0, {hash_idx}, 5;\n"));
}
}

Expand All @@ -262,7 +261,7 @@ mstore_bootloader 90, 0, 0, 91;"#
// == Merkle proof validation ==
// We commit to the memory content by hashing it in pages of {WORDS_PER_PAGE} words each.
// These hashes are stored in a binary Merkle tree of depth {MERKLE_TREE_DEPTH}.
// At this point, the current page hash is in mem[0, 32).
// At this point, the current page hash is in mem[0, 16).
//
// Now, we re-computed the Merkle root twice, in two phases:
// - First using the current page hash, as computed by the bootloader. The prover provides
Expand Down
8 changes: 8 additions & 0 deletions riscv/src/large_field/code_gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,14 @@ fn memory(with_bootloader: bool) -> String {
r#"
std::machines::large_field::memory_with_bootloader_write::MemoryWithBootloaderWrite memory(byte2, MIN_DEGREE, MAIN_MAX_DEGREE);

// Runs a Posseidon2GL permutation on uninitialized memory at val(X) * Y + Z,
// and stores the first half of the final state at val(W).
//
// 8 memory words are read and 4 memory words are written.
instr poseidon2_gl_bootloader X, Y, Z, W
link ~> tmp1_col = regs.mload(X, STEP)
link ~> poseidon2_gl.permute(tmp1_col * Y + Z, 0, W, STEP + 1);

// Stores val(W) at address (V = val(X) - val(Z) + Y) % 2**32.
// V can be between 0 and 2**33.
instr mstore_bootloader X, Z, Y, W
Expand Down
25 changes: 21 additions & 4 deletions riscv/src/large_field/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,15 +277,32 @@ impl Runtime {
None,
"poseidon2_gl",
vec!["memory", "MIN_DEGREE", "LARGE_SUBMACHINES_MAX_DEGREE"],
[r#"instr poseidon2_gl X, Y, Z
// Poseidon2GL permutation applied to 8 words read from val(X) + Z * 32,
// with 8 or 4 words result written to, respectively, val(Y) + Z * 16 or val(Y) + Z * 32,
// depending on val(W).
//
// If val(W) is 0, no output is written, if val(W) is 1 or 2, 4 words are written of the
// first or second half of the final state is written, respectivelly.
//
// If val(W) is 3, the full 8 words of the final state are written.
//
// Argument Z can be seen as a stride counter. It indexes both the input and output
// in arrays whose elements are one full input or output block.
[r#"instr poseidon2_gl X, Y, Z, W
link ~> tmp1_col = regs.mload(X, STEP)
link ~> tmp2_col = regs.mload(Y, STEP + 1)
link ~> tmp3_col = regs.mload(Z, STEP + 2)
link ~> poseidon2_gl.permute(tmp1_col, STEP, tmp2_col, STEP + 1, tmp3_col)
link ~> tmp3_col = regs.mload(W, STEP + 2)
link ~> poseidon2_gl.permute(tmp1_col + Z * 32, STEP, tmp2_col + Z * (16 + tmp4_col), STEP + 1, tmp3_col)
{
// make sure tmp1_col and tmp2_col are 4-byte aligned memory addresses
tmp1_col = 4 * (X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000),
tmp2_col = 4 * (Y_b5 + Y_b6 * 0x100 + Y_b7 * 0x10000 + Y_b8 * 0x1000000)

// Make tmp4_col = 0 if tmp3_col = 1 or 2, and tmp4_col = 16 if tmp3_col = 3
// this way we know the size of the output to be 16 or 32 bytes.
// if tmp3_col = 0, then tmp4_col doesn't matter. Any other value is forbidden
// by the Poseidon2GL machine.
tmp4_col = 8 * (tmp3_col - 1) * (tmp3_col - 2)
}
"#],
);
Expand Down Expand Up @@ -332,7 +349,7 @@ impl Runtime {
// they can overlap.
self.add_syscall(
Syscall::Poseidon2GL,
std::iter::once(format!("{} 10, 11, 12;", Syscall::Poseidon2GL.name())),
std::iter::once(format!("{} 10, 11, 0, 12;", Syscall::Poseidon2GL.name())),
);

self.add_syscall(
Expand Down
2 changes: 1 addition & 1 deletion std/machines/large_field/memory.asm
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ machine Memory(byte2: Byte2) with
// then the value is zero.
(1 - m_is_write') * m_change * m_value' = 0;

// change has to be 1 in the last row, so that a first read on row zero is constrained to return 0
// m_change has to be 1 in the last row, so that a first read on row zero is constrained to return 0
(1 - m_change) * LAST = 0;

// If the next line is a read and we stay at the same address, then the
Expand Down
22 changes: 9 additions & 13 deletions std/machines/large_field/memory_with_bootloader_write.asm
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,20 @@ use std::machines::range::Byte2;
use std::field::modulus;
use std::check::require_field_bits;

/// This machine is a slightly extended version of std::machines::memory::Memory,
/// where in addition to mstore, there is an mstore_bootloader operation. It behaves
/// just like mstore, except that the first access to each memory cell must come
/// from the mstore_bootloader operation.
/// This is a slightly different version of std::machines::memory::Memory,
/// where initial read is unconstrained unconstrained instead of returning 0, but
/// it can only happen when m_step is 0 (i.e., from the initial memory verification,
/// inside the bootloader).
machine MemoryWithBootloaderWrite(byte2: Byte2) with
latch: LATCH,
operation_id: operation_id,
operation_id: m_is_write,
call_selectors: selectors,
{
require_field_bits(32, || "Memory requires a field that fits any 32-Bit value.");
// lower bound degree is 65536

operation mload<0> m_addr, m_step -> m_value;
operation mstore<1> m_addr, m_step, m_value ->;
operation mstore_bootloader<2> m_addr, m_step, m_value ->;

let LATCH = 1;

Expand All @@ -34,19 +33,16 @@ machine MemoryWithBootloaderWrite(byte2: Byte2) with

// Memory operation flags
let m_is_write;
let m_is_bootloader_write;
std::utils::force_bool(m_is_write);
std::utils::force_bool(m_is_bootloader_write);
col operation_id = m_is_write + 2 * m_is_bootloader_write;

// is_write can only be 1 if a selector is active
let is_mem_op = array::sum(selectors);
std::utils::force_bool(is_mem_op);
(1 - is_mem_op) * m_is_write = 0;
(1 - is_mem_op) * m_is_bootloader_write = 0;

// The first operation of a new address has to be a bootloader write
m_change * (1 - m_is_bootloader_write') = 0;
// The first operation of a new address has to either have
// m_step = 0 (to mean it is inside the bootloader) or be a write.
m_change * m_step * (1 - m_is_write)= 0;

// m_change has to be 1 in the last row, so that the above constraint is triggered.
// An exception to this when the last address is -1, which is only possible if there is
Expand All @@ -57,7 +53,7 @@ machine MemoryWithBootloaderWrite(byte2: Byte2) with

// If the next line is a read and we stay at the same address, then the
// value cannot change.
(1 - m_is_write' - m_is_bootloader_write') * (1 - m_change) * (m_value' - m_value) = 0;
(1 - m_is_write') * (1 - m_change) * (m_value' - m_value) = 0;

let m_diff_lower;
let m_diff_upper;
Expand Down
Loading