From 197951cad4ed2d91e9f5da08f416004b418571d6 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sat, 16 Nov 2024 08:28:17 -0800 Subject: [PATCH] fixes to sls --- src/ast/sls/sat_ddfw.cpp | 4 ++-- src/ast/sls/sat_ddfw.h | 3 ++- src/ast/sls/sls_arith_base.cpp | 18 ------------------ src/ast/sls/sls_smt_plugin.cpp | 19 +++++++++++-------- src/ast/sls/sls_smt_plugin.h | 4 ++-- src/ast/sls/sls_smt_solver.cpp | 11 ++++++----- src/smt/theory_sls.cpp | 19 ++++++++++--------- src/smt/theory_sls.h | 1 - 8 files changed, 33 insertions(+), 46 deletions(-) diff --git a/src/ast/sls/sat_ddfw.cpp b/src/ast/sls/sat_ddfw.cpp index 1a87524c05d..4a4c2715977 100644 --- a/src/ast/sls/sat_ddfw.cpp +++ b/src/ast/sls/sat_ddfw.cpp @@ -43,7 +43,7 @@ namespace sat { check_without_plugin(); remove_assumptions(); log(); - return m_min_sz == 0 ? l_true : l_undef; + return m_min_sz == 0 && m_limit.inc() ? m_last_result : l_undef; } void ddfw::check_without_plugin() { @@ -401,7 +401,7 @@ namespace sat { m_model[i] = to_lbool(value(i)); save_priorities(); if (m_plugin) - m_plugin->on_save_model(); + m_last_result = m_plugin->on_save_model(); } void ddfw::save_best_values() { diff --git a/src/ast/sls/sat_ddfw.h b/src/ast/sls/sat_ddfw.h index 468178981ea..6f3386a059f 100644 --- a/src/ast/sls/sat_ddfw.h +++ b/src/ast/sls/sat_ddfw.h @@ -39,7 +39,7 @@ namespace sat { //virtual void init_search() = 0; //virtual void finish_search() = 0; virtual void on_rescale() = 0; - virtual void on_save_model() = 0; + virtual lbool on_save_model() = 0; virtual void on_restart() = 0; }; @@ -90,6 +90,7 @@ namespace sat { unsigned_vector m_flat_use_list; unsigned_vector m_use_list_index; unsigned m_use_list_vars = 0, m_use_list_clauses = 0; + lbool m_last_result = l_true; indexed_uint_set m_unsat; indexed_uint_set m_unsat_vars; // set of variables that are in unsat clauses diff --git a/src/ast/sls/sls_arith_base.cpp b/src/ast/sls/sls_arith_base.cpp index 24ba4935a25..c201d9ef3bf 100644 --- a/src/ast/sls/sls_arith_base.cpp +++ b/src/ast/sls/sls_arith_base.cpp @@ -256,9 +256,6 @@ namespace sls { num_t eps(1); if (!is_int(x) && abs(rh - lh) <= eps) eps = abs(rh - lh) / num_t(2); -// verbose_stream() << a << " " << b << " " << c << "\n"; -// verbose_stream() << (-b - root) << " " << (2 * a) << " " << ll << " " << lh << "\n"; -// verbose_stream() << (-b + root) << " " << (2 * a) << " " << rl << " " << rh << "\n"; SASSERT(ll <= lh && ll + 1 >= lh); SASSERT(rl <= rh && rl + 1 >= rh); SASSERT(!is_square || ll != lh || a * ll * ll + b * ll + c == 0); @@ -571,8 +568,6 @@ namespace sls { template bool arith_base::repair(sat::literal lit) { - //verbose_stream() << "repair " << lit << " " << (ctx.is_unit(lit)?"unit":"") << " " << mk_bounded_pp(ctx.atom(lit.var()), m) << "\n"; - //verbose_stream() << *atom(lit.var()) << "\n"; m_last_literal = lit; if (find_nl_moves(lit)) return true; @@ -673,8 +668,6 @@ namespace sls { return false; } - // IF_VERBOSE(0, display(verbose_stream(), v) << " := " << new_value << "\n"); - #if 0 @@ -690,7 +683,6 @@ namespace sls { SASSERT(ctx.is_true(lit)); ineq.m_args_value += coeff * (new_value - old_value); num_t dtt_new = dtt(old_sign, ineq); - // verbose_stream() << "dtt " << lit << " " << ineq << " " << dtt_new << "\n"; if (dtt_new != 0) ctx.flip(bv); SASSERT(dtt(sign(bv), ineq) == 0); @@ -800,7 +792,6 @@ namespace sls { SASSERT(ctx.is_true(lit)); ineq.m_args_value += coeff * (new_value - old_value); num_t dtt_new = dtt(old_sign, ineq); - // verbose_stream() << "dtt " << lit << " " << ineq << " " << dtt_new << "\n"; if (dtt_new != 0) ctx.flip(bv); SASSERT(dtt(sign(bv), ineq) == 0); @@ -1006,7 +997,6 @@ namespace sls { typename arith_base::var_t arith_base::mk_var(expr* e) { var_t v = m_expr2var.get(e->get_id(), UINT_MAX); if (v == UINT_MAX) { - // verbose_stream() << "mk-var " << mk_bounded_pp(e, m) << "\n"; v = m_vars.size(); m_expr2var.setx(e->get_id(), v, UINT_MAX); m_vars.push_back(var_info(e, a.is_int(e) ? var_sort::INT : var_sort::REAL)); @@ -1245,7 +1235,6 @@ namespace sls { return false; flet _tabu(m_use_tabu, false); TRACE("sls", tout << "repair def " << mk_bounded_pp(vi.m_expr, m) << "\n"); - // verbose_stream() << "repair down " << mk_bounded_pp(e, m) << "\n"; switch (vi.m_op) { case arith_op_kind::LAST_ARITH_OP: break; @@ -1522,11 +1511,8 @@ namespace sls { if (old_value == sum) return true; - //display(verbose_stream() << "repair add v" << v << " ", ad) << " " << old_value << " sum " << sum << "\n"; m_updates.reset(); -// display(verbose_stream(), v) << " "; -// verbose_stream() << mk_bounded_pp(m_vars[v].m_expr, m) << " := " << old_value << " " << sum << "\n"; for (auto const& [coeff, w] : coeffs) { auto delta = divide(w, sum - old_value, coeff); @@ -1590,8 +1576,6 @@ namespace sls { } } - // verbose_stream() << "repair product v" << v << "\n"; - if (apply_update()) return eval_is_correct(v); @@ -1957,7 +1941,6 @@ namespace sls { } else new_value = val; - //verbose_stream() << v << " " << vi.m_value << " -> " << new_value << "\n"; vi.m_value = new_value; } else { @@ -2308,7 +2291,6 @@ namespace sls { num_t val = i.m_coeff; for (auto const& [c, v] : i.m_args) val += c * value(v); - //verbose_stream() << "invariant " << i << "\n"; if (val != i.m_args_value) verbose_stream() << i << "\n"; SASSERT(val == i.m_args_value); diff --git a/src/ast/sls/sls_smt_plugin.cpp b/src/ast/sls/sls_smt_plugin.cpp index 128f759302b..6d0cb0e0a45 100644 --- a/src/ast/sls/sls_smt_plugin.cpp +++ b/src/ast/sls/sls_smt_plugin.cpp @@ -156,10 +156,8 @@ namespace sls { if (!e) return false; bv_util bv(m); - if (bv.is_bit2bool(e, t) && m_shared_terms.contains(t->get_id())) { - verbose_stream() << "shared bit2bool " << mk_bounded_pp(e, ctx.get_manager()) << "\n"; - return true; - } + if (bv.is_bit2bool(e, t) && m_shared_terms.contains(t->get_id())) + return true; // if arith.is_le(e, s, t) && t is a numeral, s is shared-term.... return false; @@ -175,7 +173,6 @@ namespace sls { } void smt_plugin::add_unit(sat::literal lit) { - verbose_stream() << "add unit " << lit << " " << is_shared(lit) << "\n"; if (!is_shared(lit)) return; std::lock_guard lock(m_mutex); @@ -222,6 +219,7 @@ namespace sls { } void smt_plugin::smt_phase_to_sls() { + IF_VERBOSE(2, verbose_stream() << "SMT -> SLS phase\n"); for (auto v : m_shared_bool_vars) { auto w = m_smt_bool_var2sls_bool_var[v]; auto phase = ctx.get_best_phase(v); @@ -232,6 +230,7 @@ namespace sls { } void smt_plugin::smt_values_to_sls() { + IF_VERBOSE(2, verbose_stream() << "SMT -> SLS values\n"); for (auto const& [t, t_sync] : m_smt2sync_uninterp) { expr_ref val_t(m); if (!ctx.get_smt_value(t, val_t)) @@ -261,7 +260,9 @@ namespace sls { if (m_shared_bool_vars.contains(v)) { auto w = m_smt_bool_var2sls_bool_var[v]; sat::literal sls_lit(w, lit.sign()); - IF_VERBOSE(2, verbose_stream() << "unit " << sls_lit << "\n"); + if (m_context.is_unit(sls_lit)) + continue; + IF_VERBOSE(3, verbose_stream() << "unit " << sls_lit << "\n"); m_ddfw->add(1, &sls_lit); } else { @@ -350,15 +351,17 @@ namespace sls { m_sls2sync_uninterp.insert(sls_t, sync_t); } - void smt_plugin::on_save_model() { + lbool smt_plugin::on_save_model() { TRACE("sls", display(tout)); + lbool r = l_true; while (unsat().empty()) { - m_context.check(); + r = m_context.check(); if (!m_new_clause_added) break; m_ddfw->reinit(); m_new_clause_added = false; } export_from_sls(); + return r; } } diff --git a/src/ast/sls/sls_smt_plugin.h b/src/ast/sls/sls_smt_plugin.h index 43a97c05551..616fd801c01 100644 --- a/src/ast/sls/sls_smt_plugin.h +++ b/src/ast/sls/sls_smt_plugin.h @@ -124,10 +124,10 @@ namespace sls { m_ddfw->reinit(); } - void on_save_model() override; + lbool on_save_model() override; void on_model(model_ref& mdl) override { - IF_VERBOSE(3, verbose_stream() << "on-model " << "\n"); + IF_VERBOSE(2, verbose_stream() << "on-model " << "\n"); m_sls_model = mdl; } diff --git a/src/ast/sls/sls_smt_solver.cpp b/src/ast/sls/sls_smt_solver.cpp index 0960ba6bd5a..87115bdd609 100644 --- a/src/ast/sls/sls_smt_solver.cpp +++ b/src/ast/sls/sls_smt_solver.cpp @@ -49,13 +49,14 @@ namespace sls { } bool m_on_save_model = false; - void on_save_model() override { + lbool on_save_model() override { + lbool r = l_true; if (m_on_save_model) - return; + return r; flet _on_save_model(m_on_save_model, true); CTRACE("sls", unsat().empty(), display(tout)); while (unsat().empty()) { - m_context.check(); + r = m_context.check(); if (!m_new_constraint) break; TRACE("sls", display(tout)); @@ -63,10 +64,10 @@ namespace sls { m_ddfw.reinit(); m_new_constraint = false; } + return r; } - void on_model(model_ref& mdl) override { - IF_VERBOSE(1, verbose_stream() << "on-model " << "\n"); + void on_model(model_ref& mdl) override { m_model = mdl; } diff --git a/src/smt/theory_sls.cpp b/src/smt/theory_sls.cpp index bc3cb015b85..ee9a8d9f3c2 100644 --- a/src/smt/theory_sls.cpp +++ b/src/smt/theory_sls.cpp @@ -105,10 +105,9 @@ namespace smt { if (!m_smt_plugin) return; - unsigned scope_lvl = ctx.get_scope_level(); - if (ctx.get_search_level() == scope_lvl - n) { + if (ctx.get_search_level() == ctx.get_scope_level() - n) { auto& lits = ctx.assigned_literals(); - for (; m_trail_lim < lits.size() && ctx.get_assign_level(lits[m_trail_lim]) == scope_lvl; ++m_trail_lim) + for (; m_trail_lim < lits.size() && ctx.get_assign_level(lits[m_trail_lim]) == ctx.get_search_level(); ++m_trail_lim) m_smt_plugin->add_unit(lits[m_trail_lim]); } @@ -141,7 +140,8 @@ namespace smt { m_threshold *= 2; m_smt_plugin->smt_units_to_sls(); bounded_run(m_restart_ls_steps); - m_smt_plugin->sls_activity_to_smt(); + if (m_smt_plugin) + m_smt_plugin->sls_activity_to_smt(); } m_difference_score = 0; m_difference_score_threshold = 1; @@ -170,11 +170,12 @@ namespace smt { m_smt_plugin->smt_values_to_sls(); bounded_run(m_final_check_ls_steps); dec_final_check_ls_steps(); - m_smt_plugin->sls_phase_to_smt(); - m_smt_plugin->sls_values_to_smt(); - if (m_num_guided_sls % 20 == 0) - m_smt_plugin->sls_activity_to_smt(); - + if (m_smt_plugin) { + m_smt_plugin->sls_phase_to_smt(); + m_smt_plugin->sls_values_to_smt(); + if (m_num_guided_sls % 20 == 0) + m_smt_plugin->sls_activity_to_smt(); + } return FC_DONE; } diff --git a/src/smt/theory_sls.h b/src/smt/theory_sls.h index db1aa204d16..d7e3c7ba176 100644 --- a/src/smt/theory_sls.h +++ b/src/smt/theory_sls.h @@ -56,7 +56,6 @@ namespace smt { bool m_checking = false; bool m_parallel_mode = true; unsigned m_threshold = 1; - unsigned m_restart_sls_count = 0; unsigned m_difference_score = 0; unsigned m_difference_score_threshold = 0; unsigned m_num_guided_sls = 0;