Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: use onnx convention when integer dividing #925

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions examples/onnx/integer_div/gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from torch import nn
import torch
import json
import numpy as np


class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()

def forward(self, x):
return x // 3


circuit = MyModel()

x = torch.randint(0, 10, (1, 2, 2, 8))

out = circuit(x)

print(x)
print(out)
print(x/3)

torch.onnx.export(circuit, x, "network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=17, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})


d1 = ((x).detach().numpy()).reshape([-1]).tolist()

data = dict(
input_data=[d1],
)

# Serialize data into file:
json.dump(data, open("input.json", 'w'))
1 change: 1 addition & 0 deletions examples/onnx/integer_div/input.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"input_data": [[3, 4, 0, 9, 2, 6, 2, 5, 1, 5, 3, 5, 5, 7, 0, 2, 6, 1, 4, 4, 1, 9, 7, 7, 5, 8, 2, 0, 1, 5, 9, 8]]}
Binary file added examples/onnx/integer_div/network.onnx
Binary file not shown.
12 changes: 8 additions & 4 deletions src/circuit/ops/layouts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,9 @@ pub(crate) fn div<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
// implicitly check if the prover provided output is within range
let claimed_output = identity(config, region, &[claimed_output], true)?;
// check if x is too large only if the decomp would support overflow in the previous op
if (IntegerRep::MAX).abs() < ((region.base() as i128).pow(region.legs() as u32)) - 1 {
if F::from_u128(IntegerRep::MAX as u128)
< F::from_u128(region.base() as u128).pow([region.legs() as u64]) - F::ONE
{
// here we decompose and extract the sign of the input
let sign = sign(config, region, &[claimed_output.clone()])?;

Expand Down Expand Up @@ -254,7 +256,9 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
)?;

// check if x is too large only if the decomp would support overflow in the previous op
if (IntegerRep::MAX).abs() < ((region.base() as i128).pow(region.legs() as u32)) - 1 {
if F::from_u128(IntegerRep::MAX as u128)
< F::from_u128(region.base() as u128).pow([region.legs() as u64]) - F::ONE
{
// here we decompose and extract the sign of the input
let sign = sign(config, region, &[masked_output.clone()])?;
let abs_value = pairwise(
Expand Down Expand Up @@ -2652,9 +2656,9 @@ pub fn mean_of_squares_axes<F: PrimeField + TensorType + PartialOrd + std::hash:
let squared = pow(config, region, values, 2)?;
let sum_squared = sum_axes(config, region, &[squared], axes)?;

let dividand: usize = values[0].len() / sum_squared.len();
let dividend: usize = values[0].len() / sum_squared.len();

let mean_squared = div(config, region, &[sum_squared], F::from(dividand as u64))?;
let mean_squared = div(config, region, &[sum_squared], F::from(dividend as u64))?;
Ok(mean_squared)
}

Expand Down
18 changes: 13 additions & 5 deletions src/graph/utilities.rs
Original file line number Diff line number Diff line change
Expand Up @@ -274,12 +274,10 @@ pub fn new_op_from_onnx(
symbol_values: &SymbolValues,
run_args: &crate::RunArgs,
) -> Result<(SupportedOp, Vec<usize>), GraphError> {
use crate::circuit::InputType;
use std::f64::consts::E;

use tract_onnx::tract_core::ops::array::Trilu;

use crate::circuit::InputType;

let input_scales = inputs
.iter()
.flat_map(|x| x.out_scales())
Expand Down Expand Up @@ -1274,9 +1272,19 @@ pub fn new_op_from_onnx(
// get the non constant index
let denom = c.raw_values[0];

SupportedOp::Hybrid(HybridOp::Div {
let op = SupportedOp::Hybrid(HybridOp::Div {
denom: denom.into(),
})
});

// if the input is scale 0 we re up to the max scale
if input_scales[0] == 0 {
SupportedOp::Rescaled(Rescaled {
inner: Box::new(op),
scale: vec![(0, scale_to_multiplier(scales.get_max()) as u128)],
})
} else {
op
}
} else {
return Err(GraphError::MisformedParams(
"only support non zero divisors of size 1".to_string(),
Expand Down
5 changes: 3 additions & 2 deletions tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ mod native_tests {
"1l_tiny_div",
];

const TESTS: [&str; 98] = [
const TESTS: [&str; 99] = [
"1l_mlp", //0
"1l_slice", //1
"1l_concat", //2
Expand Down Expand Up @@ -309,6 +309,7 @@ mod native_tests {
"log", // 95
"exp", // 96
"general_exp", // 97
"integer_div", // 98
];

const WASM_TESTS: [&str; 46] = [
Expand Down Expand Up @@ -547,7 +548,7 @@ mod native_tests {
}
});

seq!(N in 0..=97 {
seq!(N in 0..=98 {

#(#[test_case(TESTS[N])])*
#[ignore]
Expand Down
Loading