diff --git a/src/solver.rs b/src/solver.rs index 922a8fcd..93298967 100644 --- a/src/solver.rs +++ b/src/solver.rs @@ -5,7 +5,7 @@ use crate::symbolic_state::{get_operands, BVOperator, Formula, Node, OperandSide use crate::ternary::*; use log::{debug, log_enabled, trace, Level}; use petgraph::{visit::EdgeRef, Direction}; -use rand::{random, Rng}; +use rand::{distributions::Uniform, random, thread_rng, Rng}; pub type Assignment = Vec; @@ -81,10 +81,40 @@ fn is_invertable( && (s == BitVector(0) || (!s.odd() || x.mcb(t * s.modinverse().unwrap()))) && (s.odd() || (x << s.ctz()).mcb(y(s, t) << s.ctz())) } + BVOperator::Sltu => { + // (x= s.0 { + return false; + } + } + if t == BitVector(0) && s.0 > x.hi().0 { + return false; + } + true + // (s s.0 { + return false; + } + true + } + } BVOperator::Not => x.mcb(!t), BVOperator::BitwiseAnd => { let c = !(x.lo() ^ x.hi()); - (t & s == t) && ((s & x.hi() & c) == (t & c)) } BVOperator::Equals => { @@ -95,7 +125,7 @@ fn is_invertable( } } -fn is_consistent(op: BVOperator, x: TernaryBitVector, t: BitVector) -> bool { +fn is_consistent(op: BVOperator, x: TernaryBitVector, t: BitVector, d: OperandSide) -> bool { match op { BVOperator::Add | BVOperator::Sub | BVOperator::Equals => true, BVOperator::Not => x.mcb(!t), @@ -111,6 +141,19 @@ fn is_consistent(op: BVOperator, x: TernaryBitVector, t: BitVector) -> bool { && implies(t.odd(), || x.hi().lsb() != 0) && implies(!t.odd(), || value_exists(x, t)) } + BVOperator::Sltu => { + if d == OperandSide::Lhs { + if t == BitVector(1) && x.lo() == BitVector::ones() { + return false; + } + true + } else { + if t == BitVector(1) && x.hi() == BitVector(0) { + return false; + } + true + } + } _ => unimplemented!("can not check consistency for operator: {:?}", op), } } @@ -236,6 +279,31 @@ fn compute_inverse_value( (result_with_arbitrary & !x.constant_bit_mask()) | x.constant_bits() } + BVOperator::Sltu => { + if d == OperandSide::Lhs { + if t == BitVector(0) { + // x=x>=s + BitVector(thread_rng().sample(Uniform::new_inclusive(s.0, x.hi().0))) + } else { + // xx>=x.0 + BitVector(thread_rng().sample(Uniform::new(x.lo().0, s.0))) + } + } else if t == BitVector(0) { + // s=x>=x.0 + if s < x.hi() { + BitVector(thread_rng().sample(Uniform::new_inclusive(x.lo().0, s.0))) + } else { + BitVector(thread_rng().sample(Uniform::new_inclusive(x.lo().0, x.hi().0))) + } + } else { + // s=x>s + BitVector(thread_rng().sample(Uniform::new_inclusive(s.0 + 1, x.hi().0))) + } + } BVOperator::BitwiseAnd => { let fixed_bit_mask = x.constant_bit_mask() | s; let fixed_bits = x.constant_bits() | (s & t); @@ -262,7 +330,7 @@ fn compute_consistent_value( op: BVOperator, x: TernaryBitVector, t: BitVector, - _d: OperandSide, + d: OperandSide, ) -> BitVector { match op { BVOperator::Add | BVOperator::Sub | BVOperator::Equals => { @@ -286,6 +354,27 @@ fn compute_consistent_value( r as u64 } }), + BVOperator::Sltu => { + if d == OperandSide::Lhs { + if t == BitVector(0) { + // x { (BitVector(random::()) & !x.constant_bit_mask()) | x.constant_bits() | t } @@ -448,6 +537,13 @@ fn propagate_assignment(f: &Formula, ab: &mut Assignment, n: SymbolId BVOperator::Sub => update_binary(f, ab, n, "-", |l, r| l - r), BVOperator::Mul => update_binary(f, ab, n, "*", |l, r| l * r), BVOperator::BitwiseAnd => update_binary(f, ab, n, "&", |l, r| l & r), + BVOperator::Sltu => update_binary(f, ab, n, "<", |l, r| { + if l < r { + BitVector(1) + } else { + BitVector(0) + } + }), BVOperator::Equals => update_binary(f, ab, n, "=", |l, r| { if l == r { BitVector(1) @@ -519,7 +615,7 @@ fn sat( let x = at[nx.index()]; - if !is_consistent(op, x, t) { + if !is_consistent(op, x, t, side) { trace!( "not consistent: op={:?} x={:?} t={:?} -> aborting", op, @@ -574,6 +670,7 @@ fn sat( #[cfg(test)] mod tests { use super::*; + //use crate::engine::SyscallId::Openat; fn create_formula_with_input() -> (Formula, SymbolId) { let mut formula = Formula::new(); @@ -670,7 +767,7 @@ mod tests { ); } else { assert!( - is_invertable(op, x, s, t, side) == result, + is_invertable(op, x, s, t, OperandSide::Rhs) == result, "{:?} {:?} {:?} == {:?} {}", s, op, @@ -681,11 +778,18 @@ mod tests { } } - fn test_consistence(op: BVOperator, x: &'static str, t: u64, result: bool, msg: &'static str) { + fn test_consistence( + op: BVOperator, + x: &'static str, + t: u64, + result: bool, + msg: &'static str, + d: OperandSide, + ) { let x = TernaryBitVector::lit(x); let t = BitVector(t); assert!( - is_consistent(op, x, t) == result, + is_consistent(op, x, t, d) == result, "{:?} {:?} s == {:?} {}", x, op, @@ -711,15 +815,27 @@ mod tests { let computed = compute_inverse_value(op, x, s, t, d); // prove: computed <> s == t where <> is the binary operator - assert_eq!( - f(computed, s), - t, - "{:?} {:?} {:?} == {:?}", - computed, - op, - s, - t - ); + if d == OperandSide::Rhs { + assert_eq!( + f(s, computed), + t, + "{:?} {:?} {:?} == {:?}", + computed, + op, + s, + t + ); + } else { + assert_eq!( + f(computed, s), + t, + "{:?} {:?} {:?} == {:?}", + computed, + op, + s, + t + ); + } } fn test_consistent_value_computation( @@ -755,12 +871,14 @@ mod tests { compute_inverse_value(op, x, computed, t, d.other()) } + BVOperator::Sltu => compute_inverse_value(op, x, computed, t, d), + _ => unimplemented!(), }; if d == OperandSide::Lhs { assert_eq!( - f(computed, inverse), + f(inverse, computed), t, "{:?} {:?} {:?} == {:?}", computed, @@ -770,7 +888,7 @@ mod tests { ); } else { assert_eq!( - f(inverse, computed), + f(computed, inverse), t, "{:?} {:?} {:?} == {:?}", inverse, @@ -785,6 +903,7 @@ mod tests { // TODO: add tests for SUB const MUL: BVOperator = BVOperator::Mul; + const SLTU: BVOperator = BVOperator::Sltu; #[test] fn check_invertability_condition_for_mul() { @@ -838,6 +957,49 @@ mod tests { ); } + #[test] + fn check_invertability_condition_for_sltu() { + let mut side = OperandSide::Lhs; + + test_invertability(SLTU, "1", 2, 1, side, true, "trivial sltu v1"); + test_invertability(SLTU, "10", 1, 0, side, true, "trivial sltu v2"); + test_invertability(SLTU, "1", 1, 0, side, true, "trivial sltu v3"); + + test_invertability(SLTU, "1", 2, 0, side, false, "trivial sltu v4"); + test_invertability(SLTU, "10", 1, 1, side, false, "trivial sltu v5"); + test_invertability(SLTU, "1", 1, 1, side, false, "trivial sltu v6"); + + side = OperandSide::Rhs; + + test_invertability(SLTU, "1", 2, 1, side, false, "trivial sltu v7"); + test_invertability(SLTU, "10", 1, 0, side, false, "trivial sltu v8"); + test_invertability(SLTU, "1", 1, 0, side, true, "trivial sltu v9"); + + test_invertability(SLTU, "1", 2, 0, side, true, "trivial sltu v10"); + test_invertability(SLTU, "10", 1, 1, side, true, "trivial sltu v11"); + test_invertability(SLTU, "1", 1, 1, side, false, "trivial sltu v12"); + + side = OperandSide::Lhs; + + test_invertability(SLTU, "*", 2, 1, side, true, "nontrivial sltu v1"); + test_invertability(SLTU, "**", 1, 0, side, true, "nontrivial sltu v2"); + test_invertability(SLTU, "1*", 1, 0, side, true, "nontrivial sltu v3"); + + test_invertability(SLTU, "*", 2, 0, side, false, "nontrivial sltu v4"); + test_invertability(SLTU, "**", 1, 1, side, true, "nontrivial sltu v5"); + test_invertability(SLTU, "1*", 1, 1, side, false, "nontrivial sltu v6"); + + side = OperandSide::Rhs; + + test_invertability(SLTU, "*", 2, 1, side, false, "nontrivial sltu v7"); + test_invertability(SLTU, "**", 1, 0, side, true, "nontrivial sltu v8"); + test_invertability(SLTU, "1*", 1, 0, side, false, "nontrivial sltu v9"); + + test_invertability(SLTU, "*", 2, 0, side, true, "nontrivial sltu v10"); + test_invertability(SLTU, "**", 1, 1, side, true, "nontrivial sltu v11"); + test_invertability(SLTU, "1*", 1, 1, side, true, "nontrivial sltu v12"); + } + #[test] fn compute_inverse_values_for_mul() { let side = OperandSide::Lhs; @@ -853,23 +1015,66 @@ mod tests { test_inverse_value_computation(MUL, "**0", 0b10, 0b1100, side, f); } + #[test] + fn compute_inverse_values_for_sltu() { + let mut side = OperandSide::Lhs; + + fn f(l: BitVector, r: BitVector) -> BitVector { + if l < r { + BitVector(1) + } else { + BitVector(0) + } + } + + // test only for values which are actually invertable + test_inverse_value_computation(SLTU, "1", 2, 1, side, f); + test_inverse_value_computation(SLTU, "*", 1, 0, side, f); + test_inverse_value_computation(SLTU, "***", 2, 1, side, f); + + side = OperandSide::Rhs; + + test_inverse_value_computation(SLTU, "*", 2, 0, side, f); + test_inverse_value_computation(SLTU, "**", 1, 1, side, f); + test_inverse_value_computation(SLTU, "***", 6, 1, side, f); + } + #[test] fn check_consistency_condition_for_mul() { + let side = OperandSide::Lhs; let condition = "t != 0 => x^hi != 0"; - test_consistence(MUL, "1", 0b110, true, condition); - test_consistence(MUL, "0", 0b110, false, condition); + test_consistence(MUL, "1", 0b110, true, condition, side); + test_consistence(MUL, "0", 0b110, false, condition, side); let condition = "odd(t) => x^hi[lsb] != 0"; - test_consistence(MUL, "*00", 0b101, false, condition); - test_consistence(MUL, "*01", 0b101, true, condition); - test_consistence(MUL, "*0*", 0b11, true, condition); + test_consistence(MUL, "*00", 0b101, false, condition, side); + test_consistence(MUL, "*01", 0b101, true, condition, side); + test_consistence(MUL, "*0*", 0b11, true, condition, side); let condition = "Ey.(mcb(x, y) && ctz(t) >= ctz(y))"; - test_consistence(MUL, "*00", 0b100, true, condition); - test_consistence(MUL, "*00", 0b10, false, condition); + test_consistence(MUL, "*00", 0b100, true, condition, side); + test_consistence(MUL, "*00", 0b10, false, condition, side); + } + + #[test] + fn check_consistency_condition_for_sltu() { + let mut side = OperandSide::Lhs; + let condition = "t=1 and x=ones is F"; + test_consistence( + SLTU, + "1111111111111111111111111111111111111111111111111111111111111111", + 1, + false, + condition, + side, + ); + + side = OperandSide::Rhs; + let condition = "t=0 and x=0 is F"; + test_consistence(SLTU, "0", 1, false, condition, side); } #[test] @@ -886,4 +1091,29 @@ mod tests { test_consistent_value_computation(MUL, "*0*", 0b11, side, f); test_consistent_value_computation(MUL, "*00", 0b100, side, f); } + + //TODO + #[test] + fn compute_consistent_values_for_sltu() { + let mut side = OperandSide::Lhs; + + fn f(l: BitVector, r: BitVector) -> BitVector { + if l < r { + BitVector(1) + } else { + BitVector(0) + } + } + + // test only for values which actually have a consistent value + test_consistent_value_computation(SLTU, "1111111", 0, side, f); + test_consistent_value_computation(SLTU, "*****", 1, side, f); + + side = OperandSide::Rhs; + + // test only for values which actually have a consistent value + test_consistent_value_computation(SLTU, "1", 0, side, f); + test_consistent_value_computation(SLTU, "*******", 0, side, f); + test_consistent_value_computation(SLTU, "*******", 1, side, f); + } } diff --git a/src/symbolic_state.rs b/src/symbolic_state.rs index c62ab873..b012e232 100644 --- a/src/symbolic_state.rs +++ b/src/symbolic_state.rs @@ -31,6 +31,7 @@ pub enum BVOperator { Sub, Mul, Divu, + Sltu, Not, Equals, BitwiseAnd, @@ -55,6 +56,7 @@ impl fmt::Display for BVOperator { BVOperator::Not => "!", BVOperator::Equals => "=", BVOperator::BitwiseAnd => "&", + BVOperator::Sltu => "<", } ) } @@ -82,6 +84,7 @@ fn instruction_to_bv_operator(instruction: Instruction) -> BVOperator { Instruction::Sub(_) => BVOperator::Sub, Instruction::Mul(_) => BVOperator::Mul, Instruction::Divu(_) => BVOperator::Divu, + Instruction::Sltu(_) => BVOperator::Sltu, _ => unimplemented!("can not translate {:?} to Operator", instruction), } }