From 145c7a19f4354a8b8223879c4c794252172fbbc8 Mon Sep 17 00:00:00 2001 From: Alexander Camuto <45801863+alexander-camuto@users.noreply.github.com> Date: Tue, 4 Feb 2025 13:57:35 +0000 Subject: [PATCH 1/4] fix: use onnx convention when integer dividing --- examples/onnx/integer_div/gen.py | 42 ++++++++++++++++++++++++++++++++ src/circuit/ops/layouts.rs | 12 ++++++--- src/tensor/ops.rs | 2 +- tests/integration_tests.rs | 5 ++-- 4 files changed, 54 insertions(+), 7 deletions(-) create mode 100644 examples/onnx/integer_div/gen.py diff --git a/examples/onnx/integer_div/gen.py b/examples/onnx/integer_div/gen.py new file mode 100644 index 000000000..aaa505676 --- /dev/null +++ b/examples/onnx/integer_div/gen.py @@ -0,0 +1,42 @@ +from torch import nn +import torch +import json +import numpy as np + + +class MyModel(nn.Module): + def __init__(self): + super(MyModel, self).__init__() + + def forward(self, x): + return x // 3 + + +circuit = MyModel() + +x = torch.randint(0, 10, (1, 2, 2, 8)) + +out = circuit(x) + +print(x) +print(out) +print(x/3) + +torch.onnx.export(circuit, x, "network.onnx", + export_params=True, # store the trained parameter weights inside the model file + opset_version=17, # the ONNX version to export the model to + do_constant_folding=True, # whether to execute constant folding for optimization + input_names=['input'], # the model's input names + output_names=['output'], # the model's output names + dynamic_axes={'input': {0: 'batch_size'}, # variable length axes + 'output': {0: 'batch_size'}}) + + +d1 = ((x).detach().numpy()).reshape([-1]).tolist() + +data = dict( + input_data=[d1], +) + +# Serialize data into file: +json.dump(data, open("input.json", 'w')) diff --git a/src/circuit/ops/layouts.rs b/src/circuit/ops/layouts.rs index 90a61b6ed..45db97783 100644 --- a/src/circuit/ops/layouts.rs +++ b/src/circuit/ops/layouts.rs @@ -157,7 +157,9 @@ pub(crate) fn div( // implicitly check if the prover provided output is within range let claimed_output = identity(config, region, &[claimed_output], true)?; // check if x is too large only if the decomp would support overflow in the previous op - if (IntegerRep::MAX).abs() < ((region.base() as i128).pow(region.legs() as u32)) - 1 { + if F::from_u128(IntegerRep::MAX as u128) + < F::from_u128(region.base() as u128).pow([region.legs() as u64]) - F::ONE + { // here we decompose and extract the sign of the input let sign = sign(config, region, &[claimed_output.clone()])?; @@ -254,7 +256,9 @@ pub(crate) fn recip( )?; // check if x is too large only if the decomp would support overflow in the previous op - if (IntegerRep::MAX).abs() < ((region.base() as i128).pow(region.legs() as u32)) - 1 { + if F::from_u128(IntegerRep::MAX as u128) + < F::from_u128(region.base() as u128).pow([region.legs() as u64]) - F::ONE + { // here we decompose and extract the sign of the input let sign = sign(config, region, &[masked_output.clone()])?; let abs_value = pairwise( @@ -2652,9 +2656,9 @@ pub fn mean_of_squares_axes, denom: f64) -> Tensor { a.par_enum_map(|_, a_i| { let d_inv_x = (a_i as f64) / (denom); - Ok::<_, TensorError>(d_inv_x.round() as IntegerRep) + Ok::<_, TensorError>(d_inv_x.floor() as IntegerRep) }) .unwrap() } diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index 67d5b731e..21d8dac42 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -206,7 +206,7 @@ mod native_tests { "1l_tiny_div", ]; - const TESTS: [&str; 98] = [ + const TESTS: [&str; 99] = [ "1l_mlp", //0 "1l_slice", //1 "1l_concat", //2 @@ -309,6 +309,7 @@ mod native_tests { "log", // 95 "exp", // 96 "general_exp", // 97 + "integer_div", // 98 ]; const WASM_TESTS: [&str; 46] = [ @@ -547,7 +548,7 @@ mod native_tests { } }); - seq!(N in 0..=97 { + seq!(N in 0..=98 { #(#[test_case(TESTS[N])])* #[ignore] From 71d6678a0b0e9e19e297fd886ac080a9fc911a46 Mon Sep 17 00:00:00 2001 From: Alexander Camuto <45801863+alexander-camuto@users.noreply.github.com> Date: Tue, 4 Feb 2025 14:00:25 +0000 Subject: [PATCH 2/4] files --- examples/onnx/integer_div/input.json | 1 + examples/onnx/integer_div/network.onnx | Bin 0 -> 357 bytes 2 files changed, 1 insertion(+) create mode 100644 examples/onnx/integer_div/input.json create mode 100644 examples/onnx/integer_div/network.onnx diff --git a/examples/onnx/integer_div/input.json b/examples/onnx/integer_div/input.json new file mode 100644 index 000000000..f537d5a1d --- /dev/null +++ b/examples/onnx/integer_div/input.json @@ -0,0 +1 @@ +{"input_data": [[3, 4, 0, 9, 2, 6, 2, 5, 1, 5, 3, 5, 5, 7, 0, 2, 6, 1, 4, 4, 1, 9, 7, 7, 5, 8, 2, 0, 1, 5, 9, 8]]} \ No newline at end of file diff --git a/examples/onnx/integer_div/network.onnx b/examples/onnx/integer_div/network.onnx new file mode 100644 index 0000000000000000000000000000000000000000..b4ebb77e745230ea01ed58e46ecd1e1ad6539a04 GIT binary patch literal 357 zcmdg2*=mJqMLOJ*6uXciDxi5W<6nRB7a z3GwMWCl(_)gB8qIVgZS2adI)0 Date: Tue, 4 Feb 2025 14:05:55 +0000 Subject: [PATCH 3/4] Update ops.rs --- src/tensor/ops.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tensor/ops.rs b/src/tensor/ops.rs index 51175dc53..f0ebd1be0 100644 --- a/src/tensor/ops.rs +++ b/src/tensor/ops.rs @@ -2250,7 +2250,7 @@ pub mod nonlinearities { pub fn const_div(a: &Tensor, denom: f64) -> Tensor { a.par_enum_map(|_, a_i| { let d_inv_x = (a_i as f64) / (denom); - Ok::<_, TensorError>(d_inv_x.floor() as IntegerRep) + Ok::<_, TensorError>(d_inv_x.round() as IntegerRep) }) .unwrap() } From f44ffdb1cf48eaff362cdeaa6249f57f8656d3ea Mon Sep 17 00:00:00 2001 From: Alexander Camuto <45801863+alexander-camuto@users.noreply.github.com> Date: Tue, 4 Feb 2025 14:31:05 +0000 Subject: [PATCH 4/4] Update utilities.rs --- src/graph/utilities.rs | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/src/graph/utilities.rs b/src/graph/utilities.rs index 3aadbc953..e70cbdb8a 100644 --- a/src/graph/utilities.rs +++ b/src/graph/utilities.rs @@ -274,12 +274,10 @@ pub fn new_op_from_onnx( symbol_values: &SymbolValues, run_args: &crate::RunArgs, ) -> Result<(SupportedOp, Vec), GraphError> { + use crate::circuit::InputType; use std::f64::consts::E; - use tract_onnx::tract_core::ops::array::Trilu; - use crate::circuit::InputType; - let input_scales = inputs .iter() .flat_map(|x| x.out_scales()) @@ -1274,9 +1272,19 @@ pub fn new_op_from_onnx( // get the non constant index let denom = c.raw_values[0]; - SupportedOp::Hybrid(HybridOp::Div { + let op = SupportedOp::Hybrid(HybridOp::Div { denom: denom.into(), - }) + }); + + // if the input is scale 0 we re up to the max scale + if input_scales[0] == 0 { + SupportedOp::Rescaled(Rescaled { + inner: Box::new(op), + scale: vec![(0, scale_to_multiplier(scales.get_max()) as u128)], + }) + } else { + op + } } else { return Err(GraphError::MisformedParams( "only support non zero divisors of size 1".to_string(),