Skip to content

Commit

Permalink
Sum instruction
Browse files Browse the repository at this point in the history
  • Loading branch information
quackzar committed Aug 23, 2024
1 parent 74d39ad commit 021efc1
Show file tree
Hide file tree
Showing 7 changed files with 357 additions and 58 deletions.
2 changes: 1 addition & 1 deletion src/testing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ impl Cluster<Script<Element32>> {
let mut fueltanks = beaver::BeaverTriple::fake_many(&context, shared_rng, 2000);
let mut engine = Engine::<_, S, _>::new(context, network, private_rng);
engine.add_fuel(&mut fueltanks[context.me.0]);
engine.execute(&script).await
engine.execute(&script).await.unwrap_single()
})
.await
}
Expand Down
97 changes: 86 additions & 11 deletions src/vm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ pub mod parsing;
use std::future::Future;

Check failure on line 3 in src/vm/mod.rs

View workflow job for this annotation

GitHub Actions / cargo test

unused import: `std::future::Future`

use ff::Field;
use itertools::Itertools;
use itertools::{Either, Itertools};
use rand::RngCore;

use crate::{
algebra::math::Vector,
net::{network::Network, Id, SplitChannel},
net::{agency::Broadcast, network::Network, Id, SplitChannel},
protocols::beaver::{beaver_multiply, BeaverTriple},
schemes::interactive::{InteractiveShared, InteractiveSharedMany},
};
Expand All @@ -19,6 +19,29 @@ pub enum Value<F> {
Vector(Vector<F>),
}

impl<F> Value<F> {
pub fn unwrap_single(self) -> F {
match self {
Value::Single(v) => v,
_ => panic!("Was vector and not a single!"),
}
}

pub fn unwrap_vector(self) -> Vector<F> {
match self {
Value::Vector(v) => v,
_ => panic!("Was single and not a vector!"),
}
}

pub fn map<U>(self, func: impl Fn(F) -> U) -> Value<U> {
match self {
Value::Single(a) => Value::Single(func(a)),
Value::Vector(a) => Value::Vector(a.into_iter().map(func).collect()),
}
}
}

impl<F> From<F> for Value<F> {
fn from(value: F) -> Self {
Value::Single(value)
Expand Down Expand Up @@ -69,6 +92,7 @@ pub enum Instruction {
Add,
Mul,
Sub,
Sum(usize),
}

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -127,6 +151,31 @@ impl<S: InteractiveSharedMany> Stack<S> {
_ => panic!("no valid value found"),
}
}

pub fn take_singles(&mut self, n: usize) -> impl Iterator<Item = S> + '_ {
self.stack.drain(0..n).map(|v| match v {
SharedValue::Single(v) => v,
_ => panic!(),
})
}

pub fn take_vectors(&mut self, n: usize) -> impl Iterator<Item = S::VectorShare> + '_ {
self.stack.drain(0..n).map(|v| match v {
SharedValue::Vector(v) => v,
_ => panic!(),
})
}

pub fn take(
&mut self,
n: usize,
) -> Either<impl Iterator<Item = S> + '_, impl Iterator<Item = S::VectorShare> + '_> {
match self.stack.last() {
Some(SharedValue::Single(_)) => Either::Left(self.take_singles(n)),
Some(SharedValue::Vector(_)) => Either::Right(self.take_vectors(n)),
None => panic!(),
}
}
}

impl<C, S, R, F> Engine<C, S, R>
Expand All @@ -151,9 +200,9 @@ where

// TODO: Superscalar execution when awaiting.

