Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Solve bit decomposition with negative coefficients. #2554

Open
wants to merge 33 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
5f9c452
Solve bit decomposition with negative coefficients.
chriseth Mar 14, 2025
bf43c73
Make bit decomposition its own effect.
chriseth Mar 17, 2025
99558c2
fixes
chriseth Mar 17, 2025
f4f8a28
Remove negative case since this is now also covered.
chriseth Mar 17, 2025
f5fe3f5
Update test expectations.
chriseth Mar 17, 2025
1a378c3
Implement missing functions.
chriseth Mar 17, 2025
fcc5190
Codegen for bit decomposition.
chriseth Mar 18, 2025
7b4780e
clippy
chriseth Mar 18, 2025
94a8cc4
Test with negative value.
chriseth Mar 18, 2025
44bdd08
Use shift instead of division.
chriseth Mar 18, 2025
d2980dd
Fix typo.
chriseth Mar 18, 2025
a3b210c
Bit decomposition interpreter.
chriseth Mar 18, 2025
2e7d3a7
Jump through some hoops.
chriseth Mar 19, 2025
4ddae64
Tests and fixes.
chriseth Mar 19, 2025
721cd36
Merge remote-tracking branch 'origin/main' into bit_decomp_interpreter
chriseth Mar 19, 2025
589963b
Cover check.
chriseth Mar 20, 2025
977f106
Update executor/src/witgen/jit/compiler.rs
chriseth Mar 20, 2025
14647d1
Review.
chriseth Mar 20, 2025
2e87944
Remove bitand_signed_negated.
chriseth Mar 20, 2025
0906258
debugging
chriseth Mar 20, 2025
1059fa6
Add comment.
chriseth Mar 20, 2025
5cf50ee
Make all variables to be assigned in the branch mutable.
chriseth Mar 20, 2025
00750c8
Revert "debugging"
chriseth Mar 20, 2025
f71190c
remove debug
chriseth Mar 20, 2025
d398c25
Fix expectation.
chriseth Mar 20, 2025
85a3a4d
Merge remote-tracking branch 'origin/main' into solve_bit_decompositi…
chriseth Mar 21, 2025
1c1eaa2
Merge remote-tracking branch 'origin/main' into solve_bit_decompositi…
chriseth Mar 24, 2025
ea0a974
Merge remote-tracking branch 'origin/main' into solve_bit_decompositi…
chriseth Mar 24, 2025
cbf290f
Merge remote-tracking branch 'origin/main' into solve_bit_decompositi…
chriseth Mar 25, 2025
73f11ce
fix merge.
chriseth Mar 25, 2025
84d05ce
Merge remote-tracking branch 'origin/main' into solve_bit_decompositi…
chriseth Mar 25, 2025
392a6d1
Merge remote-tracking branch 'origin/main' into bit_decomp_interpreter
chriseth Mar 25, 2025
8275be7
Merge branch 'bit_decomp_interpreter' into solve_bit_decomposition_wi…
chriseth Mar 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 76 additions & 65 deletions executor/src/witgen/jit/affine_symbolic_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@ use std::{

use itertools::Itertools;
use num_traits::Zero;
use powdr_number::FieldElement;
use powdr_number::{log2_exact, FieldElement};

use crate::witgen::jit::effect::Assertion;

use super::{
super::range_constraints::RangeConstraint, effect::Effect,
super::range_constraints::RangeConstraint,
effect::{BitDecomposition, BitDecompositionComponent, Effect},
symbolic_expression::SymbolicExpression,
};

Expand Down Expand Up @@ -215,23 +216,19 @@ impl<T: FieldElement, V: Ord + Clone + Display> AffineSymbolicExpression<T, V> {
}
_ => {
let r = self.solve_bit_decomposition()?;

if r.complete {
r
} else {
let negated = -self;
let r = negated.solve_bit_decomposition()?;
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The negated version is now covered already by the non-negated check.

if r.complete {
r
} else {
let effects = self
.transfer_constraints()
.into_iter()
.chain(negated.transfer_constraints())
.collect();
ProcessResult {
effects,
complete: false,
}
let effects = self
.transfer_constraints()
.into_iter()
.chain(negated.transfer_constraints())
.collect();
ProcessResult {
effects,
complete: false,
}
}
}
Expand All @@ -257,41 +254,55 @@ impl<T: FieldElement, V: Ord + Clone + Display> AffineSymbolicExpression<T, V> {

// Check if they are mutually exclusive and compute assignments.
let mut covered_bits: <T as FieldElement>::Integer = 0.into();
let mut effects = vec![];
for (var, coeff, constraint) in constrained_coefficients {
let mask = *constraint.multiple(coeff).mask();
if !(mask & covered_bits).is_zero() {
let mut components = vec![];
for (variable, coeff, constraint) in constrained_coefficients {
let is_negative = !coeff.is_in_lower_half();
let coeff_abs = if is_negative { -coeff } else { coeff };
let Some(exponent) = log2_exact(coeff_abs.to_arbitrary_integer()) else {
// We could work with non-powers of two, but it would require
// division instead of shifts.
return Ok(ProcessResult::empty());
};
let bit_mask = *constraint.multiple(coeff_abs).mask();
if !(bit_mask & covered_bits).is_zero() {
// Overlapping range constraints.
return Ok(ProcessResult::empty());
} else {
covered_bits |= mask;
covered_bits |= bit_mask;
}
let masked = -&self.offset & mask;
effects.push(Effect::Assignment(
var.clone(),
masked.integer_div(&coeff.into()),
));
components.push(BitDecompositionComponent {
variable,
// We negate here because we are solving
// c_1 * x_1 + c_2 * x_2 + ... + offset = 0,
// instead of
// c_1 * x_1 + c_2 * x_2 + ... = offset.
is_negative: !is_negative,
exponent: exponent as u64,
bit_mask,
});
}

if covered_bits >= T::modulus() {
return Ok(ProcessResult::empty());
}

// We need to assert that the masks cover "-offset",
// otherwise the equation is not solvable.
// We assert -offset & !masks == 0
if let Some(offset) = self.offset.try_to_number() {
if (-offset).to_integer() & !covered_bits != 0.into() {
return Err(Error::ConflictingRangeConstraints);
if !components.iter().any(|c| c.is_negative) {
// If all coefficients are positive and the offset is known, we can check
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean if the offset is a known number, we can actually determine that in all cases, but the ones with negative coefficients are more complicated...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, if the offset is a known constant, we could just "run" the algorithm and make all the components known constants as well. Let's do that if we need it.

// that all bits are covered. If not, then there is no way to extract
// the components and thus we have a conflict.
if let Some(offset) = self.offset.try_to_number() {
if offset.to_integer() & !covered_bits != 0.into() {
return Err(Error::ConflictingRangeConstraints);
}
}
} else {
effects.push(Assertion::assert_eq(
-&self.offset & !covered_bits,
T::from(0).into(),
));
}

Ok(ProcessResult::complete(effects))
Ok(ProcessResult::complete(vec![Effect::BitDecomposition(
BitDecomposition {
value: self.offset.clone(),
components,
},
)]))
}

fn transfer_constraints(&self) -> Option<Effect<T, V>> {
Expand Down Expand Up @@ -521,10 +532,9 @@ mod test {
let b = Ase::from_unknown_variable("b", rc.clone());
let c = Ase::from_unknown_variable("c", rc.clone());
let z = Ase::from_known_symbol("Z", Default::default());
// a * 0x100 + b * 0x10000 + c * 0x1000000 + 10 + Z = 0
// a * 0x100 - b * 0x10000 + c * 0x1000000 + 10 + Z = 0
let ten = from_number(10);
let constr = mul(&a, &from_number(0x100))
+ mul(&b, &from_number(0x10000))
let constr = mul(&a, &from_number(0x100)) - mul(&b, &from_number(0x10000))
+ mul(&c, &from_number(0x1000000))
+ ten.clone()
+ z.clone();
Expand All @@ -533,38 +543,39 @@ mod test {
assert!(!result.complete && result.effects.is_empty());
// Now add the range constraint on a, it should be solvable.
let a = Ase::from_unknown_variable("a", rc.clone());
let constr = mul(&a, &from_number(0x100))
+ mul(&b, &from_number(0x10000))
let constr = mul(&a, &from_number(0x100)) - mul(&b, &from_number(0x10000))
+ mul(&c, &from_number(0x1000000))
+ ten.clone()
+ z;
let result = constr.solve().unwrap();
assert!(result.complete);
let effects = result
.effects
.into_iter()
.map(|effect| match effect {
Effect::Assignment(v, expr) => format!("{v} = {expr};\n"),
Effect::Assertion(Assertion {
lhs,
rhs,
expected_equal,
}) => {
format!(
"assert {lhs} {} {rhs};\n",
if expected_equal { "==" } else { "!=" }
)
}
_ => panic!(),

let [effect] = &result.effects[..] else {
panic!();
};
let Effect::BitDecomposition(BitDecomposition { value, components }) = effect else {
panic!();
};
assert_eq!(format!("{value}"), "(10 + Z)");
let formatted = components
.iter()
.map(|c| {
format!(
"{} = (({value} & 0x{:0x}) >> {}){};\n",
c.variable,
c.bit_mask,
c.exponent,
if c.is_negative { " [negative]" } else { "" }
)
})
.format("")
.to_string();
.join("");

assert_eq!(
effects,
"a = ((-(10 + Z) & 0xff00) // 256);
b = ((-(10 + Z) & 0xff0000) // 65536);
c = ((-(10 + Z) & 0xff000000) // 16777216);
assert (-(10 + Z) & 0xffffffff000000ff) == 0;
formatted,
"\
a = (((10 + Z) & 0xff00) >> 8) [negative];
b = (((10 + Z) & 0xff0000) >> 16);
c = (((10 + Z) & 0xff000000) >> 24) [negative];
"
);
}
Expand Down
57 changes: 36 additions & 21 deletions executor/src/witgen/jit/block_machine_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -521,31 +521,19 @@ main_binary::operation_id_next[0] = main_binary::operation_id[1];
call_var(9, 0, 0) = main_binary::operation_id_next[0];
main_binary::operation_id_next[1] = main_binary::operation_id[2];
call_var(9, 1, 0) = main_binary::operation_id_next[1];
main_binary::A_byte[2] = ((main_binary::A[3] & 0xff000000) // 16777216);
main_binary::A[2] = (main_binary::A[3] & 0xffffff);
assert (main_binary::A[3] & 0xffffffff00000000) == 0;
2**24 * main_binary::A_byte[2] + 2**0 * main_binary::A[2] := main_binary::A[3];
call_var(9, 2, 1) = main_binary::A_byte[2];
main_binary::A_byte[1] = ((main_binary::A[2] & 0xff0000) // 65536);
main_binary::A[1] = (main_binary::A[2] & 0xffff);
assert (main_binary::A[2] & 0xffffffffff000000) == 0;
2**16 * main_binary::A_byte[1] + 2**0 * main_binary::A[1] := main_binary::A[2];
call_var(9, 1, 1) = main_binary::A_byte[1];
main_binary::A_byte[0] = ((main_binary::A[1] & 0xff00) // 256);
main_binary::A[0] = (main_binary::A[1] & 0xff);
assert (main_binary::A[1] & 0xffffffffffff0000) == 0;
2**8 * main_binary::A_byte[0] + 2**0 * main_binary::A[0] := main_binary::A[1];
call_var(9, 0, 1) = main_binary::A_byte[0];
main_binary::A_byte[-1] = main_binary::A[0];
call_var(9, -1, 1) = main_binary::A_byte[-1];
main_binary::B_byte[2] = ((main_binary::B[3] & 0xff000000) // 16777216);
main_binary::B[2] = (main_binary::B[3] & 0xffffff);
assert (main_binary::B[3] & 0xffffffff00000000) == 0;
2**24 * main_binary::B_byte[2] + 2**0 * main_binary::B[2] := main_binary::B[3];
call_var(9, 2, 2) = main_binary::B_byte[2];
main_binary::B_byte[1] = ((main_binary::B[2] & 0xff0000) // 65536);
main_binary::B[1] = (main_binary::B[2] & 0xffff);
assert (main_binary::B[2] & 0xffffffffff000000) == 0;
2**16 * main_binary::B_byte[1] + 2**0 * main_binary::B[1] := main_binary::B[2];
call_var(9, 1, 2) = main_binary::B_byte[1];
main_binary::B_byte[0] = ((main_binary::B[1] & 0xff00) // 256);
main_binary::B[0] = (main_binary::B[1] & 0xff);
assert (main_binary::B[1] & 0xffffffffffff0000) == 0;
2**8 * main_binary::B_byte[0] + 2**0 * main_binary::B[0] := main_binary::B[1];
call_var(9, 0, 2) = main_binary::B_byte[0];
main_binary::B_byte[-1] = main_binary::B[0];
call_var(9, -1, 2) = main_binary::B_byte[-1];
Expand Down Expand Up @@ -620,9 +608,7 @@ params[1] = Sub::b[0];"
assert_eq!(
format_code(&code),
"SubM::a[0] = params[0];
SubM::b[0] = ((SubM::a[0] & 0xff00) // 256);
SubM::c[0] = (SubM::a[0] & 0xff);
assert (SubM::a[0] & 0xffffffffffff0000) == 0;
2**8 * SubM::b[0] + 2**0 * SubM::c[0] := SubM::a[0];
params[1] = SubM::b[0];
params[2] = SubM::c[0];
call_var(1, 0, 0) = SubM::c[0];
Expand Down Expand Up @@ -802,4 +788,33 @@ S::Z[0] = ((S::Zi[0][0] + S::Zi[1][0]) + S::Zi[2][0]);
params[2] = S::Z[0];"
);
}

#[test]
fn bit_decomp_negative() {
let input = "
namespace Main(256);
col witness a, b, c;
[a, b, c] is [S.Y, S.Z, S.carry];
namespace S(256);
let BYTE: col = |i| i & 0xff;
let X;
let Y;
let Z;
let carry;
carry * (1 - carry) = 0;
[ X ] in [ BYTE ];
[ Y ] in [ BYTE ];
[ Z ] in [ BYTE ];
X + Y = Z + 256 * carry;
";
let code = format_code(&generate_for_block_machine(input, "S", 2, 1).unwrap().code);
assert_eq!(
code,
"\
S::Y[0] = params[0];
S::Z[0] = params[1];
-2**0 * S::X[0] + 2**8 * S::carry[0] := (S::Y[0] + -S::Z[0]);
params[2] = S::carry[0];"
);
}
}
2 changes: 1 addition & 1 deletion executor/src/witgen/jit/code_cleaner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ fn optional_vars_in_effect<T: FieldElement>(
required: &mut HashSet<Variable>,
) -> HashSet<Variable> {
let needed = match &effect {
Effect::Assignment(..) | Effect::ProverFunctionCall(..) => {
Effect::Assignment(..) | Effect::ProverFunctionCall(..) | Effect::BitDecomposition(_) => {
effect.written_vars().any(|(v, _)| required.contains(v))
}
Effect::Assertion(_) => false,
Expand Down
Loading
Loading