From d0a59f37404c7ccf0d2e13eda40dd8da3f6884d5 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sat, 16 Dec 2023 15:12:57 -0800 Subject: [PATCH] intblast with lazy expansion of shl, ashr, lshr Signed-off-by: Nikolaj Bjorner --- src/ast/arith_decl_plugin.cpp | 21 +++- src/ast/arith_decl_plugin.h | 31 ++++-- src/ast/rewriter/arith_rewriter.cpp | 103 +++++++++++++++++++ src/ast/rewriter/arith_rewriter.h | 3 + src/math/lp/lp_api.h | 4 +- src/sat/smt/arith_axioms.cpp | 149 ++++++++++++++++++++++------ src/sat/smt/arith_internalize.cpp | 8 +- src/sat/smt/arith_solver.cpp | 2 +- src/sat/smt/arith_solver.h | 8 +- src/sat/smt/intblast_solver.cpp | 75 ++++++++------ 10 files changed, 321 insertions(+), 83 deletions(-) diff --git a/src/ast/arith_decl_plugin.cpp b/src/ast/arith_decl_plugin.cpp index 8317b37c39b..f09daaf7541 100644 --- a/src/ast/arith_decl_plugin.cpp +++ b/src/ast/arith_decl_plugin.cpp @@ -508,6 +508,19 @@ static bool is_const_op(decl_kind k) { //k == OP_0_PW_0_REAL; } +symbol arith_decl_plugin::bv_symbol(decl_kind k) const { + switch (k) { + case OP_ARITH_BAND: return symbol("band"); + case OP_ARITH_SHL: return symbol("shl"); + case OP_ARITH_ASHR: return symbol("ashr"); + case OP_ARITH_LSHR: return symbol("lshr"); + default: + UNREACHABLE(); + } + return symbol(); +} + + func_decl * arith_decl_plugin::mk_func_decl(decl_kind k, unsigned num_parameters, parameter const * parameters, unsigned arity, sort * const * domain, sort * range) { if (k == OP_NUM) @@ -523,10 +536,10 @@ func_decl * arith_decl_plugin::mk_func_decl(decl_kind k, unsigned num_parameters return m_manager->mk_func_decl(symbol("divisible"), 1, &m_int_decl, m_manager->mk_bool_sort(), func_decl_info(m_family_id, k, num_parameters, parameters)); } - if (k == OP_ARITH_BAND) { + if (k == OP_ARITH_BAND || k == OP_ARITH_SHL || k == OP_ARITH_ASHR || k == OP_ARITH_LSHR) { if (arity != 2 || domain[0] != m_int_decl || domain[1] != m_int_decl || num_parameters != 1 || !parameters[0].is_int()) m_manager->raise_exception("invalid bitwise and application. Expects integer parameter and two arguments of sort integer"); - return m_manager->mk_func_decl(symbol("band"), 2, domain, m_int_decl, + return m_manager->mk_func_decl(bv_symbol(k), 2, domain, m_int_decl, func_decl_info(m_family_id, k, num_parameters, parameters)); } @@ -554,11 +567,11 @@ func_decl * arith_decl_plugin::mk_func_decl(decl_kind k, unsigned num_parameters return m_manager->mk_func_decl(symbol("divisible"), 1, &m_int_decl, m_manager->mk_bool_sort(), func_decl_info(m_family_id, k, num_parameters, parameters)); } - if (k == OP_ARITH_BAND) { + if (k == OP_ARITH_BAND || k == OP_ARITH_SHL || k == OP_ARITH_ASHR || k == OP_ARITH_LSHR) { if (num_args != 2 || args[0]->get_sort() != m_int_decl || args[1]->get_sort() != m_int_decl || num_parameters != 1 || !parameters[0].is_int()) m_manager->raise_exception("invalid bitwise and application. Expects integer parameter and two arguments of sort integer"); sort* domain[2] = { m_int_decl, m_int_decl }; - return m_manager->mk_func_decl(symbol("band"), 2, domain, m_int_decl, + return m_manager->mk_func_decl(bv_symbol(k), 2, domain, m_int_decl, func_decl_info(m_family_id, k, num_parameters, parameters)); } diff --git a/src/ast/arith_decl_plugin.h b/src/ast/arith_decl_plugin.h index 25c4977e9f2..308bc1326aa 100644 --- a/src/ast/arith_decl_plugin.h +++ b/src/ast/arith_decl_plugin.h @@ -72,6 +72,9 @@ enum arith_op_kind { OP_ATANH, // Bit-vector functions OP_ARITH_BAND, + OP_ARITH_SHL, + OP_ARITH_ASHR, + OP_ARITH_LSHR, // constants OP_PI, OP_E, @@ -150,6 +153,8 @@ class arith_decl_plugin : public decl_plugin { bool m_convert_int_numerals_to_real; + symbol bv_symbol(decl_kind k) const; + func_decl * mk_func_decl(decl_kind k, bool is_real); void set_manager(ast_manager * m, family_id id) override; decl_kind fix_kind(decl_kind k, unsigned arity); @@ -233,6 +238,14 @@ class arith_decl_plugin : public decl_plugin { executed in different threads. */ class arith_recognizers { + bool is_arith_op(expr const* n, decl_kind k, unsigned& sz, expr*& x, expr*& y) { + if (!is_app_of(n, arith_family_id, k)) + return false; + x = to_app(n)->get_arg(0); + y = to_app(n)->get_arg(1); + sz = to_app(n)->get_parameter(0).get_int(); + return true; + } public: family_id get_family_id() const { return arith_family_id; } @@ -296,14 +309,13 @@ class arith_recognizers { bool is_int_real(expr const * n) const { return is_int_real(n->get_sort()); } bool is_band(expr const* n) const { return is_app_of(n, arith_family_id, OP_ARITH_BAND); } - bool is_band(expr const* n, unsigned& sz, expr*& x, expr*& y) { - if (!is_band(n)) - return false; - x = to_app(n)->get_arg(0); - y = to_app(n)->get_arg(1); - sz = to_app(n)->get_parameter(0).get_int(); - return true; - } + bool is_band(expr const* n, unsigned& sz, expr*& x, expr*& y) { return is_arith_op(n, OP_ARITH_BAND, sz, x, y); } + bool is_shl(expr const* n) const { return is_app_of(n, arith_family_id, OP_ARITH_SHL); } + bool is_shl(expr const* n, unsigned& sz, expr*& x, expr*& y) { return is_arith_op(n, OP_ARITH_SHL, sz, x, y); } + bool is_lshr(expr const* n) const { return is_app_of(n, arith_family_id, OP_ARITH_LSHR); } + bool is_lshr(expr const* n, unsigned& sz, expr*& x, expr*& y) { return is_arith_op(n, OP_ARITH_LSHR, sz, x, y); } + bool is_ashr(expr const* n) const { return is_app_of(n, arith_family_id, OP_ARITH_ASHR); } + bool is_ashr(expr const* n, unsigned& sz, expr*& x, expr*& y) { return is_arith_op(n, OP_ARITH_ASHR, sz, x, y); } bool is_sin(expr const* n) const { return is_app_of(n, arith_family_id, OP_SIN); } bool is_cos(expr const* n) const { return is_app_of(n, arith_family_id, OP_COS); } @@ -487,6 +499,9 @@ class arith_util : public arith_recognizers { app * mk_power0(expr* arg1, expr* arg2) { return m_manager.mk_app(arith_family_id, OP_POWER0, arg1, arg2); } app* mk_band(unsigned n, expr* arg1, expr* arg2) { parameter p(n); expr* args[2] = { arg1, arg2 }; return m_manager.mk_app(arith_family_id, OP_ARITH_BAND, 1, &p, 2, args); } + app* mk_shl(unsigned n, expr* arg1, expr* arg2) { parameter p(n); expr* args[2] = { arg1, arg2 }; return m_manager.mk_app(arith_family_id, OP_ARITH_SHL, 1, &p, 2, args); } + app* mk_ashr(unsigned n, expr* arg1, expr* arg2) { parameter p(n); expr* args[2] = { arg1, arg2 }; return m_manager.mk_app(arith_family_id, OP_ARITH_ASHR, 1, &p, 2, args); } + app* mk_lshr(unsigned n, expr* arg1, expr* arg2) { parameter p(n); expr* args[2] = { arg1, arg2 }; return m_manager.mk_app(arith_family_id, OP_ARITH_LSHR, 1, &p, 2, args); } app * mk_sin(expr * arg) { return m_manager.mk_app(arith_family_id, OP_SIN, arg); } app * mk_cos(expr * arg) { return m_manager.mk_app(arith_family_id, OP_COS, arg); } diff --git a/src/ast/rewriter/arith_rewriter.cpp b/src/ast/rewriter/arith_rewriter.cpp index ddfabed8380..d8a06ada65b 100644 --- a/src/ast/rewriter/arith_rewriter.cpp +++ b/src/ast/rewriter/arith_rewriter.cpp @@ -92,6 +92,9 @@ br_status arith_rewriter::mk_app_core(func_decl * f, unsigned num_args, expr * c case OP_COSH: SASSERT(num_args == 1); st = mk_cosh_core(args[0], result); break; case OP_TANH: SASSERT(num_args == 1); st = mk_tanh_core(args[0], result); break; case OP_ARITH_BAND: SASSERT(num_args == 2); st = mk_band_core(f->get_parameter(0).get_int(), args[0], args[1], result); break; + case OP_ARITH_SHL: SASSERT(num_args == 2); st = mk_shl_core(f->get_parameter(0).get_int(), args[0], args[1], result); break; + case OP_ARITH_ASHR: SASSERT(num_args == 2); st = mk_ashr_core(f->get_parameter(0).get_int(), args[0], args[1], result); break; + case OP_ARITH_LSHR: SASSERT(num_args == 2); st = mk_lshr_core(f->get_parameter(0).get_int(), args[0], args[1], result); break; default: st = BR_FAILED; break; } CTRACE("arith_rewriter", st != BR_FAILED, tout << st << ": " << mk_pp(f, m); @@ -1350,6 +1353,98 @@ app* arith_rewriter_core::mk_power(expr* x, rational const& r, sort* s) { return y; } +br_status arith_rewriter::mk_shl_core(unsigned sz, expr* arg1, expr* arg2, expr_ref& result) { + numeral x, y, N; + bool is_num_x = m_util.is_numeral(arg1, x); + bool is_num_y = m_util.is_numeral(arg2, y); + N = rational::power_of_two(sz); + if (is_num_x) + x = mod(x, N); + if (is_num_y) + y = mod(y, N); + if (is_num_x && is_num_y) { + if (y >= sz) + result = m_util.mk_int(0); + else + result = m_util.mk_int(mod(x * rational::power_of_two(y.get_unsigned()), N)); + return BR_DONE; + } + if (is_num_y) { + if (y >= sz) + result = m_util.mk_int(0); + else + result = m_util.mk_mod(m_util.mk_mul(arg1, m_util.mk_int(rational::power_of_two(y.get_unsigned()))), m_util.mk_int(N)); + return BR_REWRITE1; + } + if (is_num_x && x == 0) { + result = m_util.mk_int(0); + return BR_DONE; + } + return BR_FAILED; +} +br_status arith_rewriter::mk_ashr_core(unsigned sz, expr* arg1, expr* arg2, expr_ref& result) { + numeral x, y, N; + bool is_num_x = m_util.is_numeral(arg1, x); + bool is_num_y = m_util.is_numeral(arg2, y); + N = rational::power_of_two(sz); + if (is_num_x) + x = mod(x, N); + if (is_num_y) + y = mod(y, N); + if (is_num_x && x == 0) { + result = m_util.mk_int(0); + return BR_DONE; + } + if (is_num_x && is_num_y) { + bool signx = x >= N/2; + rational d = div(x, rational::power_of_two(y.get_unsigned())); + SASSERT(y >= 0); + if (signx) { + if (y >= sz) + result = m_util.mk_int(N-1); + else + result = m_util.mk_int(d); + } + else { + if (y >= sz) + result = m_util.mk_int(0); + else + result = m_util.mk_int(mod(d - rational::power_of_two(sz - y.get_unsigned()), N)); + } + return BR_DONE; + } + return BR_FAILED; +} + +br_status arith_rewriter::mk_lshr_core(unsigned sz, expr* arg1, expr* arg2, expr_ref& result) { + numeral x, y, N; + bool is_num_x = m_util.is_numeral(arg1, x); + bool is_num_y = m_util.is_numeral(arg2, y); + N = rational::power_of_two(sz); + if (is_num_x) + x = mod(x, N); + if (is_num_y) + y = mod(y, N); + if (is_num_x && x == 0) { + result = m_util.mk_int(0); + return BR_DONE; + } + if (is_num_y && y == 0) { + result = arg1; + return BR_DONE; + } + if (is_num_x && is_num_y) { + if (y >= sz) + result = m_util.mk_int(N-1); + else { + rational d = div(x, rational::power_of_two(y.get_unsigned())); + result = m_util.mk_int(d); + } + return BR_DONE; + } + return BR_FAILED; +} + br_status arith_rewriter::mk_band_core(unsigned sz, expr* arg1, expr* arg2, expr_ref& result) { numeral x, y, N; bool is_num_x = m_util.is_numeral(arg1, x); @@ -1375,6 +1470,14 @@ br_status arith_rewriter::mk_band_core(unsigned sz, expr* arg1, expr* arg2, expr result = m_util.mk_int(r); return BR_DONE; } + if (is_num_x && (x + 1).is_power_of_two()) { + result = m_util.mk_mod(arg2, m_util.mk_int(x + 1)); + return BR_REWRITE1; + } + if (is_num_y && (y + 1).is_power_of_two()) { + result = m_util.mk_mod(arg1, m_util.mk_int(y + 1)); + return BR_REWRITE1; + } return BR_FAILED; } diff --git a/src/ast/rewriter/arith_rewriter.h b/src/ast/rewriter/arith_rewriter.h index 548ab80dbed..6066c9eb419 100644 --- a/src/ast/rewriter/arith_rewriter.h +++ b/src/ast/rewriter/arith_rewriter.h @@ -160,6 +160,9 @@ class arith_rewriter : public poly_rewriter { br_status mk_rem_core(expr * arg1, expr * arg2, expr_ref & result); br_status mk_power_core(expr* arg1, expr* arg2, expr_ref & result); br_status mk_band_core(unsigned sz, expr* arg1, expr* arg2, expr_ref& result); + br_status mk_shl_core(unsigned sz, expr* arg1, expr* arg2, expr_ref& result); + br_status mk_lshr_core(unsigned sz, expr* arg1, expr* arg2, expr_ref& result); + br_status mk_ashr_core(unsigned sz, expr* arg1, expr* arg2, expr_ref& result); void mk_div(expr * arg1, expr * arg2, expr_ref & result) { if (mk_div_core(arg1, arg2, result) == BR_FAILED) result = m.mk_app(get_fid(), OP_DIV, arg1, arg2); diff --git a/src/math/lp/lp_api.h b/src/math/lp/lp_api.h index 0eb8b6b3713..021501ecd25 100644 --- a/src/math/lp/lp_api.h +++ b/src/math/lp/lp_api.h @@ -108,7 +108,7 @@ namespace lp_api { unsigned m_gomory_cuts; unsigned m_assume_eqs; unsigned m_branch; - unsigned m_band_axioms; + unsigned m_bv_axioms; stats() { reset(); } void reset() { memset(this, 0, sizeof(*this)); @@ -129,7 +129,7 @@ namespace lp_api { st.update("arith-gomory-cuts", m_gomory_cuts); st.update("arith-assume-eqs", m_assume_eqs); st.update("arith-branch", m_branch); - st.update("arith-band-axioms", m_band_axioms); + st.update("arith-bv-axioms", m_bv_axioms); } }; diff --git a/src/sat/smt/arith_axioms.cpp b/src/sat/smt/arith_axioms.cpp index f004422a652..ae67783ebd8 100644 --- a/src/sat/smt/arith_axioms.cpp +++ b/src/sat/smt/arith_axioms.cpp @@ -205,58 +205,117 @@ namespace arith { add_clause(dgez, neg); } - bool solver::check_band_term(app* n) { + bool solver::check_bv_term(app* n) { unsigned sz; - expr* x, * y; + expr* _x, * _y; if (!ctx.is_relevant(expr2enode(n))) return true; - VERIFY(a.is_band(n, sz, x, y)); expr_ref vx(m), vy(m),vn(m); - if (!get_value(expr2enode(x), vx) || !get_value(expr2enode(y), vy) || !get_value(expr2enode(n), vn)) { + rational valn, valx, valy; + bool is_int; + VERIFY(a.is_band(n, sz, _x, _y) || a.is_shl(n, sz, _x, _y) || a.is_ashr(n, sz, _x, _y) || a.is_lshr(n, sz, _x, _y)); + if (!get_value(expr2enode(_x), vx) || !get_value(expr2enode(_y), vy) || !get_value(expr2enode(n), vn)) { IF_VERBOSE(2, verbose_stream() << "could not get value of " << mk_pp(n, m) << "\n"); found_unsupported(n); return true; } - rational valn, valx, valy; - bool is_int; if (!a.is_numeral(vn, valn, is_int) || !is_int || !a.is_numeral(vx, valx, is_int) || !is_int || !a.is_numeral(vy, valy, is_int) || !is_int) { IF_VERBOSE(2, verbose_stream() << "could not get value of " << mk_pp(n, m) << "\n"); found_unsupported(n); return true; } - // verbose_stream() << "band: " << mk_pp(n, m) << " " << valn << " := " << valx << "&" << valy << "\n"; rational N = rational::power_of_two(sz); valx = mod(valx, N); valy = mod(valy, N); + expr_ref x(a.mk_mod(_x, a.mk_int(N)), m); + expr_ref y(a.mk_mod(_y, a.mk_int(N)), m); SASSERT(0 <= valn && valn < N); - + // x mod 2^{i + 1} >= 2^i means the i'th bit is 1. auto bitof = [&](expr* x, unsigned i) { expr_ref r(m); r = a.mk_ge(a.mk_mod(x, a.mk_int(rational::power_of_two(i+1))), a.mk_int(rational::power_of_two(i))); return mk_literal(r); }; - for (unsigned i = 0; i < sz; ++i) { - bool xb = valx.get_bit(i); - bool yb = valy.get_bit(i); - bool nb = valn.get_bit(i); - if (xb && yb && !nb) - add_clause(~bitof(x, i), ~bitof(y, i), bitof(n, i)); - else if (nb && !xb) - add_clause(~bitof(n, i), bitof(x, i)); - else if (nb && !yb) - add_clause(~bitof(n, i), bitof(y, i)); - else - continue; + + if (a.is_band(n)) { + IF_VERBOSE(2, verbose_stream() << "band: " << mk_bounded_pp(n, m) << " " << valn << " := " << valx << "&" << valy << "\n"); + for (unsigned i = 0; i < sz; ++i) { + bool xb = valx.get_bit(i); + bool yb = valy.get_bit(i); + bool nb = valn.get_bit(i); + if (xb && yb && !nb) + add_clause(~bitof(x, i), ~bitof(y, i), bitof(n, i)); + else if (nb && !xb) + add_clause(~bitof(n, i), bitof(x, i)); + else if (nb && !yb) + add_clause(~bitof(n, i), bitof(y, i)); + else + continue; + return false; + } + } + if (a.is_shl(n)) { + SASSERT(valy >= 0); + if (valy >= sz || valy == 0) + return true; + unsigned k = valy.get_unsigned(); + sat::literal eq = eq_internalize(n, a.mk_mod(a.mk_mul(_x, a.mk_int(rational::power_of_two(k))), a.mk_int(N))); + if (s().value(eq) == l_true) + return true; + add_clause(~eq_internalize(y, a.mk_int(k)), eq); + IF_VERBOSE(2, verbose_stream() << "shl: " << mk_bounded_pp(n, m) << " " << valn << " := " << valx << " << " << valy << "\n"); + return false; + } + if (a.is_lshr(n)) { + SASSERT(valy >= 0); + if (valy >= sz || valy == 0) + return true; + unsigned k = valy.get_unsigned(); + sat::literal eq = eq_internalize(n, a.mk_idiv(x, a.mk_int(rational::power_of_two(k)))); + if (s().value(eq) == l_true) + return true; + add_clause(~eq_internalize(y, a.mk_int(k)), eq); + IF_VERBOSE(2, verbose_stream() << "lshr: " << mk_bounded_pp(n, m) << " " << valn << " := " << valx << " >>l " << valy << "\n"); + return false; + } + if (a.is_ashr(n)) { + SASSERT(valy >= 0); + if (valy >= sz || valy == 0) + return true; + unsigned k = valy.get_unsigned(); + sat::literal signx = mk_literal(a.mk_ge(x, a.mk_int(N/2))); + sat::literal eq; + expr* xdiv2k; + switch (s().value(signx)) { + case l_true: + // x < 0 & y = k -> n = (x div 2^k - 2^{N-k}) mod 2^N + xdiv2k = a.mk_idiv(x, a.mk_int(rational::power_of_two(k))); + eq = eq_internalize(n, a.mk_mod(a.mk_add(xdiv2k, a.mk_int(-rational::power_of_two(sz - k))), a.mk_int(N))); + if (s().value(eq) == l_true) + return true; + break; + case l_false: + // x >= 0 & y = k -> n = x div 2^k + xdiv2k = a.mk_idiv(x, a.mk_int(rational::power_of_two(k))); + eq = eq_internalize(n, xdiv2k); + if (s().value(eq) == l_true) + return true; + break; + case l_undef: + ctx.mark_relevant(signx); + return false; + } + add_clause(~eq_internalize(y, a.mk_int(k)), ~signx, eq); return false; } return true; } - bool solver::check_band_terms() { - for (app* n : m_band_terms) { - if (!check_band_term(n)) { - ++m_stats.m_band_axioms; + bool solver::check_bv_terms() { + for (app* n : m_bv_terms) { + if (!check_bv_term(n)) { + ++m_stats.m_bv_axioms; return false; } } @@ -268,15 +327,43 @@ namespace arith { * x&y <= x * x&y <= y */ - void solver::mk_band_axiom(app* n) { + void solver::mk_bv_axiom(app* n) { unsigned sz; - expr* x, * y; - VERIFY(a.is_band(n, sz, x, y)); + expr* _x, * _y; + VERIFY(a.is_band(n, sz, _x, _y) || a.is_shl(n, sz, _x, _y) || a.is_ashr(n, sz, _x, _y) || a.is_lshr(n, sz, _x, _y)); rational N = rational::power_of_two(sz); - add_clause(mk_literal(a.mk_ge(n, a.mk_int(0)))); - add_clause(mk_literal(a.mk_le(n, a.mk_int(N - 1)))); - add_clause(mk_literal(a.mk_le(n, a.mk_mod(x, a.mk_int(N))))); - add_clause(mk_literal(a.mk_le(n, a.mk_mod(y, a.mk_int(N))))); + expr_ref x(a.mk_mod(_x, a.mk_int(N)), m); + expr_ref y(a.mk_mod(_y, a.mk_int(N)), m); + + if (a.is_band(n)) { + add_clause(mk_literal(a.mk_ge(n, a.mk_int(0)))); + add_clause(mk_literal(a.mk_le(n, a.mk_int(N - 1)))); + add_clause(mk_literal(a.mk_le(n, x))); + add_clause(mk_literal(a.mk_le(n, y))); + } + else if (a.is_shl(n)) { + // y >= sz => n = 0 + // y = 0 => n = x + add_clause(~mk_literal(a.mk_ge(y, a.mk_int(sz))), mk_literal(m.mk_eq(n, a.mk_int(0)))); + add_clause(~mk_literal(a.mk_eq(y, a.mk_int(0))), mk_literal(m.mk_eq(n, x))); + } + else if (a.is_lshr(n)) { + // y >= sz => n = 0 + // y = 0 => n = x + add_clause(~mk_literal(a.mk_ge(y, a.mk_int(sz))), mk_literal(m.mk_eq(n, a.mk_int(0)))); + add_clause(~mk_literal(a.mk_eq(y, a.mk_int(0))), mk_literal(m.mk_eq(n, x))); + } + else if (a.is_ashr(n)) { + // y >= sz & x < 2^{sz-1} => n = 0 + // y >= sz & x >= 2^{sz-1} => n = -1 + // y = 0 => n = x + auto signx = mk_literal(a.mk_ge(x, a.mk_int(N/2))); + add_clause(~mk_literal(a.mk_ge(a.mk_mod(y, a.mk_int(N)), a.mk_int(sz))), signx, mk_literal(m.mk_eq(n, a.mk_int(0)))); + add_clause(~mk_literal(a.mk_ge(a.mk_mod(y, a.mk_int(N)), a.mk_int(sz))), ~signx, mk_literal(m.mk_eq(n, a.mk_int(N-1)))); + add_clause(~mk_literal(a.mk_eq(a.mk_mod(y, a.mk_int(N)), a.mk_int(0))), mk_literal(m.mk_eq(n, x))); + } + else + UNREACHABLE(); } void solver::mk_bound_axioms(api_bound& b) { diff --git a/src/sat/smt/arith_internalize.cpp b/src/sat/smt/arith_internalize.cpp index decd49019e4..ed49092fd6a 100644 --- a/src/sat/smt/arith_internalize.cpp +++ b/src/sat/smt/arith_internalize.cpp @@ -252,10 +252,10 @@ namespace arith { st.to_ensure_var().push_back(n1); st.to_ensure_var().push_back(n2); } - else if (a.is_band(n)) { - m_band_terms.push_back(to_app(n)); - mk_band_axiom(to_app(n)); - ctx.push(push_back_vector(m_band_terms)); + else if (a.is_band(n) || a.is_shl(n) || a.is_ashr(n) || a.is_lshr(n)) { + m_bv_terms.push_back(to_app(n)); + ctx.push(push_back_vector(m_bv_terms)); + mk_bv_axiom(to_app(n)); ensure_arg_vars(to_app(n)); } else if (!a.is_div0(n) && !a.is_mod0(n) && !a.is_idiv0(n) && !a.is_rem0(n) && !a.is_power0(n)) { diff --git a/src/sat/smt/arith_solver.cpp b/src/sat/smt/arith_solver.cpp index eff25bc4a3b..078515184c1 100644 --- a/src/sat/smt/arith_solver.cpp +++ b/src/sat/smt/arith_solver.cpp @@ -1053,7 +1053,7 @@ namespace arith { if (!check_delayed_eqs()) return sat::check_result::CR_CONTINUE; - if (!int_undef && !check_band_terms()) + if (!int_undef && !check_bv_terms()) return sat::check_result::CR_CONTINUE; if (ctx.get_config().m_arith_ignore_int && int_undef) diff --git a/src/sat/smt/arith_solver.h b/src/sat/smt/arith_solver.h index 022dbeaead6..cbf4206a9cf 100644 --- a/src/sat/smt/arith_solver.h +++ b/src/sat/smt/arith_solver.h @@ -214,7 +214,7 @@ namespace arith { expr* m_not_handled = nullptr; ptr_vector m_underspecified; ptr_vector m_idiv_terms; - ptr_vector m_band_terms; + ptr_vector m_bv_terms; vector > m_use_list; // bounds where variables are used. // attributes for incremental version: @@ -318,7 +318,7 @@ namespace arith { void mk_bound_axioms(api_bound& b); void mk_bound_axiom(api_bound& b1, api_bound& b2); void mk_power0_axioms(app* t, app* n); - void mk_band_axiom(app* n); + void mk_bv_axiom(app* n); void flush_bound_axioms(); void add_farkas_clause(sat::literal l1, sat::literal l2); @@ -410,8 +410,8 @@ namespace arith { bool check_delayed_eqs(); lbool check_lia(); lbool check_nla(); - bool check_band_terms(); - bool check_band_term(app* n); + bool check_bv_terms(); + bool check_bv_term(app* n); void add_lemmas(); void propagate_nla(); void add_equality(lpvar v, rational const& k, lp::explanation const& exp); diff --git a/src/sat/smt/intblast_solver.cpp b/src/sat/smt/intblast_solver.cpp index 9d03d0ad086..fed43e2173c 100644 --- a/src/sat/smt/intblast_solver.cpp +++ b/src/sat/smt/intblast_solver.cpp @@ -656,24 +656,58 @@ namespace intblast { break; } case OP_BSHL: { - expr* x = arg(0), * y = umod(e, 1); - r = a.mk_int(0); - for (unsigned i = 0; i < bv.get_bv_size(e); ++i) - r = m.mk_ite(m.mk_eq(y, a.mk_int(i)), a.mk_mul(x, a.mk_int(rational::power_of_two(i))), r); + if (!a.is_numeral(arg(0)) && !a.is_numeral(arg(1))) + r = a.mk_shl(bv.get_bv_size(e), arg(0),arg(1)); + else { + expr* x = arg(0), * y = umod(e, 1); + r = a.mk_int(0); + IF_VERBOSE(2, verbose_stream() << "shl " << mk_bounded_pp(e, m) << " " << bv.get_bv_size(e) << "\n"); + for (unsigned i = 0; i < bv.get_bv_size(e); ++i) + r = m.mk_ite(m.mk_eq(y, a.mk_int(i)), a.mk_mul(x, a.mk_int(rational::power_of_two(i))), r); + } break; } case OP_BNOT: r = bnot(arg(0)); break; - case OP_BLSHR: { - expr* x = arg(0), * y = umod(e, 1); - r = a.mk_int(0); - for (unsigned i = 0; i < bv.get_bv_size(e); ++i) - r = m.mk_ite(m.mk_eq(y, a.mk_int(i)), a.mk_idiv(x, a.mk_int(rational::power_of_two(i))), r); + case OP_BLSHR: + if (!a.is_numeral(arg(0)) && !a.is_numeral(arg(1))) + r = a.mk_lshr(bv.get_bv_size(e), arg(0), arg(1)); + else { + expr* x = arg(0), * y = umod(e, 1); + r = a.mk_int(0); + IF_VERBOSE(2, verbose_stream() << "lshr " << mk_bounded_pp(e, m) << " " << bv.get_bv_size(e) << "\n"); + for (unsigned i = 0; i < bv.get_bv_size(e); ++i) + r = m.mk_ite(m.mk_eq(y, a.mk_int(i)), a.mk_idiv(x, a.mk_int(rational::power_of_two(i))), r); + } + break; + case OP_BASHR: + if (!a.is_numeral(arg(1))) + r = a.mk_ashr(bv.get_bv_size(e), arg(0), arg(1)); + else { + + // + // ashr(x, y) + // if y = k & x >= 0 -> x / 2^k + // if y = k & x < 0 -> (x / 2^k) - 2^{N-k} + // + unsigned sz = bv.get_bv_size(e); + rational N = bv_size(e); + expr* x = umod(e, 0), *y = umod(e, 1); + expr* signx = a.mk_ge(x, a.mk_int(N / 2)); + r = m.mk_ite(signx, a.mk_int(- 1), a.mk_int(0)); + IF_VERBOSE(1, verbose_stream() << "ashr " << mk_bounded_pp(e, m) << " " << bv.get_bv_size(e) << "\n"); + for (unsigned i = 0; i < sz; ++i) { + expr* d = a.mk_idiv(x, a.mk_int(rational::power_of_two(i))); + r = m.mk_ite(m.mk_eq(y, a.mk_int(i)), + m.mk_ite(signx, a.mk_add(d, a.mk_int(- rational::power_of_two(sz-i))), d), + r); + } + } break; - } case OP_BOR: { // p | q := (p + q) - band(p, q) + IF_VERBOSE(2, verbose_stream() << "bor " << mk_bounded_pp(e, m) << " " << bv.get_bv_size(e) << "\n"); r = arg(0); for (unsigned i = 1; i < args.size(); ++i) r = a.mk_sub(a.mk_add(r, arg(i)), a.mk_band(bv.get_bv_size(e), r, arg(i))); @@ -683,12 +717,14 @@ namespace intblast { r = bnot(band(args)); break; case OP_BAND: + IF_VERBOSE(2, verbose_stream() << "band " << mk_bounded_pp(e, m) << " " << bv.get_bv_size(e) << "\n"); r = band(args); break; case OP_BXNOR: case OP_BXOR: { // p ^ q := (p + q) - 2*band(p, q); unsigned sz = bv.get_bv_size(e); + IF_VERBOSE(2, verbose_stream() << "bxor " << bv.get_bv_size(e) << "\n"); r = arg(0); for (unsigned i = 1; i < args.size(); ++i) { expr* q = arg(i); @@ -698,25 +734,6 @@ namespace intblast { r = bnot(r); break; } - case OP_BASHR: { - // - // ashr(x, y) - // if y = k & x >= 0 -> x / 2^k - // if y = k & x < 0 -> (x / 2^k) - 1 + 2^{N-k} - // - unsigned sz = bv.get_bv_size(e); - rational N = bv_size(e); - expr* x = umod(e, 0), *y = umod(e, 1); - expr* signx = a.mk_ge(x, a.mk_int(N / 2)); - r = m.mk_ite(signx, a.mk_int(- 1), a.mk_int(0)); - for (unsigned i = 0; i < sz; ++i) { - expr* d = a.mk_idiv(x, a.mk_int(rational::power_of_two(i))); - r = m.mk_ite(m.mk_eq(y, a.mk_int(i)), - m.mk_ite(signx, a.mk_add(d, a.mk_int(- rational::power_of_two(sz-i))), d), - r); - } - break; - } case OP_ZERO_EXT: bv_expr = e->get_arg(0); r = umod(bv_expr, 0);