pub async fn execute(&mut self, script: &Script<F>) -> F {
pub async fn execute(&mut self, script: &Script<F>) -> Value<F> {
let mut stack = Stack::new();
let mut results = vec![];
let mut results: Vec<Value<_>> = vec![];
let constants = &script.constants;

for opcode in script.instructions.iter() {
Expand All @@ -168,7 +217,7 @@ where
async fn step(
&mut self,
stack: &mut Stack<S>,
results: &mut Vec<F>,
results: &mut Vec<Value<F>>,
constants: &[Value<F>],
opcode: &Instruction,
) -> Result<(), S::Error> {
Expand Down Expand Up @@ -214,9 +263,16 @@ where
stack.push_vector(share)
}
Instruction::Recombine => {
let share = stack.pop_single();
let f = S::recombine(ctx, share, &mut coms).await?;
results.push(f);
match stack.pop() {
SharedValue::Single(share) => {
let f = S::recombine(ctx, share, &mut coms).await?;
results.push(Value::Single(f));
}
SharedValue::Vector(share) => {
let f = S::recombine_many(ctx, share, &mut coms).await?;
results.push(Value::Vector(f));
}
};
}
Instruction::Add => {
let a = stack.pop();
Expand Down Expand Up @@ -268,13 +324,32 @@ where
(SharedValue::Single(_), SharedValue::Vector(_)) => todo!(),
};
}
Instruction::Sum(size) => {
// Zero is a sentinal value that represents the party size.
let size = if *size == 0 {
self.network.size()
} else {
*size
};
let res = match stack.take(size) {
Either::Left(iter) => {
let res = iter.reduce(|s, acc| acc + s).unwrap();
SharedValue::Single(res)
}
Either::Right(iter) => {
let res = iter.reduce(|s, acc| acc + &s).unwrap();
SharedValue::Vector(res)
}
};
stack.push(res)
}
}
Ok(())
}

pub async fn raw<Func, Out>(&mut self, routine: Func) -> Out
where
Func: async Fn(&mut Network<C>, &S::Context, &mut R) -> Out,
Func: async Fn(&mut Network<C>, &mut S::Context, &mut R) -> Out,
{
// TODO: Add other resources.
routine(&mut self.network, &mut self.context, &mut self.rng).await
Expand Down Expand Up @@ -327,7 +402,7 @@ mod tests {
Recombine,
],
};
let res: u32 = engine.execute(&script).await.into();
let res: u32 = engine.execute(&script).await.unwrap_single().into();
engine.network.shutdown().await.unwrap();
res
})
Expand Down Expand Up @@ -364,7 +439,7 @@ mod tests {
.with_args([a, b])
.run_with_args(|net, script| async move {
let mut engine = dumb_engine(net);
let res: u32 = engine.execute(&script).await.into();
let res: u32 = engine.execute(&script).await.unwrap_single().into();
engine.network.shutdown().await.unwrap();
res
})
Expand Down
135 changes: 135 additions & 0 deletions src/vm/parsing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use itertools::Itertools;
/// Pseudo-parsing direct to an array-backed AST which just is a bytecode stack.
use std::{
array,
iter::Sum,
ops::{Add, Mul, Sub},
};

Expand All @@ -18,6 +19,13 @@ pub struct Exp<F> {
instructions: Vec<Instruction>,
}

// A dynamicly sized list of expressions.
#[derive(Debug)]
pub struct ExpList<F> {
constant: Value<F>,
}

// An opened expression (last step)
#[derive(Debug)]
pub struct Opened<F>(Exp<F>);

Expand Down Expand Up @@ -49,6 +57,20 @@ impl<F> Exp<F> {
}
}

// This is slighty cursed.
pub fn symmetric_share(secret: impl Into<F>) -> ExpList<F> {
ExpList {
constant: Value::Single(secret.into()),
}
}

// This is slighty cursed.
pub fn symmetric_share_vec(secret: impl Into<Vector<F>>) -> ExpList<F> {
ExpList {
constant: Value::Vector(secret.into()),
}
}

/// Secret share into a field value
///
/// * `secret`: value to secret share
Expand Down Expand Up @@ -109,6 +131,7 @@ impl<F> Exp<F> {
| Instruction::RecvVec(_)
| Instruction::Recombine
| Instruction::Add
| Instruction::Sum(_)
| Instruction::Mul
| Instruction::Sub => (),
}
Expand Down Expand Up @@ -152,6 +175,35 @@ impl<T> Opened<T> {
}
}

