Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mul goldilocks #1

Merged
merged 5 commits into from
Oct 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions risc0/zkvm/platform/io.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ constexpr size_t kGPIO_GetKey = 0x01F0010;
constexpr size_t kGPIO_SendRecvChannel = 0x01F00014;
constexpr size_t kGPIO_SendRecvSize = 0x01F00018;
constexpr size_t kGPIO_SendRecvAddr = 0x01F0001C;
constexpr size_t kGPIO_Mul = 0x01F00020;

// Standard ZKVM channels; must match zkvm/sdk/rust/platform/src/io.rs.

Expand Down Expand Up @@ -67,6 +68,15 @@ struct ShaDescriptor {
uint32_t digest;
};

struct MulDescriptor {
// Address of first byte of MUL data to process
// 64 bits for first operand and 64 bits for second
uint32_t source;

// 64 bit result
uint32_t result;
};

inline volatile ShaDescriptor* volatile* GPIO_SHA() {
return reinterpret_cast<volatile ShaDescriptor* volatile*>(kGPIO_SHA);
}
Expand Down
5 changes: 3 additions & 2 deletions risc0/zkvm/platform/memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ MEM_REGION(Input, 0x01E00000, k1MB)
MEM_REGION(GPIO, 0x01F00000, k1MB)
MEM_REGION(Prog, 0x02000000, 10 * k1MB)
MEM_REGION(SHA, 0x02A00000, k1MB)
MEM_REGION(WOM, 0x02B00000, 21 * k1MB)
MEM_REGION(Output, 0x02B00000, 20 * k1MB)
MEM_REGION(MUL, 0x02B00000, k1MB)
MEM_REGION(WOM, 0x02C00000, 20 * k1MB)
MEM_REGION(Output, 0x02C00000, 19 * k1MB)
MEM_REGION(Commit, 0x03F00000, k1MB)
// clang-format on

Expand Down
3 changes: 2 additions & 1 deletion risc0/zkvm/platform/risc0.ld
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ MEMORY {
gpio : ORIGIN = 0x01F00000, LENGTH = 1M
prog (X) : ORIGIN = 0x02000000, LENGTH = 10M
sha : ORIGIN = 0x02A00000, LENGTH = 1M
wom : ORIGIN = 0x02B00000, LENGTH = 21M
mul : ORIGIN = 0x02B00000, LENGTH = 1M
wom : ORIGIN = 0x02C00000, LENGTH = 20M
}

SECTIONS {
Expand Down
34 changes: 34 additions & 0 deletions risc0/zkvm/prove/io_handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,33 @@ static void processSHA(MemoryState& mem, const ShaDescriptor& desc) {
}
}

static void processMul(MemoryState& mem, const MulDescriptor& desc) {
uint32_t a_hi = mem.load(desc.source);
LOG(1, "Input[" << hex(0, 2) << "]: " << hex(desc.source) << " -> " << hex(a_hi));
uint32_t a_lo = mem.load(desc.source + 4);
LOG(1, "Input[" << hex(1, 2) << "]: " << hex(desc.source + 4) << " -> " << hex(a_lo));
uint32_t b_hi = mem.load(desc.source + 8);
LOG(1, "Input[" << hex(2, 2) << "]: " << hex(desc.source + 8) << " -> " << hex(b_hi));
uint32_t b_lo = mem.load(desc.source + 12);
LOG(1, "Input[" << hex(3, 2) << "]: " << hex(desc.source + 12) << " -> " << hex(b_lo));

uint64_t first = a_lo | (uint64_t(a_hi) << 32);
uint64_t second = b_lo | (uint64_t(b_hi) << 32);

__uint128_t result = __uint128_t(first) * __uint128_t(second);

// goldilocks
uint64_t moded_result = result % 0xFFFFFFFF00000001;

uint32_t high = (uint32_t)((moded_result & 0xFFFFFFFF00000000LL) >> 32);
uint32_t low = (uint32_t)(moded_result & 0xFFFFFFFFLL);

LOG(1, "Output[" << hex(0, 2) << "]: " << hex(desc.result) << " <- " << hex(high));
mem.store(desc.result, high);
LOG(1, "Output[" << hex(1, 2) << "]: " << hex(desc.result + 4) << " <- " << hex(low));
mem.store(desc.result + 4, low);
}

