Skip to content

Commit

Permalink
refactor: lookup safety during cal should be selectable (#678)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-camuto authored Jan 3, 2024
1 parent 9e68cb8 commit de48ae1
Show file tree
Hide file tree
Showing 21 changed files with 142 additions and 7,395 deletions.
1 change: 0 additions & 1 deletion examples/notebooks/ezkl_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,6 @@
" proof_path,\n",
" settings_path,\n",
" vk_path,\n",
" \n",
" )\n",
"\n",
"assert res == True\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/notebooks/svm.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 10,
"id": "b1c561a8",
"metadata": {},
"outputs": [],
Expand Down
10 changes: 6 additions & 4 deletions examples/onnx/1l_topk/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@ def __init__(self):
super(MyModel, self).__init__()

def forward(self, x):
topk = torch.topk(x, 4)
print(topk)
return [topk.values]
topk_largest = torch.topk(x, 4)
topk_smallest = torch.topk(x, 4, largest=False)
print(topk_largest)
print(topk_smallest)
return topk_largest.values + topk_smallest.values


circuit = MyModel()
Expand All @@ -21,7 +23,7 @@ def forward(self, x):

torch.onnx.export(circuit, x, "network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=10, # the ONNX version to export the model to
opset_version=14, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'], # the model's output names
Expand Down
2 changes: 1 addition & 1 deletion examples/onnx/1l_topk/input.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"input_data": [[2, 1, 5, 3, 3, 7]], "output_data": [[7, 5, 3, 3]]}
{"input_data": [[2, 1, 5, 4, 8, 2]], "output_data": [[9, 7, 6, 6]]}
Binary file modified examples/onnx/1l_topk/network.onnx
Binary file not shown.
13 changes: 8 additions & 5 deletions src/circuit/ops/hybrid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ pub enum HybridOp {
TopK {
dim: usize,
k: usize,
largest: bool,
},
OneHot {
dim: usize,
Expand Down Expand Up @@ -151,8 +152,8 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
let res = tensor::ops::one_hot(&x, *num_classes, *dim)?;
(res.clone(), inter_equals)
}
HybridOp::TopK { dim, k } => {
let res = tensor::ops::topk_axes(&x, *k, *dim)?;
HybridOp::TopK { dim, k, largest } => {
let res = tensor::ops::topk_axes(&x, *k, *dim, *largest)?;

let mut inter_equals = x
.clone()
Expand Down Expand Up @@ -302,7 +303,9 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
HybridOp::LessEqual => "LESSEQUAL".into(),
HybridOp::Equals => "EQUALS".into(),
HybridOp::Gather { dim, .. } => format!("GATHER (dim={})", dim),
HybridOp::TopK { k, dim } => format!("TOPK (k={}, dim={})", k, dim),
HybridOp::TopK { k, dim, largest } => {
format!("TOPK (k={}, dim={}, largest={})", k, dim, largest)
}
HybridOp::GatherElements { dim, .. } => format!("GATHERELEMENTS (dim={})", dim),
HybridOp::ScatterElements { dim, .. } => format!("SCATTERELEMENTS (dim={})", dim),
HybridOp::OneHot { dim, num_classes } => {
Expand Down Expand Up @@ -400,8 +403,8 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
HybridOp::Less => layouts::less(config, region, values[..].try_into()?)?,
HybridOp::LessEqual => layouts::less_equal(config, region, values[..].try_into()?)?,
HybridOp::Equals => layouts::equals(config, region, values[..].try_into()?)?,
HybridOp::TopK { dim, k } => {
layouts::topk_axes(config, region, values[..].try_into()?, *k, *dim)?
HybridOp::TopK { dim, k, largest } => {
layouts::topk_axes(config, region, values[..].try_into()?, *k, *dim, *largest)?
}
HybridOp::OneHot { dim, num_classes } => {
layouts::one_hot_axis(config, region, values[..].try_into()?, *num_classes, *dim)?
Expand Down
12 changes: 9 additions & 3 deletions src/circuit/ops/layouts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ fn _sort_ascending<F: PrimeField + TensorType + PartialOrd>(
.get_int_evals()?
.iter()
.sorted_by(|a, b| a.cmp(b))
.map(|x| Ok(Value::known(input.get_felt_evals()?.get(&[*x as usize]))))
.map(|x| Ok(Value::known(i128_to_felt(*x))))
.collect::<Result<Tensor<Value<F>>, Box<dyn Error>>>()?
} else {
Tensor::new(
Expand Down Expand Up @@ -544,8 +544,13 @@ fn _select_topk<F: PrimeField + TensorType + PartialOrd>(
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
k: usize,
largest: bool,
) -> Result<ValTensor<F>, Box<dyn Error>> {
let sorted = _sort_descending(config, region, values)?.get_slice(&[0..k])?;
let sorted = if largest {
_sort_descending(config, region, values)?.get_slice(&[0..k])?
} else {
_sort_ascending(config, region, values)?.get_slice(&[0..k])?
};
Ok(sorted)
}

Expand All @@ -556,12 +561,13 @@ pub fn topk_axes<F: PrimeField + TensorType + PartialOrd>(
values: &[ValTensor<F>; 1],
k: usize,
dim: usize,
largest: bool,
) -> Result<ValTensor<F>, Box<dyn Error>> {
let topk_at_k = move |config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1]|
-> Result<ValTensor<F>, Box<dyn Error>> {
_select_topk(config, region, values, k)
_select_topk(config, region, values, k, largest)
};

let output: ValTensor<F> = multi_dim_axes_op(config, region, values, &[dim], topk_at_k)?;
Expand Down
15 changes: 9 additions & 6 deletions src/circuit/ops/lookup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use halo2curves::ff::PrimeField;
pub enum LookupOp {
Abs,
Div { denom: utils::F32 },
Cast { scale: utils::F32 },
ReLU,
Max { scale: utils::F32, a: utils::F32 },
Min { scale: utils::F32, a: utils::F32 },
Expand Down Expand Up @@ -115,6 +116,10 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for LookupOp {
&x,
f32::from(*denom).into(),
)),
LookupOp::Cast { scale } => Ok(tensor::ops::nonlinearities::const_div(
&x,
f32::from(*scale).into(),
)),
LookupOp::Recip { scale } => Ok(tensor::ops::nonlinearities::recip(&x, scale.into())),
LookupOp::ReLU => Ok(tensor::ops::nonlinearities::leakyrelu(&x, 0_f64)),

Expand Down Expand Up @@ -170,6 +175,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for LookupOp {
LookupOp::LessThanEqual { .. } => "LESS_THAN_EQUAL".into(),
LookupOp::Recip { scale, .. } => format!("RECIP(scale={})", scale),
LookupOp::Div { denom, .. } => format!("DIV(denom={})", denom),
LookupOp::Cast { scale } => format!("CAST(scale={})", scale),
LookupOp::Ln { scale } => format!("LN(scale={})", scale),
LookupOp::ReLU => "RELU".to_string(),
LookupOp::LeakyReLU { slope: a } => format!("L_RELU(slope={})", a),
Expand Down Expand Up @@ -210,12 +216,9 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for LookupOp {
/// Returns the scale of the output of the operation.
fn out_scale(&self, inputs_scale: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
let scale = match self {
LookupOp::Div { denom } => {
let mut scale = inputs_scale[0];
if scale == 0 {
scale += multiplier_to_scale(1. / denom.0 as f64);
}
scale
LookupOp::Cast { scale } => {
let in_scale = inputs_scale[0];
in_scale + multiplier_to_scale(1. / scale.0 as f64)
}
LookupOp::Recip { scale } => {
let mut out_scale = inputs_scale[0];
Expand Down
5 changes: 4 additions & 1 deletion src/circuit/ops/poly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,10 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
}

fn requires_homogenous_input_scales(&self) -> Vec<usize> {
if matches!(self, PolyOp::Add { .. } | PolyOp::Sub) {
if matches!(
self,
PolyOp::Add { .. } | PolyOp::Sub | PolyOp::Concat { .. }
) {
vec![0, 1]
} else if matches!(self, PolyOp::Iff) {
vec![1, 2]
Expand Down
5 changes: 5 additions & 0 deletions src/commands.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ pub const DEFAULT_OPTIMIZER_RUNS: &str = "1";
pub const DEFAULT_FUZZ_RUNS: &str = "10";
/// Default calibration file
pub const DEFAULT_CALIBRATION_FILE: &str = "calibration.json";
/// Default lookup safety margin
pub const DEFAULT_LOOKUP_SAFETY_MARGIN: &str = "2";

impl std::fmt::Display for TranscriptType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Expand Down Expand Up @@ -303,6 +305,9 @@ pub enum Commands {
#[arg(long = "target", default_value = DEFAULT_CALIBRATION_TARGET)]
/// Target for calibration. Set to "resources" to optimize for computational resource. Otherwise, set to "accuracy" to optimize for accuracy.
target: CalibrationTarget,
/// the lookup safety margin to use for calibration. if the max lookup is 2^k, then the max lookup will be 2^k * lookup_safety_margin. larger = safer but slower
#[arg(long, default_value = DEFAULT_LOOKUP_SAFETY_MARGIN)]
lookup_safety_margin: i128,
/// Optional scales to specifically try for calibration. Example, --scales 0,4
#[arg(long, value_delimiter = ',', allow_hyphen_values = true)]
scales: Option<Vec<crate::Scale>>,
Expand Down
16 changes: 13 additions & 3 deletions src/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,19 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
settings_path,
data,
target,
lookup_safety_margin,
scales,
max_logrows,
} => calibrate(model, data, settings_path, target, scales, max_logrows)
.map(|e| serde_json::to_string(&e).unwrap()),
} => calibrate(
model,
data,
settings_path,
target,
lookup_safety_margin,
scales,
max_logrows,
)
.map(|e| serde_json::to_string(&e).unwrap()),
Commands::GenWitness {
data,
compiled_circuit,
Expand Down Expand Up @@ -624,6 +633,7 @@ pub(crate) fn calibrate(
data: PathBuf,
settings_path: PathBuf,
target: CalibrationTarget,
lookup_safety_margin: i128,
scales: Option<Vec<crate::Scale>>,
max_logrows: Option<u32>,
) -> Result<GraphSettings, Box<dyn Error>> {
Expand Down Expand Up @@ -748,7 +758,7 @@ pub(crate) fn calibrate(
.map_err(|e| format!("failed to load circuit inputs: {}", e))?;

circuit
.calibrate(&data, max_logrows)
.calibrate(&data, max_logrows, lookup_safety_margin)
.map_err(|e| format!("failed to calibrate: {}", e))?;

let settings = circuit.settings().clone();
Expand Down
27 changes: 18 additions & 9 deletions src/graph/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -942,11 +942,16 @@ impl GraphCircuit {
(ASSUMED_BLINDING_FACTORS + RESERVED_BLINDING_ROWS_PAD) as f64
}

fn calc_safe_range(res: &GraphWitness) -> (i128, i128) {
(
RANGE_MULTIPLIER * res.min_lookup_inputs,
RANGE_MULTIPLIER * res.max_lookup_inputs,
)
fn calc_safe_lookup_range(res: &GraphWitness, lookup_safety_margin: i128) -> (i128, i128) {
let mut margin = (
lookup_safety_margin * res.min_lookup_inputs,
lookup_safety_margin * res.max_lookup_inputs,
);
if lookup_safety_margin == 1 {
margin.0 -= 1;
margin.1 += 1;
}
margin
}

fn calc_num_cols(safe_range: (i128, i128), max_logrows: u32) -> usize {
Expand All @@ -961,6 +966,7 @@ impl GraphCircuit {
&mut self,
res: &GraphWitness,
max_logrows: Option<u32>,
lookup_safety_margin: i128,
) -> Result<(), Box<dyn std::error::Error>> {
// load the max logrows
let max_logrows = max_logrows.unwrap_or(MAX_PUBLIC_SRS);
Expand All @@ -969,14 +975,15 @@ impl GraphCircuit {

let reserved_blinding_rows = Self::reserved_blinding_rows();
// check if has overflowed max lookup input
if res.max_lookup_inputs > MAX_LOOKUP_ABS / RANGE_MULTIPLIER
|| res.min_lookup_inputs < -MAX_LOOKUP_ABS / RANGE_MULTIPLIER
if res.max_lookup_inputs > MAX_LOOKUP_ABS / lookup_safety_margin
|| res.min_lookup_inputs < -MAX_LOOKUP_ABS / lookup_safety_margin
{
let err_string = format!("max lookup input ({}) is too large", res.max_lookup_inputs);
error!("{}", err_string);
return Err(err_string.into());
}

let safe_range = Self::calc_safe_range(res);
let safe_range = Self::calc_safe_lookup_range(res, lookup_safety_margin);
let mut min_logrows = MIN_LOGROWS;
// degrade the max logrows until the extended k is small enough
while min_logrows < max_logrows
Expand All @@ -995,6 +1002,7 @@ impl GraphCircuit {
"extended k is too large to accommodate the quotient polynomial with logrows {}",
min_logrows
);
error!("{}", err_string);
return Err(err_string.into());
}

Expand Down Expand Up @@ -1100,9 +1108,10 @@ impl GraphCircuit {
&mut self,
input: &[Tensor<Fp>],
max_logrows: Option<u32>,
lookup_safety_margin: i128,
) -> Result<(), Box<dyn std::error::Error>> {
let res = self.forward(&mut input.to_vec(), None, None)?;
self.calc_min_logrows(&res, max_logrows)
self.calc_min_logrows(&res, max_logrows, lookup_safety_margin)
}

/// Runs the forward pass of the model / graph of computations and any associated hashing.
Expand Down
3 changes: 2 additions & 1 deletion src/graph/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,7 @@ impl Model {
debug!("intermediate min lookup inputs: {}", min);
}
debug!(
"------------ output node int {}: {} \n ------------ float: {} \n ------------ max: {} \n ------------ min: {}",
"------------ output node int {}: {} \n ------------ float: {} \n ------------ max: {} \n ------------ min: {} ------------ scale: {}",
idx,
res.output.map(crate::fieldutils::felt_to_i32).show(),
res.output
Expand All @@ -616,6 +616,7 @@ impl Model {
.show(),
res.output.clone().into_iter().map(crate::fieldutils::felt_to_i128).max().unwrap_or(0),
res.output.clone().into_iter().map(crate::fieldutils::felt_to_i128).min().unwrap_or(0),
n.out_scale
);
results.insert(idx, vec![res.output]);
}
Expand Down
36 changes: 21 additions & 15 deletions src/graph/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,8 @@ impl Node {
idx: usize,
symbol_values: &SymbolValues,
) -> Result<Self, Box<dyn Error>> {
use log::warn;

trace!("Create {:?}", node);
trace!("Create op {:?}", node.op);

Expand Down Expand Up @@ -603,21 +605,25 @@ impl Node {
.into_iter()
.filter(|i| !deleted_indices.contains(i))
{
let input_node = other_nodes
.get_mut(&inputs[input].idx())
.ok_or("input not found")?;
let input_opkind = &mut input_node.opkind();
if let Some(constant) = input_opkind.get_mutable_constant() {
rescale_const_with_single_use(
constant,
in_scales.clone(),
param_visibility,
input_node.num_uses(),
)?;
input_node.replace_opkind(constant.clone_dyn().into());
let out_scale = input_opkind.out_scale(vec![])?;
input_node.bump_scale(out_scale);
in_scales[input] = out_scale;
if inputs.len() > input {
let input_node = other_nodes
.get_mut(&inputs[input].idx())
.ok_or("input not found")?;
let input_opkind = &mut input_node.opkind();
if let Some(constant) = input_opkind.get_mutable_constant() {
rescale_const_with_single_use(
constant,
in_scales.clone(),
param_visibility,
input_node.num_uses(),
)?;
input_node.replace_opkind(constant.clone_dyn().into());
let out_scale = input_opkind.out_scale(vec![])?;
input_node.bump_scale(out_scale);
in_scales[input] = out_scale;
}
} else {
warn!("input {} not found for rescaling, skipping ...", input);
}
}

Expand Down
Loading

0 comments on commit de48ae1

Please sign in to comment.