impl<F> ExpList<F> {
/// Promise that the explist is `size` long
///
/// This will then assume that there a `size` on the stack when executing.
pub fn concrete(self, own: usize, size: usize) -> Vec<Exp<F>> {
let mut me = Some(Exp {
constants: vec![self.constant],
instructions: vec![Instruction::SymShare(Const(0))],
});
(0..size)
.map(|id| {
if id == own {
me.take().unwrap()
} else {
Exp::empty()
}
})
.collect()
}

pub fn sum(self) -> Exp<F> {
use Instruction as I;
Exp {
constants: vec![self.constant],
instructions: vec![I::SymShare(Const(0)), I::Sum(0)],
}
}
}

impl<F> Add for Exp<F> {
type Output = Self;

Expand Down Expand Up @@ -192,6 +244,17 @@ impl<F> Sub for Exp<F> {
}
}

impl<F> Sum for Exp<F> {
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
let (mut exp, size) = iter.fold((Exp::empty(), 0usize), |(mut acc, count), exp| {
acc.append(exp);
(acc, count + 1)
});
exp.instructions.push(Instruction::Sum(size));
exp
}
}

#[cfg(test)]
mod test {
use crate::{
Expand Down Expand Up @@ -274,4 +337,76 @@ mod test {
// 1 + 2 * 3 = 7
assert_eq!(res, vec![7u32.into(), 7u32.into(), 7u32.into()]);
}

#[tokio::test]
async fn explist() {
let inputs = [1, 2, 3u32];
let res = Cluster::new(3)
.with_args(
(0..3)
.map(|id| {
let me = Id(id);
type E = Exp<Element32>;
let exp = E::symmetric_share(inputs[id]);
let [a, b, c]: [E; 3] = exp.concrete(id, 3).try_into().unwrap();
let sum = a + b * c; // no need to implement precedence!
let res = sum.open();
res.finalize()
})
.collect::<Vec<_>>(),
)
.execute_mock()
.await
.unwrap();

// 1 + 2 * 3 = 7
assert_eq!(res, vec![7u32.into(), 7u32.into(), 7u32.into()]);
}

#[tokio::test]
async fn sum() {
let inputs = [1, 2, 3u32];
let res = Cluster::new(3)
.with_args(
(0..3)
.map(|id| {
let me = Id(id);
type E = Exp<Element32>;
let [a, b, c] = E::share_and_receive(inputs[id], me);
let sum: E = [a, b, c].into_iter().sum();
let res = sum.open();
res.finalize()
})
.collect::<Vec<_>>(),
)
.execute_mock()
.await
.unwrap();

// 1 + 2 + 3 = 6
assert_eq!(res, vec![6u32.into(), 6u32.into(), 6u32.into()]);
}

#[tokio::test]
async fn sum_explist() {
let inputs = [1, 2, 3u32];
let res = Cluster::new(3)
.with_args(
(0..3)
.map(|id| {
type E = Exp<Element32>;
let exp = E::symmetric_share(inputs[id]);
let sum = exp.sum();
let res = sum.open();
res.finalize()
})
.collect::<Vec<_>>(),
)
.execute_mock()
.await
.unwrap();

// 1 + 2 + 3 = 6
assert_eq!(res, vec![6u32.into(), 6u32.into(), 6u32.into()]);
}
}
1 change: 1 addition & 0 deletions wecare/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions wecare/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ rand = "0.8.5"
fixed = "2.0.0-alpha.11"
enum_dispatch = "0.3.13"
ff = "0.13.0"
castaway = "0.2.3"

[dev-dependencies]
criterion = "0.5.1"
Expand Down
Loading

0 comments on commit 021efc1

Please sign in to comment.