void IoHandler::onFault(const std::string& msg) {
throw std::runtime_error(msg);
}
Expand All @@ -63,6 +90,13 @@ void MemoryHandler::onInit(MemoryState& mem) {
void MemoryHandler::onWrite(MemoryState& mem, uint32_t cycle, uint32_t addr, uint32_t value) {
LOG(2, "MemoryHandler::onWrite> " << hex(addr) << ": " << hex(value));
switch (addr) {
case kGPIO_Mul: {
LOG(1, "MemoryHandler::onWrite> GPIO_MUL");
MulDescriptor desc;
mem.loadRegion(value, &desc, sizeof(desc));
processMul(mem, desc);
break;
}
case kGPIO_SHA: {
LOG(1, "MemoryHandler::onWrite> GPIO_SHA");
ShaDescriptor desc;
Expand Down
3 changes: 3 additions & 0 deletions risc0/zkvm/sdk/rust/guest/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ pub mod env;
/// Functions for computing SHA-256 hashes.
pub mod sha;

/// mul
pub mod mul;

/// Functions for handling input and output
pub mod io;

Expand Down
70 changes: 70 additions & 0 deletions risc0/zkvm/sdk/rust/guest/src/mul.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
use core::{cell::UnsafeCell, mem};

use crate::env::log;
use _alloc::format;
use _alloc::{boxed::Box, vec::Vec};
use risc0_zkvm::platform::{
io::{MulDescriptor, GPIO_MUL},
memory,
};

// Current sha descriptor index.
struct CurOutput(UnsafeCell<usize>);

// SAFETY: single threaded environment
unsafe impl Sync for CurOutput {}

static CUR_OUTPUT: CurOutput = CurOutput(UnsafeCell::new(0));

/// Result of multiply goldilocks
pub struct MulGoldilocks([u32; 2]);

impl MulGoldilocks {
/// Get the result as u64
pub fn get_u64(&self) -> u64 {
(self.0[1] as u64) | ((self.0[0] as u64) << 32)
}
}

fn alloc_output() -> *mut MulDescriptor {
// SAFETY: Single threaded and this is the only place we use CUR_DESC.
unsafe {
let cur_desc = CUR_OUTPUT.0.get();
let ptr = (memory::MUL.start() as *mut MulDescriptor).add(*cur_desc);
*cur_desc += 1;
ptr
}
}

/// Multiply goldilocks oracle, verification is done separately
pub fn mul_goldilocks(a: &u64, b: &u64) -> &'static MulGoldilocks {
let a_hi = ((a & 0xFFFFFFFF00000000) >> 32) as u32;
let a_lo = (a & 0xFFFFFFFF) as u32;

let b_hi = ((b & 0xFFFFFFFF00000000) >> 32) as u32;
let b_lo = (b & 0xFFFFFFFF) as u32;

let buf = [a_hi, a_lo, b_hi, b_lo];

unsafe {
let alloced = Box::<mem::MaybeUninit<MulGoldilocks>>::new(
mem::MaybeUninit::<MulGoldilocks>::uninit(),
);
let output = (*Box::into_raw(alloced)).as_mut_ptr();
mul_raw(&buf[..], output);
&*output
}
}

pub(crate) unsafe fn mul_raw(data: &[u32], result: *mut MulGoldilocks) {
let output_ptr = alloc_output();

let ptr = data.as_ptr();
super::memory_barrier(ptr);
output_ptr.write_volatile(MulDescriptor {
source: ptr as usize,
result: result as usize,
});

GPIO_MUL.as_ptr().write_volatile(output_ptr);
}
10 changes: 10 additions & 0 deletions risc0/zkvm/sdk/rust/platform/src/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ pub const GPIO_SENDRECV_CHANNEL: Gpio<u32> = Gpio::new(0x01F0_0014);
pub const GPIO_SENDRECV_SIZE: Gpio<usize> = Gpio::new(0x01F0_0018);
pub const GPIO_SENDRECV_ADDR: Gpio<*const u8> = Gpio::new(0x01F0_001C);

pub const GPIO_MUL: Gpio<*const MulDescriptor> = Gpio::new(0x01F0_0020);

pub mod addr {
pub const GPIO_SHA: u32 = super::GPIO_SHA.addr();
pub const GPIO_COMMIT: u32 = super::GPIO_COMMIT.addr();
Expand All @@ -59,6 +61,8 @@ pub mod addr {
pub const GPIO_SENDRECV_CHANNEL: u32 = super::GPIO_SENDRECV_CHANNEL.addr();
pub const GPIO_SENDRECV_SIZE: u32 = super::GPIO_SENDRECV_SIZE.addr();
pub const GPIO_SENDRECV_ADDR: u32 = super::GPIO_SENDRECV_ADDR.addr();

pub const GPIO_MUL: u32 = super::GPIO_MUL.addr();
}

#[repr(C)]
Expand All @@ -75,6 +79,12 @@ pub struct SHADescriptor {
pub digest: usize,
}

#[repr(C)]
pub struct MulDescriptor {
pub source: usize,
pub result: usize,
}

#[repr(C)]
pub struct GetKeyDescriptor {
pub name: u32,
Expand Down
5 changes: 3 additions & 2 deletions risc0/zkvm/sdk/rust/platform/src/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ pub const INPUT: Region = Region::new(0x01E0_0000, mb(1));
pub const GPIO: Region = Region::new(0x01F0_0000, mb(1));
pub const PROG: Region = Region::new(0x0200_0000, mb(10));
pub const SHA: Region = Region::new(0x02A0_0000, mb(1));
pub const WOM: Region = Region::new(0x02B0_0000, mb(21));
pub const OUTPUT: Region = Region::new(0x02B0_0000, mb(20));
pub const MUL: Region = Region::new(0x02B0_0000, mb(1));
pub const WOM: Region = Region::new(0x02C0_0000, mb(20));
pub const OUTPUT: Region = Region::new(0x02C0_0000, mb(19));
pub const COMMIT: Region = Region::new(0x03F0_0000, mb(1));