Skip to content

Commit 8125b6a

Browse files
committed
Merge branch 'solve_bit_decomposition_with_negative_numbers' into add_sub_test
2 parents 5bc5347 + 85a3a4d commit 8125b6a

10 files changed

+387
-264
lines changed

executor/src/witgen/jit/affine_symbolic_expression.rs

+76-65
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@ use std::{
66

77
use itertools::Itertools;
88
use num_traits::Zero;
9-
use powdr_number::FieldElement;
9+
use powdr_number::{log2_exact, FieldElement};
1010

1111
use crate::witgen::jit::effect::Assertion;
1212

1313
use super::{
14-
super::range_constraints::RangeConstraint, effect::Effect,
14+
super::range_constraints::RangeConstraint,
15+
effect::{BitDecomposition, BitDecompositionComponent, Effect},
1516
symbolic_expression::SymbolicExpression,
1617
};
1718

@@ -215,23 +216,19 @@ impl<T: FieldElement, V: Ord + Clone + Display> AffineSymbolicExpression<T, V> {
215216
}
216217
_ => {
217218
let r = self.solve_bit_decomposition()?;
219+
218220
if r.complete {
219221
r
220222
} else {
221223
let negated = -self;
222-
let r = negated.solve_bit_decomposition()?;
223-
if r.complete {
224-
r
225-
} else {
226-
let effects = self
227-
.transfer_constraints()
228-
.into_iter()
229-
.chain(negated.transfer_constraints())
230-
.collect();
231-
ProcessResult {
232-
effects,
233-
complete: false,
234-
}
224+
let effects = self
225+
.transfer_constraints()
226+
.into_iter()
227+
.chain(negated.transfer_constraints())
228+
.collect();
229+
ProcessResult {
230+
effects,
231+
complete: false,
235232
}
236233
}
237234
}
@@ -257,41 +254,55 @@ impl<T: FieldElement, V: Ord + Clone + Display> AffineSymbolicExpression<T, V> {
257254

258255
// Check if they are mutually exclusive and compute assignments.
259256
let mut covered_bits: <T as FieldElement>::Integer = 0.into();
260-
let mut effects = vec![];
261-
for (var, coeff, constraint) in constrained_coefficients {
262-
let mask = *constraint.multiple(coeff).mask();
263-
if !(mask & covered_bits).is_zero() {
257+
let mut components = vec![];
258+
for (variable, coeff, constraint) in constrained_coefficients {
259+
let is_negative = !coeff.is_in_lower_half();
260+
let coeff_abs = if is_negative { -coeff } else { coeff };
261+
let Some(exponent) = log2_exact(coeff_abs.to_arbitrary_integer()) else {
262+
// We could work with non-powers of two, but it would require
263+
// division instead of shifts.
264+
return Ok(ProcessResult::empty());
265+
};
266+
let bit_mask = *constraint.multiple(coeff_abs).mask();
267+
if !(bit_mask & covered_bits).is_zero() {
264268
// Overlapping range constraints.
265269
return Ok(ProcessResult::empty());
266270
} else {
267-
covered_bits |= mask;
271+
covered_bits |= bit_mask;
268272
}
269-
let masked = -&self.offset & mask;
270-
effects.push(Effect::Assignment(
271-
var.clone(),
272-
masked.integer_div(&coeff.into()),
273-
));
273+
components.push(BitDecompositionComponent {
274+
variable,
275+
// We negate here because we are solving
276+
// c_1 * x_1 + c_2 * x_2 + ... + offset = 0,
277+
// instead of
278+
// c_1 * x_1 + c_2 * x_2 + ... = offset.
279+
is_negative: !is_negative,
280+
exponent: exponent as u64,
281+
bit_mask,
282+
});
274283
}
275284

276285
if covered_bits >= T::modulus() {
277286
return Ok(ProcessResult::empty());
278287
}
279288

280-
// We need to assert that the masks cover "-offset",
281-
// otherwise the equation is not solvable.
282-
// We assert -offset & !masks == 0
283-
if let Some(offset) = self.offset.try_to_number() {
284-
if (-offset).to_integer() & !covered_bits != 0.into() {
285-
return Err(Error::ConflictingRangeConstraints);
289+
if !components.iter().any(|c| c.is_negative) {
290+
// If all coefficients are positive and the offset is known, we can check
291+
// that all bits are covered. If not, then there is no way to extract
292+
// the components and thus we have a conflict.
293+
if let Some(offset) = self.offset.try_to_number() {
294+
if offset.to_integer() & !covered_bits != 0.into() {
295+
return Err(Error::ConflictingRangeConstraints);
296+
}
286297
}
287-
} else {
288-
effects.push(Assertion::assert_eq(
289-
-&self.offset & !covered_bits,
290-
T::from(0).into(),
291-
));
292298
}
293299

294-
Ok(ProcessResult::complete(effects))
300+
Ok(ProcessResult::complete(vec![Effect::BitDecomposition(
301+
BitDecomposition {
302+
value: self.offset.clone(),
303+
components,
304+
},
305+
)]))
295306
}
296307

297308
fn transfer_constraints(&self) -> Option<Effect<T, V>> {
@@ -521,10 +532,9 @@ mod test {
521532
let b = Ase::from_unknown_variable("b", rc.clone());
522533
let c = Ase::from_unknown_variable("c", rc.clone());
523534
let z = Ase::from_known_symbol("Z", Default::default());
524-
// a * 0x100 + b * 0x10000 + c * 0x1000000 + 10 + Z = 0
535+
// a * 0x100 - b * 0x10000 + c * 0x1000000 + 10 + Z = 0
525536
let ten = from_number(10);
526-
let constr = mul(&a, &from_number(0x100))
527-
+ mul(&b, &from_number(0x10000))
537+
let constr = mul(&a, &from_number(0x100)) - mul(&b, &from_number(0x10000))
528538
+ mul(&c, &from_number(0x1000000))
529539
+ ten.clone()
530540
+ z.clone();
@@ -533,38 +543,39 @@ mod test {
533543
assert!(!result.complete && result.effects.is_empty());
534544
// Now add the range constraint on a, it should be solvable.
535545
let a = Ase::from_unknown_variable("a", rc.clone());
536-
let constr = mul(&a, &from_number(0x100))
537-
+ mul(&b, &from_number(0x10000))
546+
let constr = mul(&a, &from_number(0x100)) - mul(&b, &from_number(0x10000))
538547
+ mul(&c, &from_number(0x1000000))
539548
+ ten.clone()
540549
+ z;
541550
let result = constr.solve().unwrap();
542551
assert!(result.complete);
543-
let effects = result
544-
.effects
545-
.into_iter()
546-
.map(|effect| match effect {
547-
Effect::Assignment(v, expr) => format!("{v} = {expr};\n"),
548-
Effect::Assertion(Assertion {
549-
lhs,
550-
rhs,
551-
expected_equal,
552-
}) => {
553-
format!(
554-
"assert {lhs} {} {rhs};\n",
555-
if expected_equal { "==" } else { "!=" }
556-
)
557-
}
558-
_ => panic!(),
552+
553+
let [effect] = &result.effects[..] else {
554+
panic!();
555+
};
556+
let Effect::BitDecomposition(BitDecomposition { value, components }) = effect else {
557+
panic!();
558+
};
559+
assert_eq!(format!("{value}"), "(10 + Z)");
560+
let formatted = components
561+
.iter()
562+
.map(|c| {
563+
format!(
564+
"{} = (({value} & 0x{:0x}) >> {}){};\n",
565+
c.variable,
566+
c.bit_mask,
567+
c.exponent,
568+
if c.is_negative { " [negative]" } else { "" }
569+
)
559570
})
560-
.format("")
561-
.to_string();
571+
.join("");
572+
562573
assert_eq!(
563-
effects,
564-
"a = ((-(10 + Z) & 0xff00) // 256);
565-
b = ((-(10 + Z) & 0xff0000) // 65536);
566-
c = ((-(10 + Z) & 0xff000000) // 16777216);
567-
assert (-(10 + Z) & 0xffffffff000000ff) == 0;
574+
formatted,
575+
"\
576+
a = (((10 + Z) & 0xff00) >> 8) [negative];
577+
b = (((10 + Z) & 0xff0000) >> 16);
578+
c = (((10 + Z) & 0xff000000) >> 24) [negative];
568579
"
569580
);
570581
}

executor/src/witgen/jit/block_machine_processor.rs

+36-21
Original file line numberDiff line numberDiff line change
@@ -521,31 +521,19 @@ main_binary::operation_id_next[0] = main_binary::operation_id[1];
521521
call_var(9, 0, 0) = main_binary::operation_id_next[0];
522522
main_binary::operation_id_next[1] = main_binary::operation_id[2];
523523
call_var(9, 1, 0) = main_binary::operation_id_next[1];
524-
main_binary::A_byte[2] = ((main_binary::A[3] & 0xff000000) // 16777216);
525-
main_binary::A[2] = (main_binary::A[3] & 0xffffff);
526-
assert (main_binary::A[3] & 0xffffffff00000000) == 0;
524+
2**24 * main_binary::A_byte[2] + 2**0 * main_binary::A[2] := main_binary::A[3];
527525
call_var(9, 2, 1) = main_binary::A_byte[2];
528-
main_binary::A_byte[1] = ((main_binary::A[2] & 0xff0000) // 65536);
529-
main_binary::A[1] = (main_binary::A[2] & 0xffff);
530-
assert (main_binary::A[2] & 0xffffffffff000000) == 0;
526+
2**16 * main_binary::A_byte[1] + 2**0 * main_binary::A[1] := main_binary::A[2];
531527
call_var(9, 1, 1) = main_binary::A_byte[1];
532-
main_binary::A_byte[0] = ((main_binary::A[1] & 0xff00) // 256);
533-
main_binary::A[0] = (main_binary::A[1] & 0xff);
534-
assert (main_binary::A[1] & 0xffffffffffff0000) == 0;
528+
2**8 * main_binary::A_byte[0] + 2**0 * main_binary::A[0] := main_binary::A[1];
535529
call_var(9, 0, 1) = main_binary::A_byte[0];
536530
main_binary::A_byte[-1] = main_binary::A[0];
537531
call_var(9, -1, 1) = main_binary::A_byte[-1];
538-
main_binary::B_byte[2] = ((main_binary::B[3] & 0xff000000) // 16777216);
539-
main_binary::B[2] = (main_binary::B[3] & 0xffffff);
540-
assert (main_binary::B[3] & 0xffffffff00000000) == 0;
532+
2**24 * main_binary::B_byte[2] + 2**0 * main_binary::B[2] := main_binary::B[3];
541533
call_var(9, 2, 2) = main_binary::B_byte[2];
542-
main_binary::B_byte[1] = ((main_binary::B[2] & 0xff0000) // 65536);
543-
main_binary::B[1] = (main_binary::B[2] & 0xffff);
544-
assert (main_binary::B[2] & 0xffffffffff000000) == 0;
534+
2**16 * main_binary::B_byte[1] + 2**0 * main_binary::B[1] := main_binary::B[2];
545535
call_var(9, 1, 2) = main_binary::B_byte[1];
546-
main_binary::B_byte[0] = ((main_binary::B[1] & 0xff00) // 256);
547-
main_binary::B[0] = (main_binary::B[1] & 0xff);
548-
assert (main_binary::B[1] & 0xffffffffffff0000) == 0;
536+
2**8 * main_binary::B_byte[0] + 2**0 * main_binary::B[0] := main_binary::B[1];
549537
call_var(9, 0, 2) = main_binary::B_byte[0];
550538
main_binary::B_byte[-1] = main_binary::B[0];
551539
call_var(9, -1, 2) = main_binary::B_byte[-1];
@@ -618,9 +606,7 @@ params[1] = Sub::b[0];"
618606
assert_eq!(
619607
format_code(&code),
620608
"SubM::a[0] = params[0];
621-
SubM::b[0] = ((SubM::a[0] & 0xff00) // 256);
622-
SubM::c[0] = (SubM::a[0] & 0xff);
623-
assert (SubM::a[0] & 0xffffffffffff0000) == 0;
609+
2**8 * SubM::b[0] + 2**0 * SubM::c[0] := SubM::a[0];
624610
params[1] = SubM::b[0];
625611
params[2] = SubM::c[0];
626612
call_var(1, 0, 0) = SubM::c[0];
@@ -800,4 +786,33 @@ S::Z[0] = ((S::Zi[0][0] + S::Zi[1][0]) + S::Zi[2][0]);
800786
params[2] = S::Z[0];"
801787
);
802788
}
789+
790+
#[test]
791+
fn bit_decomp_negative() {
792+
let input = "
793+
namespace Main(256);
794+
col witness a, b, c;
795+
[a, b, c] is [S.Y, S.Z, S.carry];
796+
namespace S(256);
797+
let BYTE: col = |i| i & 0xff;
798+
let X;
799+
let Y;
800+
let Z;
801+
let carry;
802+
carry * (1 - carry) = 0;
803+
[ X ] in [ BYTE ];
804+
[ Y ] in [ BYTE ];
805+
[ Z ] in [ BYTE ];
806+
X + Y = Z + 256 * carry;
807+
";
808+
let code = format_code(&generate_for_block_machine(input, "S", 2, 1).unwrap().code);
809+
assert_eq!(
810+
code,
811+
"\
812+
S::Y[0] = params[0];
813+
S::Z[0] = params[1];
814+
-2**0 * S::X[0] + 2**8 * S::carry[0] := (S::Y[0] + -S::Z[0]);
815+
params[2] = S::carry[0];"
816+
);
817+
}
803818
}

executor/src/witgen/jit/code_cleaner.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ fn optional_vars_in_effect<T: FieldElement>(
5050
required: &mut HashSet<Variable>,
5151
) -> HashSet<Variable> {
5252
let needed = match &effect {
53-
Effect::Assignment(..) | Effect::ProverFunctionCall(..) => {
53+
Effect::Assignment(..) | Effect::ProverFunctionCall(..) | Effect::BitDecomposition(_) => {
5454
effect.written_vars().any(|(v, _)| required.contains(v))
5555
}
5656
Effect::Assertion(_) => false,

0 commit comments

Comments
 (0)