From 91dc02d8628ef791b7e10c21318de49db98572d8 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sat, 2 Nov 2024 12:32:48 -0700 Subject: [PATCH] Sls (#7439) * reorg sls * sls * na * split into base and plugin * move sat_params to params directory, add op_def repair options * move sat_ddfw to sls, initiate sls-bv-plugin * porting bv-sls * adding basic plugin * na Signed-off-by: Nikolaj Bjorner * add sls-sms solver * bv updates * updated dependencies Signed-off-by: Nikolaj Bjorner * updated dependencies Signed-off-by: Nikolaj Bjorner * use portable ptr-initializer Signed-off-by: Nikolaj Bjorner * move definitions to cpp Signed-off-by: Nikolaj Bjorner * use template<> syntax Signed-off-by: Nikolaj Bjorner * fix compiler errors for gcc Signed-off-by: Nikolaj Bjorner * Bump docker/build-push-action from 6.0.0 to 6.1.0 (#7265) Bumps [docker/build-push-action](https://github.com/docker/build-push-action) from 6.0.0 to 6.1.0. - [Release notes](https://github.com/docker/build-push-action/releases) - [Commits](https://github.com/docker/build-push-action/compare/v6.0.0...v6.1.0) --- updated-dependencies: - dependency-name: docker/build-push-action dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * set clean shutdown for local search and re-enable local search when it parallelizes with PB solver Signed-off-by: Nikolaj Bjorner * Bump docker/build-push-action from 6.1.0 to 6.2.0 (#7269) Bumps [docker/build-push-action](https://github.com/docker/build-push-action) from 6.1.0 to 6.2.0. - [Release notes](https://github.com/docker/build-push-action/releases) - [Commits](https://github.com/docker/build-push-action/compare/v6.1.0...v6.2.0) --- updated-dependencies: - dependency-name: docker/build-push-action dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Fix a comment for Z3_solver_from_string (#7271) Z3_solver_from_string accepts a string buffer with solver assertions, not a string buffer with filename. * trigger the build with a comment change Signed-off-by: Lev Nachmanson * remove macro distinction #7270 * fix #7268 * kludge to address #7232, probably superseeded by planned revision to setup/pypi Signed-off-by: Nikolaj Bjorner * add new ema invariant (#7288) * Bump docker/build-push-action from 6.2.0 to 6.3.0 (#7280) Bumps [docker/build-push-action](https://github.com/docker/build-push-action) from 6.2.0 to 6.3.0. - [Release notes](https://github.com/docker/build-push-action/releases) - [Commits](https://github.com/docker/build-push-action/compare/v6.2.0...v6.3.0) --- updated-dependencies: - dependency-name: docker/build-push-action dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * merge Signed-off-by: Nikolaj Bjorner * fix unit test build Signed-off-by: Nikolaj Bjorner * remove shared attribute Signed-off-by: Nikolaj Bjorner * remove stale files Signed-off-by: Nikolaj Bjorner * fix build of unit test Signed-off-by: Nikolaj Bjorner * fixes and rename sls-cc to sls-euf-plugin Signed-off-by: Nikolaj Bjorner * na Signed-off-by: Nikolaj Bjorner * testing / debugging arithmetic * updates to repair logic, mainly arithmetic * fixes to sls * evolve sls arith * bugfixes in sls-arith * fix typo Signed-off-by: Nikolaj Bjorner * bug fixes * Update sls_test.cpp * fixes * fixes Signed-off-by: Nikolaj Bjorner * fix build Signed-off-by: Nikolaj Bjorner * refactor basic plugin and clause generation Signed-off-by: Nikolaj Bjorner * fixes to ite and other Signed-off-by: Nikolaj Bjorner * updates * update Signed-off-by: Nikolaj Bjorner * fix division by 0 Signed-off-by: Nikolaj Bjorner * disable fail restart Signed-off-by: Nikolaj Bjorner * disable tabu when using reset moves Signed-off-by: Nikolaj Bjorner * update sls_test Signed-off-by: Nikolaj Bjorner * add factoring Signed-off-by: Nikolaj Bjorner * fixes to semantics Signed-off-by: Nikolaj Bjorner * re-add tabu override Signed-off-by: Nikolaj Bjorner * generalize factoring Signed-off-by: Nikolaj Bjorner * fix bug Signed-off-by: Nikolaj Bjorner * remove restart Signed-off-by: Nikolaj Bjorner * disable tabu in fallback modes Signed-off-by: Nikolaj Bjorner * localize impact of factoring Signed-off-by: Nikolaj Bjorner * delay factoring Signed-off-by: Nikolaj Bjorner * flatten products Signed-off-by: Nikolaj Bjorner * perform lookahead update + nested mul Signed-off-by: Nikolaj Bjorner * disable nested mul Signed-off-by: Nikolaj Bjorner * disable nested mul, use non-lookahead Signed-off-by: Nikolaj Bjorner * make reset updates recursive Signed-off-by: Nikolaj Bjorner * include linear moves Signed-off-by: Nikolaj Bjorner * include 5% reset probability Signed-off-by: Nikolaj Bjorner * separate linear update Signed-off-by: Nikolaj Bjorner * separate linear update remove 20% threshold Signed-off-by: Nikolaj Bjorner * remove linear opt Signed-off-by: Nikolaj Bjorner * enable multiplier expansion, enable linear move Signed-off-by: Nikolaj Bjorner * use unit coefficients for muls Signed-off-by: Nikolaj Bjorner * disable non-tabu version of find_nl_moves Signed-off-by: Nikolaj Bjorner * remove coefficient from multiplication definition Signed-off-by: Nikolaj Bjorner * reorg monomials Signed-off-by: Nikolaj Bjorner * add smt params to path Signed-off-by: Nikolaj Bjorner * avoid negative reward Signed-off-by: Nikolaj Bjorner * use reward as proxy for score Signed-off-by: Nikolaj Bjorner * use reward as proxy for score Signed-off-by: Nikolaj Bjorner * use exponential decay with breaks Signed-off-by: Nikolaj Bjorner * use std::pow Signed-off-by: Nikolaj Bjorner * fixes to bv Signed-off-by: Nikolaj Bjorner * fixes to fixed Signed-off-by: Nikolaj Bjorner * fixup repairs Signed-off-by: Nikolaj Bjorner * reserve for multiplication Signed-off-by: Nikolaj Bjorner * fixing repair Signed-off-by: Nikolaj Bjorner * include bounds checks in set random * na * fixes to mul Signed-off-by: Nikolaj Bjorner * fix mul inverse Signed-off-by: Nikolaj Bjorner * fixes to handling signed operators Signed-off-by: Nikolaj Bjorner * logging and fixes Signed-off-by: Nikolaj Bjorner * gcm Signed-off-by: Nikolaj Bjorner * peli Signed-off-by: Nikolaj Bjorner * Add .env to gitignore to prevent environment files from being tracked * Add m_num_pelis counter to stats in sls_context * Remove m_num_pelis member from stats struct in sls_context * Enhance bv_sls_eval with improved repair and logging, refine is_bv_predicate in sls_bv_plugin * Remove verbose logging in register_term function of sls_basic_plugin and fix formatting in sls_context * Rename source files for consistency in `src/ast/sls` directory * Refactor bv_sls files to sls_bv with namespace and class name adjustments * Remove typename from member declarations in bv_fixed class * fixing conca Signed-off-by: Nikolaj Bjorner * Add initial implementation of bit-vector SLS evaluation module in bv_sls_eval.cpp * Remove bv_sls_eval.cpp as part of code cleanup and refactoring * Refactor alignment of member variables in bv_plugin of sls namespace * Rename SLS engine related files to reflect their specific use for bit-vectors * Refactor SLS engine and evaluator components for bit-vector specifics and adjust memory manager alignment * Enhance bv_eval with use_current, lookahead strategies, and randomization improvements in SLS module * Refactor verbose logging and fix logic in range adjustment functions in sls bv modules * Remove commented verbose output in sls_bv_plugin.cpp during repair process * Add early return after setting fixed subterms in sls_bv_fixed.cpp * Remove redundant return statement in sls_bv_fixed.cpp * fixes to new value propagation Signed-off-by: Nikolaj Bjorner * Refactor sls bv evaluation and fix logic checks for bit operations * Add array plugin support and update bv_eval in ast_sls module * Add array, model value, and user sort plugins to SLS module with enhancements in array propagation logic * Refactor array_plugin in sls to improve handling of select expressions with multiple arguments * Enhance array plugin with early termination and propagation verification, and improve euf and user sort plugins with propagation adjustments and debugging enhancements * Add support for handling 'distinct' expressions in SLS context and user sort plugin * Remove model value and user sort plugins from SLS theory * replace user plugin by euf plugin Signed-off-by: Nikolaj Bjorner * remove extra file Signed-off-by: Nikolaj Bjorner * Refactor handling of term registration and enhance distinct handling in sls_euf_plugin * Add TODO list for enhancements in sls_euf_plugin.cpp * add incremental mode * updated package * fix sls build Signed-off-by: Nikolaj Bjorner * break sls build Signed-off-by: Nikolaj Bjorner * fix build * break build again * fix build Signed-off-by: Nikolaj Bjorner * fixes Signed-off-by: Nikolaj Bjorner * fixing incremental Signed-off-by: Nikolaj Bjorner * avoid units Signed-off-by: Nikolaj Bjorner * fixup handling of disequality propagation Signed-off-by: Nikolaj Bjorner * fx Signed-off-by: Nikolaj Bjorner * recover shift-weight loop Signed-off-by: Nikolaj Bjorner * alternate Signed-off-by: Nikolaj Bjorner * throttle save model Signed-off-by: Nikolaj Bjorner * allow for alternating Signed-off-by: Nikolaj Bjorner * fix test for new signature of flip Signed-off-by: Nikolaj Bjorner * bug fixes Signed-off-by: Nikolaj Bjorner * restore use of value_hash Signed-off-by: Nikolaj Bjorner * fixes Signed-off-by: Nikolaj Bjorner * adding dt plugin Signed-off-by: Nikolaj Bjorner * adt Signed-off-by: Nikolaj Bjorner * dt updates Signed-off-by: Nikolaj Bjorner * added cycle detection Signed-off-by: Nikolaj Bjorner * updated sls-datatype Signed-off-by: Nikolaj Bjorner * Refactor context management, improve datatype handling, and enhance logging in sls plugins. * axiomatize dt Signed-off-by: Nikolaj Bjorner * add missing factory plugins to model Signed-off-by: Nikolaj Bjorner * fixup finite domain search Signed-off-by: Nikolaj Bjorner * fixup finite domain search Signed-off-by: Nikolaj Bjorner * fixes Signed-off-by: Nikolaj Bjorner * redo dfs Signed-off-by: Nikolaj Bjorner * fixing model construction for underspecified operators Signed-off-by: Nikolaj Bjorner * fixes to occurs check Signed-off-by: Nikolaj Bjorner * fixup interpretation building Signed-off-by: Nikolaj Bjorner * saturate worklist Signed-off-by: Nikolaj Bjorner * delay distinct axiom Signed-off-by: Nikolaj Bjorner * adding model-based sls for datatatypes * update the interface in sls_solver to transfer phase between SAT and SLS * add value transfer option Signed-off-by: Nikolaj Bjorner * rename aux functions * Track shared variables using a unit set * debugging parallel integration * fix dirty flag setting * update log level * add plugin to smt_context, factor out sls_smt_plugin functionality. * bug fixes * fixes * use common infrastructure for sls-smt * fix build Signed-off-by: Nikolaj Bjorner * fix build Signed-off-by: Nikolaj Bjorner * remove declaration of context Signed-off-by: Nikolaj Bjorner * add virtual destructor Signed-off-by: Nikolaj Bjorner * build warnings Signed-off-by: Nikolaj Bjorner * reorder inclusion order to define smt_context before theory_sls Signed-off-by: Nikolaj Bjorner * change namespace for single threaded Signed-off-by: Nikolaj Bjorner * check delayed eqs before nla Signed-off-by: Nikolaj Bjorner * use independent completion flag for sls to avoid conflating with genuine cancelation * validate sls-arith lemmas Signed-off-by: Nikolaj Bjorner * bugfixes Signed-off-by: Nikolaj Bjorner * add intblast to legacy SMT solver * fixup model generation for theory_intblast Signed-off-by: Nikolaj Bjorner * na Signed-off-by: Nikolaj Bjorner * mk_value needs to accept more cases where integer expression doesn't evalate Signed-off-by: Nikolaj Bjorner * use th-axioms to track origins of assertions Signed-off-by: Nikolaj Bjorner * add missing operator handling for bitwise operators Signed-off-by: Nikolaj Bjorner * add missing operator handling for bitwise operators Signed-off-by: Nikolaj Bjorner * normalizing inequality Signed-off-by: Nikolaj Bjorner * add virtual destructor Signed-off-by: Nikolaj Bjorner * rework elim_unconstrained * fix non-termination Signed-off-by: Nikolaj Bjorner * use glue as computed without adjustment * update model generation to fix model bug Signed-off-by: Nikolaj Bjorner * fixes to model construction * remove package and package lock Signed-off-by: Nikolaj Bjorner * fix build warning Signed-off-by: Nikolaj Bjorner * use original gai Signed-off-by: Nikolaj Bjorner --------- Signed-off-by: Nikolaj Bjorner Signed-off-by: dependabot[bot] Signed-off-by: Lev Nachmanson Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Sergey Bronnikov Co-authored-by: Lev Nachmanson Co-authored-by: LiviaSun <33578456+ChuyueSun@users.noreply.github.com> --- .gitignore | 4 + scripts/mk_project.py | 14 +- src/CMakeLists.txt | 2 +- src/ast/arith_decl_plugin.h | 1 + src/ast/ast.cpp | 2 +- src/ast/bv_decl_plugin.cpp | 4 +- src/ast/bv_decl_plugin.h | 4 +- src/ast/datatype_decl_plugin.h | 2 + src/ast/euf/euf_egraph.cpp | 17 +- src/ast/euf/euf_egraph.h | 3 +- src/ast/rewriter/CMakeLists.txt | 1 + src/ast/rewriter/arith_rewriter.cpp | 125 +- src/ast/rewriter/arith_rewriter.h | 6 + src/ast/rewriter/bv2int_translator.cpp | 693 +++++ src/ast/rewriter/bv2int_translator.h | 84 + src/ast/rewriter/bv_rewriter.h | 10 +- src/ast/simplifiers/elim_unconstrained.cpp | 419 ++- src/ast/simplifiers/elim_unconstrained.h | 83 +- src/ast/sls/CMakeLists.txt | 23 +- src/ast/sls/bv_sls.cpp | 364 --- src/ast/sls/bv_sls.h | 129 - src/ast/sls/bv_sls_terms.cpp | 229 -- src/ast/sls/bv_sls_terms.h | 79 - src/ast/sls/bvsls_opt_engine.h | 2 +- src/{sat => ast/sls}/sat_ddfw.cpp | 350 ++- src/{sat => ast/sls}/sat_ddfw.h | 161 +- src/ast/sls/sls_arith_base.cpp | 2326 +++++++++++++++++ src/ast/sls/sls_arith_base.h | 292 +++ src/ast/sls/sls_arith_plugin.cpp | 131 + src/ast/sls/sls_arith_plugin.h | 52 + src/ast/sls/sls_array_plugin.cpp | 277 ++ src/ast/sls/sls_array_plugin.h | 90 + src/ast/sls/sls_basic_plugin.cpp | 210 ++ src/ast/sls/sls_basic_plugin.h | 58 + .../sls/{sls_engine.cpp => sls_bv_engine.cpp} | 2 +- src/ast/sls/{sls_engine.h => sls_bv_engine.h} | 4 +- .../sls/{bv_sls_eval.cpp => sls_bv_eval.cpp} | 1030 +++++--- src/ast/sls/{bv_sls_eval.h => sls_bv_eval.h} | 145 +- .../{sls_evaluator.h => sls_bv_evaluator.h} | 2 +- .../{bv_sls_fixed.cpp => sls_bv_fixed.cpp} | 341 ++- .../sls/{bv_sls_fixed.h => sls_bv_fixed.h} | 34 +- src/ast/sls/sls_bv_plugin.cpp | 206 ++ src/ast/sls/sls_bv_plugin.h | 62 + src/ast/sls/sls_bv_terms.cpp | 143 + src/ast/sls/sls_bv_terms.h | 54 + .../sls/{sls_tracker.h => sls_bv_tracker.h} | 0 ...sls_valuation.cpp => sls_bv_valuation.cpp} | 298 ++- .../{sls_valuation.h => sls_bv_valuation.h} | 25 +- src/ast/sls/sls_context.cpp | 654 +++++ src/ast/sls/sls_context.h | 212 ++ src/ast/sls/sls_datatype_plugin.cpp | 956 +++++++ src/ast/sls/sls_datatype_plugin.h | 107 + src/ast/sls/sls_euf_plugin.cpp | 489 ++++ src/ast/sls/sls_euf_plugin.h | 96 + src/ast/sls/sls_smt_plugin.cpp | 315 +++ src/ast/sls/sls_smt_plugin.h | 158 ++ src/ast/sls/sls_smt_solver.cpp | 171 ++ src/ast/sls/sls_smt_solver.h | 44 + src/math/hilbert/hilbert_basis.h | 2 +- src/model/model.cpp | 3 + src/opt/opt_context.cpp | 2 +- src/opt/opt_lns.cpp | 2 +- src/params/CMakeLists.txt | 1 + src/{sat => params}/sat_params.pyg | 0 src/params/sls_params.pyg | 2 + src/qe/qe_mbp.cpp | 15 +- src/sat/CMakeLists.txt | 3 +- src/sat/sat_config.cpp | 2 +- src/sat/sat_ddfw_wrapper.cpp | 85 + src/sat/sat_ddfw_wrapper.h | 89 + src/sat/sat_local_search.cpp | 2 +- src/sat/sat_solver.cpp | 9 +- src/sat/sat_solver.h | 6 +- src/sat/sat_solver/inc_sat_solver.cpp | 2 +- src/sat/sat_solver/sat_smt_solver.cpp | 7 +- src/sat/smt/CMakeLists.txt | 2 - src/sat/smt/arith_axioms.cpp | 1 + src/sat/smt/arith_sls.cpp | 642 ----- src/sat/smt/arith_sls.h | 169 -- src/sat/smt/arith_solver.cpp | 29 +- src/sat/smt/arith_solver.h | 7 +- src/sat/smt/bv_ackerman.cpp | 4 +- src/sat/smt/euf_internalize.cpp | 6 +- src/sat/smt/euf_local_search.cpp | 50 - src/sat/smt/euf_proof_checker.cpp | 2 +- src/sat/smt/euf_solver.cpp | 1 - src/sat/smt/euf_solver.h | 19 +- src/sat/smt/intblast_solver.cpp | 708 +---- src/sat/smt/intblast_solver.h | 52 +- src/sat/smt/sat_th.h | 5 - src/sat/smt/sls_solver.cpp | 184 +- src/sat/smt/sls_solver.h | 40 +- src/sat/tactic/goal2sat.cpp | 15 +- src/sat/tactic/sat2goal.cpp | 2 +- src/sat/tactic/sat_tactic.cpp | 3 +- src/shell/dimacs_frontend.cpp | 2 +- src/smt/CMakeLists.txt | 2 + src/smt/smt_context.cpp | 23 +- src/smt/smt_context.h | 11 + src/smt/smt_internalizer.cpp | 4 + src/smt/smt_setup.cpp | 22 +- src/smt/smt_setup.h | 1 + src/smt/theory_intblast.cpp | 191 ++ src/smt/theory_intblast.h | 73 + src/smt/theory_lra.cpp | 191 +- src/smt/theory_sls.cpp | 133 + src/smt/theory_sls.h | 93 + src/tactic/portfolio/smt_strategic_solver.cpp | 2 +- src/tactic/sls/sls_tactic.cpp | 200 +- src/tactic/sls/sls_tactic.h | 10 +- src/tactic/smtlogics/smt_tactic.cpp | 6 +- src/test/dlist.cpp | 4 +- src/test/sls_test.cpp | 56 +- src/util/checked_int64.h | 178 +- src/util/mpz.cpp | 1 + src/util/rlimit.cpp | 13 + src/util/rlimit.h | 1 + src/util/sat_sls.h | 41 + src/util/util.h | 10 + src/util/vector.h | 4 + 120 files changed, 11132 insertions(+), 4108 deletions(-) create mode 100644 src/ast/rewriter/bv2int_translator.cpp create mode 100644 src/ast/rewriter/bv2int_translator.h delete mode 100644 src/ast/sls/bv_sls.cpp delete mode 100644 src/ast/sls/bv_sls.h delete mode 100644 src/ast/sls/bv_sls_terms.cpp delete mode 100644 src/ast/sls/bv_sls_terms.h rename src/{sat => ast/sls}/sat_ddfw.cpp (71%) rename src/{sat => ast/sls}/sat_ddfw.h (60%) create mode 100644 src/ast/sls/sls_arith_base.cpp create mode 100644 src/ast/sls/sls_arith_base.h create mode 100644 src/ast/sls/sls_arith_plugin.cpp create mode 100644 src/ast/sls/sls_arith_plugin.h create mode 100644 src/ast/sls/sls_array_plugin.cpp create mode 100644 src/ast/sls/sls_array_plugin.h create mode 100644 src/ast/sls/sls_basic_plugin.cpp create mode 100644 src/ast/sls/sls_basic_plugin.h rename src/ast/sls/{sls_engine.cpp => sls_bv_engine.cpp} (99%) rename src/ast/sls/{sls_engine.h => sls_bv_engine.h} (97%) rename src/ast/sls/{bv_sls_eval.cpp => sls_bv_eval.cpp} (66%) rename src/ast/sls/{bv_sls_eval.h => sls_bv_eval.h} (62%) rename src/ast/sls/{sls_evaluator.h => sls_bv_evaluator.h} (99%) rename src/ast/sls/{bv_sls_fixed.cpp => sls_bv_fixed.cpp} (60%) rename src/ast/sls/{bv_sls_fixed.h => sls_bv_fixed.h} (57%) create mode 100644 src/ast/sls/sls_bv_plugin.cpp create mode 100644 src/ast/sls/sls_bv_plugin.h create mode 100644 src/ast/sls/sls_bv_terms.cpp create mode 100644 src/ast/sls/sls_bv_terms.h rename src/ast/sls/{sls_tracker.h => sls_bv_tracker.h} (100%) rename src/ast/sls/{sls_valuation.cpp => sls_bv_valuation.cpp} (70%) rename src/ast/sls/{sls_valuation.h => sls_bv_valuation.h} (94%) create mode 100644 src/ast/sls/sls_context.cpp create mode 100644 src/ast/sls/sls_context.h create mode 100644 src/ast/sls/sls_datatype_plugin.cpp create mode 100644 src/ast/sls/sls_datatype_plugin.h create mode 100644 src/ast/sls/sls_euf_plugin.cpp create mode 100644 src/ast/sls/sls_euf_plugin.h create mode 100644 src/ast/sls/sls_smt_plugin.cpp create mode 100644 src/ast/sls/sls_smt_plugin.h create mode 100644 src/ast/sls/sls_smt_solver.cpp create mode 100644 src/ast/sls/sls_smt_solver.h rename src/{sat => params}/sat_params.pyg (100%) create mode 100644 src/sat/sat_ddfw_wrapper.cpp create mode 100644 src/sat/sat_ddfw_wrapper.h delete mode 100644 src/sat/smt/arith_sls.cpp delete mode 100644 src/sat/smt/arith_sls.h delete mode 100644 src/sat/smt/euf_local_search.cpp create mode 100644 src/smt/theory_intblast.cpp create mode 100644 src/smt/theory_intblast.h create mode 100644 src/smt/theory_sls.cpp create mode 100644 src/smt/theory_sls.h create mode 100644 src/util/sat_sls.h diff --git a/.gitignore b/.gitignore index cb3d0190c24..47c6e392326 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ rebase.cmd callgrind.out.* # .hpp files are automatically generated *.hpp +.env .z3-trace .env .genaiscript @@ -28,6 +29,8 @@ ocamlz3 # Emacs temp files \#*\# # Directories with generated code and documentation +node_modules/* +.genaiscript/* release/* build/* trace/* @@ -105,3 +108,4 @@ CMakeSettings.json .DS_Store dbg/** *.wsp +CppProperties.json diff --git a/scripts/mk_project.py b/scripts/mk_project.py index e0fd260ba34..6399552e8db 100644 --- a/scripts/mk_project.py +++ b/scripts/mk_project.py @@ -28,19 +28,19 @@ def init_project_def(): add_lib('parser_util', ['ast'], 'parsers/util') add_lib('euf', ['ast'], 'ast/euf') add_lib('grobner', ['ast', 'dd', 'simplex'], 'math/grobner') - add_lib('sat', ['params', 'util', 'dd', 'grobner']) - add_lib('nlsat', ['polynomial', 'sat']) - add_lib('lp', ['util', 'nlsat', 'grobner', 'interval', 'smt_params'], 'math/lp') add_lib('rewriter', ['ast', 'polynomial', 'interval', 'automata', 'params'], 'ast/rewriter') - add_lib('bit_blaster', ['rewriter'], 'ast/rewriter/bit_blaster') add_lib('normal_forms', ['rewriter'], 'ast/normal_forms') - add_lib('substitution', ['rewriter'], 'ast/substitution') - add_lib('proofs', ['rewriter'], 'ast/proofs') add_lib('macros', ['rewriter'], 'ast/macros') add_lib('model', ['macros']) add_lib('converters', ['model'], 'ast/converters') + add_lib('ast_sls', ['ast','normal_forms','converters','smt_params','euf'], 'ast/sls') + add_lib('sat', ['params', 'util', 'dd', 'ast_sls', 'grobner']) + add_lib('nlsat', ['polynomial', 'sat']) + add_lib('lp', ['util', 'nlsat', 'grobner', 'interval', 'smt_params'], 'math/lp') + add_lib('bit_blaster', ['rewriter'], 'ast/rewriter/bit_blaster') + add_lib('substitution', ['rewriter'], 'ast/substitution') + add_lib('proofs', ['rewriter'], 'ast/proofs') add_lib('simplifiers', ['euf', 'normal_forms', 'bit_blaster', 'converters', 'substitution'], 'ast/simplifiers') - add_lib('ast_sls', ['ast','normal_forms','converters'], 'ast/sls') add_lib('tactic', ['simplifiers']) add_lib('mbp', ['model', 'simplex'], 'qe/mbp') add_lib('qe_lite', ['tactic', 'mbp'], 'qe/lite') diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 4c09f31aa86..5faede21ff7 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -54,7 +54,6 @@ add_subdirectory(ast/euf) add_subdirectory(ast/converters) add_subdirectory(ast/substitution) add_subdirectory(ast/simplifiers) -add_subdirectory(ast/sls) add_subdirectory(tactic) add_subdirectory(qe/mbp) add_subdirectory(qe/lite) @@ -74,6 +73,7 @@ add_subdirectory(parsers/smt2) add_subdirectory(solver/assertions) add_subdirectory(ast/pattern) add_subdirectory(math/lp) +add_subdirectory(ast/sls) add_subdirectory(sat/smt) add_subdirectory(sat/tactic) add_subdirectory(nlsat/tactic) diff --git a/src/ast/arith_decl_plugin.h b/src/ast/arith_decl_plugin.h index 3f094d43f08..275d39cf10e 100644 --- a/src/ast/arith_decl_plugin.h +++ b/src/ast/arith_decl_plugin.h @@ -365,6 +365,7 @@ class arith_recognizers { MATCH_BINARY(is_div0); MATCH_BINARY(is_idiv0); MATCH_BINARY(is_power); + MATCH_BINARY(is_power0); MATCH_UNARY(is_sin); MATCH_UNARY(is_asin); diff --git a/src/ast/ast.cpp b/src/ast/ast.cpp index 2d6f6f9b2e6..c4f87e5da98 100644 --- a/src/ast/ast.cpp +++ b/src/ast/ast.cpp @@ -1714,7 +1714,7 @@ ast * ast_manager::register_node_core(ast * n) { n->m_id = is_decl(n) ? m_decl_id_gen.mk() : m_expr_id_gen.mk(); -// track_id(*this, n, 77); + // track_id(*this, n, 9213); // TRACE("ast", tout << (s_count++) << " Object " << n->m_id << " was created.\n";); TRACE("mk_var_bug", tout << "mk_ast: " << n->m_id << "\n";); diff --git a/src/ast/bv_decl_plugin.cpp b/src/ast/bv_decl_plugin.cpp index 8c9a5652fc1..ded11804752 100644 --- a/src/ast/bv_decl_plugin.cpp +++ b/src/ast/bv_decl_plugin.cpp @@ -932,13 +932,13 @@ unsigned bv_util::get_int2bv_size(parameter const& p) { return static_cast(sz); } -app * bv_util::mk_bv2int(expr* e) { +app * bv_util::mk_bv2int(expr* e) const { sort* s = m_manager.mk_sort(m_manager.mk_family_id("arith"), INT_SORT); parameter p(s); return m_manager.mk_app(get_fid(), OP_BV2INT, 1, &p, 1, &e); } -app* bv_util::mk_int2bv(unsigned sz, expr* e) { +app* bv_util::mk_int2bv(unsigned sz, expr* e) const { parameter p(sz); return m_manager.mk_app(get_fid(), OP_INT2BV, 1, &p, 1, &e); } diff --git a/src/ast/bv_decl_plugin.h b/src/ast/bv_decl_plugin.h index 58445afda47..b8dde9361de 100644 --- a/src/ast/bv_decl_plugin.h +++ b/src/ast/bv_decl_plugin.h @@ -549,8 +549,8 @@ class bv_util : public bv_recognizers { app * mk_bv_ashr(expr* arg1, expr* arg2) { return m_manager.mk_app(get_fid(), OP_BASHR, arg1, arg2); } app * mk_bv_lshr(expr* arg1, expr* arg2) { return m_manager.mk_app(get_fid(), OP_BLSHR, arg1, arg2); } - app * mk_bv2int(expr* e); - app * mk_int2bv(unsigned sz, expr* e); + app * mk_bv2int(expr* e) const; + app * mk_int2bv(unsigned sz, expr* e) const; app* mk_bv_rotate_left(expr* arg1, expr* arg2) { return m_manager.mk_app(get_fid(), OP_EXT_ROTATE_LEFT, arg1, arg2); } app* mk_bv_rotate_right(expr* arg1, expr* arg2) { return m_manager.mk_app(get_fid(), OP_EXT_ROTATE_RIGHT, arg1, arg2); } diff --git a/src/ast/datatype_decl_plugin.h b/src/ast/datatype_decl_plugin.h index dcca7897060..48033d129ae 100644 --- a/src/ast/datatype_decl_plugin.h +++ b/src/ast/datatype_decl_plugin.h @@ -341,8 +341,10 @@ namespace datatype { ast_manager & get_manager() const { return m; } // sort * mk_datatype_sort(symbol const& name, unsigned n, sort* const* params); bool is_datatype(sort const* s) const { return is_sort_of(s, fid(), DATATYPE_SORT); } + bool is_datatype(expr* e) const { return is_datatype(e->get_sort()); } bool is_enum_sort(sort* s); bool is_recursive(sort * ty); + bool is_recursive(expr* e) { return is_recursive(e->get_sort()); } bool is_recursive_nested(sort * ty); bool is_constructor(func_decl * f) const { return is_decl_of(f, fid(), OP_DT_CONSTRUCTOR); } bool is_recognizer(func_decl * f) const { return is_recognizer0(f) || is_is(f); } diff --git a/src/ast/euf/euf_egraph.cpp b/src/ast/euf/euf_egraph.cpp index eaa290bbf96..a2bd2df4f31 100644 --- a/src/ast/euf/euf_egraph.cpp +++ b/src/ast/euf/euf_egraph.cpp @@ -82,8 +82,11 @@ namespace euf { void egraph::reinsert_equality(enode* p) { SASSERT(p->is_equality()); - if (p->value() != l_true && p->get_arg(0)->get_root() == p->get_arg(1)->get_root()) + if (p->value() != l_true && p->get_arg(0)->get_root() == p->get_arg(1)->get_root()) { queue_literal(p, nullptr); + if (p->value() == l_false && !m_on_propagate_literal) + set_conflict(p->get_arg(0), p->get_arg(1), p->m_lit_justification); + } } void egraph::queue_literal(enode* p, enode* ante) { @@ -201,6 +204,18 @@ namespace euf { } } + void egraph::new_diseq(enode* n, void* reason) { + force_push(); + SASSERT(m.is_eq(n->get_expr())); + auto j = justification::external(reason); + auto a = n->get_arg(0), b = n->get_arg(1); + auto r1 = a->get_root(), r2 = b->get_root(); + if (r1 == r2) + set_conflict(a, b, j); + else + set_value(n, l_false, j); + } + void egraph::new_diseq(enode* n) { SASSERT(n->is_equality()); SASSERT(n->value() == l_false); diff --git a/src/ast/euf/euf_egraph.h b/src/ast/euf/euf_egraph.h index 8822b07e793..4280b478060 100644 --- a/src/ast/euf/euf_egraph.h +++ b/src/ast/euf/euf_egraph.h @@ -278,10 +278,11 @@ namespace euf { */ void merge(enode* n1, enode* n2, void* reason) { merge(n1, n2, justification::external(reason)); } void new_diseq(enode* n); + void new_diseq(enode* n, void* reason); /** - \brief propagate set of merges. + \brief propagate set of merges. This call may detect an inconsistency. Then inconsistent() is true. Use then explain() to extract an explanation for the conflict. diff --git a/src/ast/rewriter/CMakeLists.txt b/src/ast/rewriter/CMakeLists.txt index 7f351ecb652..7822a370c92 100644 --- a/src/ast/rewriter/CMakeLists.txt +++ b/src/ast/rewriter/CMakeLists.txt @@ -4,6 +4,7 @@ z3_add_component(rewriter array_rewriter.cpp ast_counter.cpp bit2int.cpp + bv2int_translator.cpp bool_rewriter.cpp bv_bounds.cpp bv_elim.cpp diff --git a/src/ast/rewriter/arith_rewriter.cpp b/src/ast/rewriter/arith_rewriter.cpp index f21a5c4be68..b67e873c002 100644 --- a/src/ast/rewriter/arith_rewriter.cpp +++ b/src/ast/rewriter/arith_rewriter.cpp @@ -515,6 +515,129 @@ br_status arith_rewriter::reduce_power(expr * arg1, expr * arg2, op_kind kind, e } } +bool arith_rewriter::is_mul_factor(expr* s, expr* t) { + if (m_util.is_mul(t)) + return any_of(*to_app(t), [&](expr* m) { return is_mul_factor(s, m); }); + return s == t; +} + +bool arith_rewriter::is_add_factor(expr* s, expr* t) { + if (m_util.is_add(t)) + return all_of(*to_app(t), [&](expr* f) { return is_add_factor(s, f); }); + return is_mul_factor(s, t); +} + +expr_ref arith_rewriter::remove_factor(expr* s, expr* t) { + + if (m_util.is_mul(t)) { + ptr_buffer r; + r.push_back(t); + for (unsigned i = 0; i < r.size(); ++i) { + expr* arg = r[i]; + if (m_util.is_mul(arg)) { + r.append(to_app(arg)->get_num_args(), to_app(arg)->get_args()); + r[i] = r.back(); + r.pop_back(); + --i; + continue; + } + if (s == arg) { + r[i] = r.back(); + r.pop_back(); + break; + } + } + switch (r.size()) { + case 0: + return expr_ref(m_util.mk_numeral(rational(1), m_util.is_int(t)), m); + case 1: + return expr_ref(r[0], m); + default: + return expr_ref(m_util.mk_mul(r.size(), r.data()), m); + } + } + if (m_util.is_add(t)) { + expr_ref_vector sum(m); + sum.push_back(t); + for (unsigned i = 0; i < sum.size(); ++i) { + expr* arg = sum.get(i); + if (m_util.is_add(arg)) { + sum.append(to_app(arg)->get_num_args(), to_app(arg)->get_args()); + sum[i] = sum.back(); + sum.pop_back(); + --i; + continue; + } + sum[i] = remove_factor(s, arg); + } + if (sum.size() == 1) + return expr_ref(sum.get(0), m); + else + return expr_ref(m_util.mk_add(sum.size(), sum.data()), m); + } + else { + SASSERT(s == t); + return expr_ref(m_util.mk_numeral(rational(1), m_util.is_int(t)), m); + } +} + + +void arith_rewriter::get_nl_muls(expr* t, ptr_buffer& muls) { + if (m_util.is_mul(t)) { + for (expr* arg : *to_app(t)) + get_nl_muls(arg, muls); + } + else if (!m_util.is_numeral(t)) + muls.push_back(t); +} + +expr* arith_rewriter::find_nl_factor(expr* t) { + ptr_buffer sum, muls; + sum.push_back(t); + + for (unsigned i = 0; i < sum.size(); ++i) { + expr* arg = sum[i]; + if (m_util.is_add(arg)) + sum.append(to_app(arg)->get_num_args(), to_app(arg)->get_args()); + else if (m_util.is_mul(arg)) { + muls.reset(); + get_nl_muls(arg, muls); + if (muls.size() <= 1) + continue; + for (auto m : muls) { + if (is_add_factor(m, t)) + return m; + } + return nullptr; + } + } + return nullptr; +} + +br_status arith_rewriter::factor_le_ge_eq(expr * arg1, expr * arg2, op_kind kind, expr_ref & result) { + + if (is_zero(arg2)) { + expr* f = find_nl_factor(arg1); + if (!f) + return BR_FAILED; + expr_ref f2 = remove_factor(f, arg1); + expr* z = m_util.mk_numeral(rational(0), m_util.is_int(arg1)); + result = m.mk_or(m_util.mk_eq(f, z), m_util.mk_eq(f2, z)); + switch (kind) { + case EQ: + break; + case GE: + result = m.mk_or(m.mk_iff(m_util.mk_ge(f, z), m_util.mk_ge(f2, z)), result); + break; + case LE: + result = m.mk_or(m.mk_not(m.mk_iff(m_util.mk_ge(f, z), m_util.mk_ge(f2, z))), result); + break; + } + return BR_REWRITE3; + } + return BR_FAILED; +} + br_status arith_rewriter::mk_le_ge_eq_core(expr * arg1, expr * arg2, op_kind kind, expr_ref & result) { expr *orig_arg1 = arg1, *orig_arg2 = arg2; expr_ref new_arg1(m); @@ -655,7 +778,7 @@ br_status arith_rewriter::mk_le_ge_eq_core(expr * arg1, expr * arg2, op_kind kin default: result = m.mk_eq(arg1, arg2); return BR_DONE; } } - return BR_FAILED; + return factor_le_ge_eq(arg1, arg2, kind, result); } diff --git a/src/ast/rewriter/arith_rewriter.h b/src/ast/rewriter/arith_rewriter.h index 01fea0ac7f0..a1aadfa7f10 100644 --- a/src/ast/rewriter/arith_rewriter.h +++ b/src/ast/rewriter/arith_rewriter.h @@ -73,6 +73,12 @@ class arith_rewriter : public poly_rewriter { br_status is_separated(expr * arg1, expr * arg2, op_kind kind, expr_ref & result); bool is_non_negative(expr* e); br_status mk_le_ge_eq_core(expr * arg1, expr * arg2, op_kind kind, expr_ref & result); + bool is_add_factor(expr* s, expr* t); + bool is_mul_factor(expr* s, expr* t); + expr* find_nl_factor(expr* t); + void get_nl_muls(expr* t, ptr_buffer& muls); + expr_ref remove_factor(expr* s, expr* t); + br_status factor_le_ge_eq(expr * arg1, expr * arg2, op_kind kind, expr_ref & result); bool elim_to_real_var(expr * var, expr_ref & new_var); bool elim_to_real_mon(expr * monomial, expr_ref & new_monomial); diff --git a/src/ast/rewriter/bv2int_translator.cpp b/src/ast/rewriter/bv2int_translator.cpp new file mode 100644 index 00000000000..4e49e85e4e2 --- /dev/null +++ b/src/ast/rewriter/bv2int_translator.cpp @@ -0,0 +1,693 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + bv2int_translator + +Author: + + Nikolaj Bjorner (nbjorner) 2024-10-27 + +--*/ + +#include "ast/ast.h" +#include "ast/arith_decl_plugin.h" +#include "ast/bv_decl_plugin.h" +#include "ast/rewriter/bv2int_translator.h" +#include "ast/ast_ll_pp.h" +#include "ast/ast_pp.h" + +bv2int_translator::bv2int_translator(ast_manager& m, bv2int_translator_trail& ctx) : + m(m), + ctx(ctx), + bv(m), + a(m), + m_translate(m), + m_args(m), + m_pinned(m), + m_vars(m), + m_preds(m) +{} + +void bv2int_translator::reset(bool is_plugin) { + m_vars.reset(); + m_preds.reset(); + for (unsigned i = m_translate.size(); i-- > 0; ) + m_translate[i] = nullptr; + m_is_plugin = is_plugin; +} + + +void bv2int_translator::set_translated(expr* e, expr* r) { + SASSERT(r); + SASSERT(!is_translated(e)); + m_translate.setx(e->get_id(), r); + ctx.push_idx(set_vector_idx_trail(m_translate, e->get_id())); +} + +void bv2int_translator::internalize_bv(app* e) { + ensure_translated(e); + if (m.is_bool(e)) { + m_preds.push_back(e); + ctx.push(push_back_vector(m_preds)); + } +} + +void bv2int_translator::ensure_translated(expr* e) { + if (m_translate.get(e->get_id(), nullptr)) + return; + ptr_vector todo; + ast_fast_mark1 visited; + todo.push_back(e); + visited.mark(e); + for (unsigned i = 0; i < todo.size(); ++i) { + expr* e = todo[i]; + if (!is_app(e)) + continue; + app* a = to_app(e); + if (m.is_bool(e) && a->get_family_id() != bv.get_family_id()) + continue; + for (auto arg : *a) + if (!visited.is_marked(arg) && !m_translate.get(arg->get_id(), nullptr)) { + visited.mark(arg); + todo.push_back(arg); + } + } + std::stable_sort(todo.begin(), todo.end(), [&](expr* a, expr* b) { return get_depth(a) < get_depth(b); }); + for (expr* e : todo) + translate_expr(e); +} + +rational bv2int_translator::bv_size(expr* bv_expr) { + return rational::power_of_two(bv.get_bv_size(bv_expr->get_sort())); +} + +void bv2int_translator::translate_expr(expr* e) { + if (is_quantifier(e)) + translate_quantifier(to_quantifier(e)); + else if (is_var(e)) + translate_var(to_var(e)); + else { + app* ap = to_app(e); + if (m_is_plugin && ap->get_family_id() == basic_family_id && m.is_bool(ap)) { + set_translated(e, e); + return; + } + m_args.reset(); + for (auto arg : *ap) + m_args.push_back(translated(arg)); + + if (ap->get_family_id() == basic_family_id) + translate_basic(ap); + else if (ap->get_family_id() == bv.get_family_id()) + translate_bv(ap); + else + translate_app(ap); + } +} + +void bv2int_translator::translate_quantifier(quantifier* q) { + if (m_is_plugin) { + set_translated(q, q); + return; + } + if (is_lambda(q)) + throw default_exception("lambdas are not supported in intblaster"); + expr* b = q->get_expr(); + unsigned nd = q->get_num_decls(); + ptr_vector sorts; + for (unsigned i = 0; i < nd; ++i) { + auto s = q->get_decl_sort(i); + if (bv.is_bv_sort(s)) { + NOT_IMPLEMENTED_YET(); + sorts.push_back(a.mk_int()); + } + else + sorts.push_back(s); + } + b = translated(b); + // TODO if sorts contain integer, then created bounds variables. + set_translated(q, m.update_quantifier(q, b)); +} + +void bv2int_translator::translate_var(var* v) { + if (bv.is_bv_sort(v->get_sort())) + set_translated(v, m.mk_var(v->get_idx(), a.mk_int())); + else + set_translated(v, v); +} + +// Translate functions that are not built-in or bit-vectors. +// Base method uses fresh functions. +// Other method could use bv2int, int2bv axioms and coercions. +// f(args) = bv2int(f(int2bv(args')) +// + +void bv2int_translator::translate_app(app* e) { + + if (m_is_plugin && m.is_bool(e)) { + set_translated(e, e); + return; + } + + bool has_bv_sort = bv.is_bv(e); + func_decl* f = e->get_decl(); + + for (unsigned i = 0; i < m_args.size(); ++i) + if (bv.is_bv(e->get_arg(i))) + m_args[i] = bv.mk_int2bv(bv.get_bv_size(e->get_arg(i)), m_args.get(i)); + + if (has_bv_sort) + m_vars.push_back(e); + if (m_is_plugin) { + expr* r = m.mk_app(f, m_args); + if (has_bv_sort) { + ctx.push(push_back_vector(m_vars)); + r = bv.mk_bv2int(r); + } + set_translated(e, r); + return; + } + else if (has_bv_sort) { + if (f->get_family_id() != null_family_id) + throw default_exception("conversion for interpreted functions is not supported by intblast solver"); + func_decl* g = nullptr; + if (!m_new_funs.find(f, g)) { + g = m.mk_fresh_func_decl(e->get_decl()->get_name(), symbol("bv"), f->get_arity(), f->get_domain(), a.mk_int()); + m_new_funs.insert(f, g); + } + f = g; + m_pinned.push_back(f); + } + set_translated(e, m.mk_app(f, m_args)); +} + +expr_ref bv2int_translator::mk_le(expr* x, expr* y) { + if (a.is_numeral(y)) + return expr_ref(a.mk_le(x, y), m); + if (a.is_numeral(x)) + return expr_ref(a.mk_ge(y, x), m); + return expr_ref(a.mk_le(a.mk_sub(x, y), a.mk_numeral(rational(0), x->get_sort())), m); +} + +expr_ref bv2int_translator::mk_lt(expr* x, expr* y) { + return expr_ref(m.mk_not(mk_le(y, x)), m); +} + + + +void bv2int_translator::translate_bv(app* e) { + + auto bnot = [&](expr* e) { + return a.mk_sub(a.mk_int(-1), e); + }; + + auto band = [&](expr_ref_vector const& args) { + expr* r = arg(0); + for (unsigned i = 1; i < args.size(); ++i) + r = a.mk_band(bv.get_bv_size(e), r, arg(i)); + return r; + }; + + auto rotate_left = [&](unsigned n) { + auto sz = bv.get_bv_size(e); + n = n % sz; + expr* r = arg(0); + if (n != 0 && sz != 1) { + // r[sz - n - 1 : 0] ++ r[sz - 1 : sz - n] + // r * 2^(sz - n) + (r div 2^n) mod 2^(sz - n)??? + // r * A + (r div B) mod A + auto N = bv_size(e); + auto A = rational::power_of_two(sz - n); + auto B = rational::power_of_two(n); + auto hi = mul(r, a.mk_int(A)); + auto lo = amod(e, a.mk_idiv(umod(e, 0), a.mk_int(B)), A); + r = add(hi, lo); + } + return r; + }; + + expr* bv_expr = e; + expr_ref r(m); + auto const& args = m_args; + switch (e->get_decl_kind()) { + case OP_BADD: + r = a.mk_add(args); + break; + case OP_BSUB: + r = a.mk_sub(args.size(), args.data()); + break; + case OP_BMUL: + r = a.mk_mul(args); + break; + case OP_ULEQ: + bv_expr = e->get_arg(0); + r = mk_le(umod(bv_expr, 0), umod(bv_expr, 1)); + break; + case OP_UGEQ: + bv_expr = e->get_arg(0); + r = mk_ge(umod(bv_expr, 0), umod(bv_expr, 1)); + break; + case OP_ULT: + bv_expr = e->get_arg(0); + r = mk_lt(umod(bv_expr, 0), umod(bv_expr, 1)); + break; + case OP_UGT: + bv_expr = e->get_arg(0); + r = mk_gt(umod(bv_expr, 0), umod(bv_expr, 1)); + break; + case OP_SLEQ: + bv_expr = e->get_arg(0); + r = mk_le(smod(bv_expr, 0), smod(bv_expr, 1)); + break; + case OP_SGEQ: + bv_expr = e->get_arg(0); + r = mk_ge(smod(bv_expr, 0), smod(bv_expr, 1)); + break; + case OP_SLT: + bv_expr = e->get_arg(0); + r = mk_lt(smod(bv_expr, 0), smod(bv_expr, 1)); + break; + case OP_SGT: + bv_expr = e->get_arg(0); + r = mk_gt(smod(bv_expr, 0), smod(bv_expr, 1)); + break; + case OP_BNEG: + r = a.mk_uminus(arg(0)); + break; + case OP_CONCAT: { + unsigned sz = 0; + expr_ref new_arg(m); + for (unsigned i = args.size(); i-- > 0;) { + expr* old_arg = e->get_arg(i); + new_arg = umod(old_arg, i); + if (sz > 0) { + new_arg = mul(new_arg, a.mk_int(rational::power_of_two(sz))); + r = add(r, new_arg); + } + else + r = new_arg; + sz += bv.get_bv_size(old_arg->get_sort()); + } + break; + } + case OP_EXTRACT: { + unsigned lo, hi; + expr* old_arg; + VERIFY(bv.is_extract(e, lo, hi, old_arg)); + r = arg(0); + if (lo > 0) + r = a.mk_idiv(r, a.mk_int(rational::power_of_two(lo))); + break; + } + case OP_BV_NUM: { + rational val; + unsigned sz; + VERIFY(bv.is_numeral(e, val, sz)); + r = a.mk_int(val); + break; + } + case OP_BUREM: + case OP_BUREM_I: { + expr* x = umod(e, 0), * y = umod(e, 1); + r = if_eq(y, 0, x, a.mk_mod(x, y)); + break; + } + case OP_BUDIV: + case OP_BUDIV_I: { + expr* x = umod(e, 0), * y = umod(e, 1); + r = if_eq(y, 0, a.mk_int(-1), a.mk_idiv(x, y)); + break; + } + case OP_BUMUL_NO_OVFL: { + bv_expr = e->get_arg(0); + r = mk_lt(mul(umod(bv_expr, 0), umod(bv_expr, 1)), a.mk_int(bv_size(bv_expr))); + break; + } + case OP_BSHL: { + if (!a.is_numeral(arg(0)) && !a.is_numeral(arg(1))) + r = a.mk_shl(bv.get_bv_size(e), arg(0), arg(1)); + else { + expr* x = arg(0), * y = umod(e, 1); + r = a.mk_int(0); + for (unsigned i = 0; i < bv.get_bv_size(e); ++i) + r = if_eq(y, i, mul(x, a.mk_int(rational::power_of_two(i))), r); + } + break; + } + case OP_BNOT: + r = bnot(arg(0)); + break; + case OP_BLSHR: + if (!a.is_numeral(arg(0)) && !a.is_numeral(arg(1))) + r = a.mk_lshr(bv.get_bv_size(e), arg(0), arg(1)); + else { + expr* x = arg(0), * y = umod(e, 1); + r = a.mk_int(0); + IF_VERBOSE(4, verbose_stream() << "lshr " << mk_bounded_pp(e, m) << " " << bv.get_bv_size(e) << "\n"); + for (unsigned i = 0; i < bv.get_bv_size(e); ++i) + r = if_eq(y, i, a.mk_idiv(x, a.mk_int(rational::power_of_two(i))), r); + } + break; + case OP_BASHR: + if (!a.is_numeral(arg(1))) + r = a.mk_ashr(bv.get_bv_size(e), arg(0), arg(1)); + else { + + // + // ashr(x, y) + // if y = k & x >= 0 -> x / 2^k + // if y = k & x < 0 -> (x / 2^k) - 2^{N-k} + // + unsigned sz = bv.get_bv_size(e); + rational N = bv_size(e); + expr* x = umod(e, 0), * y = umod(e, 1); + expr* signx = a.mk_ge(x, a.mk_int(N / 2)); + r = m.mk_ite(signx, a.mk_int(-1), a.mk_int(0)); + IF_VERBOSE(4, verbose_stream() << "ashr " << mk_bounded_pp(e, m) << " " << bv.get_bv_size(e) << "\n"); + for (unsigned i = 0; i < sz; ++i) { + expr* d = a.mk_idiv(x, a.mk_int(rational::power_of_two(i))); + r = if_eq(y, i, + m.mk_ite(signx, add(d, a.mk_int(-rational::power_of_two(sz - i))), d), + r); + } + } + break; + case OP_BOR: + // p | q := (p + q) - band(p, q) + IF_VERBOSE(4, verbose_stream() << "bor " << mk_bounded_pp(e, m) << " " << bv.get_bv_size(e) << "\n"); + r = arg(0); + for (unsigned i = 1; i < args.size(); ++i) + r = a.mk_sub(add(r, arg(i)), a.mk_band(bv.get_bv_size(e), r, arg(i))); + break; + case OP_BNAND: + r = bnot(band(args)); + break; + case OP_BAND: + IF_VERBOSE(4, verbose_stream() << "band " << mk_bounded_pp(e, m) << " " << bv.get_bv_size(e) << "\n"); + r = band(args); + break; + case OP_BXNOR: + case OP_BXOR: { + // p ^ q := (p + q) - 2*band(p, q); + unsigned sz = bv.get_bv_size(e); + IF_VERBOSE(4, verbose_stream() << "bxor " << bv.get_bv_size(e) << "\n"); + r = arg(0); + for (unsigned i = 1; i < args.size(); ++i) { + expr* q = arg(i); + r = a.mk_sub(add(r, q), mul(a.mk_int(2), a.mk_band(sz, r, q))); + } + if (e->get_decl_kind() == OP_BXNOR) + r = bnot(r); + break; + } + case OP_ZERO_EXT: + bv_expr = e->get_arg(0); + r = umod(bv_expr, 0); + SASSERT(bv.get_bv_size(e) >= bv.get_bv_size(bv_expr)); + break; + case OP_SIGN_EXT: { + bv_expr = e->get_arg(0); + r = umod(bv_expr, 0); + SASSERT(bv.get_bv_size(e) >= bv.get_bv_size(bv_expr)); + unsigned arg_sz = bv.get_bv_size(bv_expr); + //unsigned sz = bv.get_bv_size(e); + // rational N = rational::power_of_two(sz); + rational M = rational::power_of_two(arg_sz); + expr* signbit = a.mk_ge(r, a.mk_int(M / 2)); + r = m.mk_ite(signbit, a.mk_sub(r, a.mk_int(M)), r); + break; + } + case OP_INT2BV: + m_int2bv.push_back(e); + ctx.push(push_back_vector(m_int2bv)); + r = arg(0); + break; + case OP_BV2INT: + m_bv2int.push_back(e); + ctx.push(push_back_vector(m_bv2int)); + r = umod(e->get_arg(0), 0); + break; + case OP_BCOMP: + bv_expr = e->get_arg(0); + r = m.mk_ite(m.mk_eq(umod(bv_expr, 0), umod(bv_expr, 1)), a.mk_int(1), a.mk_int(0)); + break; + case OP_BSMOD_I: + case OP_BSMOD: { + expr* x = umod(e, 0), * y = umod(e, 1); + rational N = bv_size(e); + expr* signx = a.mk_ge(x, a.mk_int(N / 2)); + expr* signy = a.mk_ge(y, a.mk_int(N / 2)); + expr* u = a.mk_mod(x, y); + // u = 0 -> 0 + // y = 0 -> x + // x < 0, y < 0 -> -u + // x < 0, y >= 0 -> y - u + // x >= 0, y < 0 -> y + u + // x >= 0, y >= 0 -> u + r = a.mk_uminus(u); + r = m.mk_ite(m.mk_and(m.mk_not(signx), signy), add(u, y), r); + r = m.mk_ite(m.mk_and(signx, m.mk_not(signy)), a.mk_sub(y, u), r); + r = m.mk_ite(m.mk_and(m.mk_not(signx), m.mk_not(signy)), u, r); + r = if_eq(u, 0, a.mk_int(0), r); + r = if_eq(y, 0, x, r); + break; + } + case OP_BSDIV_I: + case OP_BSDIV: { + // d = udiv(abs(x), abs(y)) + // y = 0, x > 0 -> 1 + // y = 0, x <= 0 -> -1 + // x = 0, y != 0 -> 0 + // x > 0, y < 0 -> -d + // x < 0, y > 0 -> -d + // x > 0, y > 0 -> d + // x < 0, y < 0 -> d + expr* x = umod(e, 0), * y = umod(e, 1); + rational N = bv_size(e); + expr* signx = a.mk_ge(x, a.mk_int(N / 2)); + expr* signy = a.mk_ge(y, a.mk_int(N / 2)); + x = m.mk_ite(signx, a.mk_sub(a.mk_int(N), x), x); + y = m.mk_ite(signy, a.mk_sub(a.mk_int(N), y), y); + expr* d = a.mk_idiv(x, y); + r = m.mk_ite(m.mk_iff(signx, signy), d, a.mk_uminus(d)); + r = if_eq(y, 0, m.mk_ite(signx, a.mk_int(1), a.mk_int(-1)), r); + break; + } + case OP_BSREM_I: + case OP_BSREM: { + // y = 0 -> x + // else x - sdiv(x, y) * y + expr* x = umod(e, 0), * y = umod(e, 1); + rational N = bv_size(e); + expr* signx = a.mk_ge(x, a.mk_int(N / 2)); + expr* signy = a.mk_ge(y, a.mk_int(N / 2)); + expr* absx = m.mk_ite(signx, a.mk_sub(a.mk_int(N), x), x); + expr* absy = m.mk_ite(signy, a.mk_sub(a.mk_int(N), y), y); + expr* d = a.mk_idiv(absx, absy); + d = m.mk_ite(m.mk_iff(signx, signy), d, a.mk_uminus(d)); + r = a.mk_sub(x, mul(d, y)); + r = if_eq(y, 0, x, r); + break; + } + case OP_ROTATE_LEFT: { + auto n = e->get_parameter(0).get_int(); + r = rotate_left(n); + break; + } + case OP_ROTATE_RIGHT: { + unsigned sz = bv.get_bv_size(e); + auto n = e->get_parameter(0).get_int(); + r = rotate_left(sz - n); + break; + } + case OP_EXT_ROTATE_LEFT: { + unsigned sz = bv.get_bv_size(e); + expr* y = umod(e, 1); + r = a.mk_int(0); + for (unsigned i = 0; i < sz; ++i) + r = if_eq(y, i, rotate_left(i), r); + break; + } + case OP_EXT_ROTATE_RIGHT: { + unsigned sz = bv.get_bv_size(e); + expr* y = umod(e, 1); + r = a.mk_int(0); + for (unsigned i = 0; i < sz; ++i) + r = if_eq(y, i, rotate_left(sz - i), r); + break; + } + case OP_REPEAT: { + unsigned n = e->get_parameter(0).get_int(); + expr* x = umod(e->get_arg(0), 0); + r = x; + rational N = bv_size(e->get_arg(0)); + rational N0 = N; + for (unsigned i = 1; i < n; ++i) + r = add(mul(a.mk_int(N), x), r), N *= N0; + break; + } + case OP_BREDOR: { + r = umod(e->get_arg(0), 0); + r = m.mk_not(m.mk_eq(r, a.mk_int(0))); + break; + } + case OP_BREDAND: { + rational N = bv_size(e->get_arg(0)); + r = umod(e->get_arg(0), 0); + r = m.mk_not(m.mk_eq(r, a.mk_int(N - 1))); + break; + } + default: + verbose_stream() << mk_pp(e, m) << "\n"; + NOT_IMPLEMENTED_YET(); + } + set_translated(e, r); +} + +expr_ref bv2int_translator::if_eq(expr* n, unsigned k, expr* th, expr* el) { + rational r; + expr_ref _th(th, m), _el(el, m); + if (bv.is_numeral(n, r)) { + if (r == k) + return expr_ref(th, m); + else + return expr_ref(el, m); + } + return expr_ref(m.mk_ite(m.mk_eq(n, a.mk_int(k)), th, el), m); +} + +void bv2int_translator::translate_basic(app* e) { + if (m.is_eq(e)) { + bool has_bv_arg = any_of(*e, [&](expr* arg) { return bv.is_bv(arg); }); + if (has_bv_arg) { + expr* bv_expr = e->get_arg(0); + rational N = rational::power_of_two(bv.get_bv_size(bv_expr)); + if (a.is_numeral(arg(0)) || a.is_numeral(arg(1)) || + is_bounded(arg(0), N) || is_bounded(arg(1), N)) { + set_translated(e, m.mk_eq(umod(bv_expr, 0), umod(bv_expr, 1))); + } + else { + m_args[0] = a.mk_sub(arg(0), arg(1)); + set_translated(e, m.mk_eq(umod(bv_expr, 0), a.mk_int(0))); + } + } + else + set_translated(e, m.mk_eq(arg(0), arg(1))); + } + else if (m.is_ite(e)) + set_translated(e, m.mk_ite(arg(0), arg(1), arg(2))); + else if (m_is_plugin) + set_translated(e, e); + else + set_translated(e, m.mk_app(e->get_decl(), m_args)); +} + +bool bv2int_translator::is_bounded(expr* x, rational const& N) { + return any_of(m_vars, [&](expr* v) { + return is_translated(v) && translated(v) == x && bv_size(v) <= N; + }); +} + +bool bv2int_translator::is_non_negative(expr* bv_expr, expr* e) { + auto N = rational::power_of_two(bv.get_bv_size(bv_expr)); + rational r; + if (a.is_numeral(e, r)) + return r >= 0; + if (is_bounded(e, N)) + return true; + expr* x = nullptr, * y = nullptr; + if (a.is_mul(e, x, y)) + return is_non_negative(bv_expr, x) && is_non_negative(bv_expr, y); + if (a.is_add(e, x, y)) + return is_non_negative(bv_expr, x) && is_non_negative(bv_expr, y); + return false; +} + +expr* bv2int_translator::umod(expr* bv_expr, unsigned i) { + expr* x = arg(i); + rational N = bv_size(bv_expr); + return amod(bv_expr, x, N); +} + +expr* bv2int_translator::smod(expr* bv_expr, unsigned i) { + expr* x = arg(i); + auto N = bv_size(bv_expr); + auto shift = N / 2; + rational r; + if (a.is_numeral(x, r)) + return a.mk_int(mod(r + shift, N)); + return amod(bv_expr, add(x, a.mk_int(shift)), N); +} + +expr_ref bv2int_translator::mul(expr* x, expr* y) { + expr_ref _x(x, m), _y(y, m); + if (a.is_zero(x)) + return _x; + if (a.is_zero(y)) + return _y; + if (a.is_one(x)) + return _y; + if (a.is_one(y)) + return _x; + rational v1, v2; + if (a.is_numeral(x, v1) && a.is_numeral(y, v2)) + return expr_ref(a.mk_int(v1 * v2), m); + _x = a.mk_mul(x, y); + return _x; +} + +expr_ref bv2int_translator::add(expr* x, expr* y) { + expr_ref _x(x, m), _y(y, m); + if (a.is_zero(x)) + return _y; + if (a.is_zero(y)) + return _x; + rational v1, v2; + if (a.is_numeral(x, v1) && a.is_numeral(y, v2)) + return expr_ref(a.mk_int(v1 + v2), m); + _x = a.mk_add(x, y); + return _x; +} + +/* +* Perform simplifications that are claimed sound when the bit-vector interpretations of +* mod/div always guard the mod and dividend to be non-zero. +* Potentially shady area is for arithmetic expressions created by int2bv. +* They will be guarded by a modulus which does not disappear. +*/ +expr* bv2int_translator::amod(expr* bv_expr, expr* x, rational const& N) { + rational v; + expr* r = nullptr, * c = nullptr, * t = nullptr, * e = nullptr; + if (m.is_ite(x, c, t, e)) + r = m.mk_ite(c, amod(bv_expr, t, N), amod(bv_expr, e, N)); + else if (a.is_idiv(x, t, e) && a.is_numeral(t, v) && 0 <= v && v < N && is_non_negative(bv_expr, e)) + r = x; + else if (a.is_mod(x, t, e) && a.is_numeral(t, v) && 0 <= v && v < N) + r = x; + else if (a.is_numeral(x, v)) + r = a.mk_int(mod(v, N)); + else if (is_bounded(x, N)) + r = x; + else + r = a.mk_mod(x, a.mk_int(N)); + return r; +} + +void bv2int_translator::translate_eq(expr* e) { + expr* x = nullptr, * y = nullptr; + VERIFY(m.is_eq(e, x, y)); + SASSERT(bv.is_bv(x)); + if (!is_translated(e)) { + ensure_translated(x); + ensure_translated(y); + m_args.reset(); + m_args.push_back(a.mk_sub(translated(x), translated(y))); + set_translated(e, m.mk_eq(umod(x, 0), a.mk_int(0))); + } + m_preds.push_back(e); + TRACE("bv", tout << mk_pp(e, m) << " " << mk_pp(translated(e), m) << "\n"); + ctx.push(push_back_vector(m_preds)); + +} diff --git a/src/ast/rewriter/bv2int_translator.h b/src/ast/rewriter/bv2int_translator.h new file mode 100644 index 00000000000..97b8b76b873 --- /dev/null +++ b/src/ast/rewriter/bv2int_translator.h @@ -0,0 +1,84 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + bv2int_translator + Utilities for translating bit-vector constraints into arithmetic. + +Author: + + Nikolaj Bjorner (nbjorner) 2024-10-27 + + +--*/ +#pragma once + +#include "util/trail.h" + +class bv2int_translator_trail { +public: + virtual ~bv2int_translator_trail() {} + virtual void push(push_back_vector const& c) = 0; + virtual void push(push_back_vector> const& c) = 0; + virtual void push_idx(set_vector_idx_trail const& c) = 0; +}; + +class bv2int_translator { + ast_manager& m; + bv2int_translator_trail& ctx; + bv_util bv; + arith_util a; + obj_map m_new_funs; + expr_ref_vector m_translate, m_args; + ast_ref_vector m_pinned; + ptr_vector m_bv2int, m_int2bv; + expr_ref_vector m_vars, m_preds; + bool m_is_plugin = true; + + void set_translated(expr* e, expr* r); + expr* arg(unsigned i) { return m_args.get(i); } + + expr* umod(expr* bv_expr, unsigned i); + expr* smod(expr* bv_expr, unsigned i); + bool is_bounded(expr* v, rational const& N); + bool is_non_negative(expr* bv_expr, expr* e); + expr_ref mul(expr* x, expr* y); + expr_ref add(expr* x, expr* y); + expr_ref if_eq(expr* n, unsigned k, expr* th, expr* el); + expr* amod(expr* bv_expr, expr* x, rational const& N); + rational bv_size(expr* bv_expr); + expr_ref mk_le(expr* a, expr* b); + expr_ref mk_lt(expr* a, expr* b); + expr_ref mk_ge(expr* a, expr* b) { return mk_le(b, a); } + expr_ref mk_gt(expr* a, expr* b) { return mk_lt(b, a); } + + + void translate_bv(app* e); + void translate_basic(app* e); + void translate_app(app* e); + void translate_quantifier(quantifier* q); + void translate_var(var* v); + + +public: + bv2int_translator(ast_manager& m, bv2int_translator_trail& ctx); + + void ensure_translated(expr* e); + + void translate_eq(expr* e); + + bool is_translated(expr* e) const { return !!m_translate.get(e->get_id(), nullptr); } + expr* translated(expr* e) const { expr* r = m_translate.get(e->get_id(), nullptr); SASSERT(r); return r; } + + void internalize_bv(app* e); + void translate_expr(expr* e); + + expr_ref_vector const& vars() const { return m_vars; } + expr_ref_vector const& preds() const { return m_preds; } + ptr_vector const& bv2int() const { return m_bv2int; } + ptr_vector const& int2bv() const { return m_int2bv; } + + void reset(bool is_plugin); + +}; diff --git a/src/ast/rewriter/bv_rewriter.h b/src/ast/rewriter/bv_rewriter.h index 73710f5c6e0..bcf797353ca 100644 --- a/src/ast/rewriter/bv_rewriter.h +++ b/src/ast/rewriter/bv_rewriter.h @@ -223,7 +223,7 @@ class bv_rewriter : public poly_rewriter { #define MK_BV_BINARY(OP) \ expr_ref OP(expr* a, expr* b) { \ - expr_ref result(m); \ + expr_ref result(m), _a(a, m), _b(b, m); \ if (BR_FAILED == OP(a, b, result)) \ result = m_util.OP(a, b); \ return result; \ @@ -238,6 +238,7 @@ class bv_rewriter : public poly_rewriter { MK_BV_BINARY(mk_bv_urem); MK_BV_BINARY(mk_ule); + MK_BV_BINARY(mk_sle); MK_BV_BINARY(mk_bv_add); MK_BV_BINARY(mk_bv_mul); MK_BV_BINARY(mk_bv_sub); @@ -250,6 +251,13 @@ class bv_rewriter : public poly_rewriter { return result; } + expr_ref mk_bv_neg(expr* a) { + expr_ref result(a, m); + if (BR_FAILED == mk_uminus(a, result)) + result = m_util.mk_bv_neg(a); + return result; + } + }; diff --git a/src/ast/simplifiers/elim_unconstrained.cpp b/src/ast/simplifiers/elim_unconstrained.cpp index a355f823acc..478eb00283c 100644 --- a/src/ast/simplifiers/elim_unconstrained.cpp +++ b/src/ast/simplifiers/elim_unconstrained.cpp @@ -42,6 +42,73 @@ proof production is work in progress. reconstruct_term should assign proof objects with nodes by applying monotonicity or reflexivity rules. +Maintain term nodes. +Each term node has a root. The root of the root is itself. +The root of a term node can be updated. +The parents of terms with same roots are combined. +The depth of a parent is always greater than the depth of a child. +The effective term of a node is reconstructed by taking the root and canonizing the children based on roots. +The reference count of a term is the number of parents it has. + +node: term -> node +dirty: node -> bool +root: node -> node +top: node -> bool +term: node -> term + +invariant: +- root(root(n)) = root(n) +- term(node(t)) = t + +parents: node -> node* +parents(root(node)) = union of parents of n : root(n) = root(node). +is_child(n, p) = term(root(n)) in args(term(root(p))) + +set_root: node -> node -> void +set_root(n, r) = + r = root(r) + n = root(n) + if r = n then return + parents(r) = parents(r) union parents(n) + root(n) := r, + top(r) := top(r) or top(n) + set all parents of class(r) to dirty, recursively + +reconstruct_term: node -> term +reconstruct_term(n) = + r = root(n) + if dirty(r) then + args = [reconstruct_term(c) | c in args(term(r))] + term(r) := term(r).f(args) + dirty(r) := false + return term(r) + +live : term -> bool +live(t) = + n = node(t) + is_root(n) & (top(n) or p in parents(n) : live(p)) + +Only live nodes require updates. + +eliminate: + while heap is not empty: + v = heap.erase_min() + n = node(v) + if n.parents.size() > 1 then + return + if !is_root(n) or !live(n) or n.parents.size() != 1 then + continue + p = n.parents[0] + if !is_child(n, p) or !is_root(p) then + continue + t = p.term + args = [reconstruct_term(node(arg)) | arg in t] + r = inverter(t.f, args) + if r then + set_root(n, r) + + + --*/ @@ -54,15 +121,17 @@ monotonicity or reflexivity rules. elim_unconstrained::elim_unconstrained(ast_manager& m, dependent_expr_state& fmls) : dependent_expr_simplifier(m, fmls), m_inverter(m), m_lt(*this), m_heap(1024, m_lt), m_trail(m), m_args(m) { std::function is_var = [&](expr* e) { - return is_uninterp_const(e) && !m_fmls.frozen(e) && is_node(e) && get_node(e).m_refcount <= 1; + return is_uninterp_const(e) && !m_fmls.frozen(e) && get_node(e).is_root() && get_node(e).num_parents() <= 1; }; m_inverter.set_is_var(is_var); } -bool elim_unconstrained::is_var_lt(int v1, int v2) const { - node const& n1 = get_node(v1); - node const& n2 = get_node(v2); - return n1.m_refcount < n2.m_refcount; +elim_unconstrained::~elim_unconstrained() { + reset_nodes(); +} + +bool elim_unconstrained::is_var_lt(int v1, int v2) const { + return get_node(v1).num_parents() < get_node(v2).num_parents(); } void elim_unconstrained::eliminate() { @@ -70,30 +139,29 @@ void elim_unconstrained::eliminate() { expr_ref r(m); int v = m_heap.erase_min(); node& n = get_node(v); - if (n.m_refcount == 0) + if (!n.is_root() || n.is_top()) + continue; + unsigned num_parents = n.num_parents(); + if (num_parents == 0) continue; - if (n.m_refcount > 1) + if (num_parents > 1) return; - if (n.m_parents.empty()) { - n.m_refcount = 0; + node& p = n.parent(); + if (!is_child(n, p) || !p.is_root()) continue; - } - expr* e = get_parent(v); - TRACE("elim_unconstrained", for (expr* p : n.m_parents) tout << "parent " << mk_bounded_pp(p, m) << " @ " << get_node(p).m_refcount << "\n";); - if (!e || !is_app(e) || !is_ground(e)) { - n.m_refcount = 0; - continue; - } - if (m_heap.contains(root(e))) { - TRACE("elim_unconstrained", tout << "already in heap " << mk_bounded_pp(e, m) << "\n"); + expr* e = p.term(); + if (!e || !is_app(e) || !is_ground(e)) continue; - } + + SASSERT(!m_heap.contains(p.term()->get_id())); + app* t = to_app(e); - TRACE("elim_unconstrained", tout << "eliminating " << mk_pp(t, m) << "\n";); + TRACE("elim_unconstrained", tout << "eliminating " << mk_bounded_pp(t, m) << "\n";); unsigned sz = m_args.size(); for (expr* arg : *to_app(t)) - m_args.push_back(reconstruct_term(get_node(arg))); + m_args.push_back(reconstruct_term(root(arg))); + expr_ref rr(m.mk_app(t->get_decl(), t->get_num_args(), m_args.data() + sz), m); bool inverted = m_inverter(t->get_decl(), t->get_num_args(), m_args.data() + sz, r); proof_ref pr(m); if (inverted && m_enable_proofs) { @@ -103,67 +171,91 @@ void elim_unconstrained::eliminate() { proof * pr = m.mk_apply_def(s, r, pr1); m_trail.push_back(pr); } - expr_ref rr(m.mk_app(t->get_decl(), t->get_num_args(), m_args.data() + sz), m); - n.m_refcount = 0; m_args.shrink(sz); - if (!inverted) { - TRACE("elim_unconstrained", tout << "not inverted " << mk_bounded_pp(e, m) << "\n"); + if (!inverted) continue; - } - IF_VERBOSE(11, verbose_stream() << "replace " << mk_pp(t, m) << " / " << rr << " -> " << r << "\n"); + IF_VERBOSE(4, verbose_stream() << "replace " << mk_bounded_pp(t, m) << " / " << mk_bounded_pp(rr, m) << " -> " << mk_bounded_pp(r, m) << "\n"); + - TRACE("elim_unconstrained", tout << mk_pp(t, m) << " / " << rr << " -> " << r << "\n"); + TRACE("elim_unconstrained", tout << mk_bounded_pp(t, m) << " / " << mk_bounded_pp(rr, m) << " -> " << mk_bounded_pp(r, m) << "\n"); SASSERT(r->get_sort() == t->get_sort()); m_stats.m_num_eliminated++; - m_trail.push_back(r); - SASSERT(r); - gc(e); - invalidate_parents(e); - freeze_rec(r); - - m_root.setx(r->get_id(), e->get_id(), UINT_MAX); - get_node(e).m_term = r; - get_node(e).m_proof = pr; - get_node(e).m_refcount++; - get_node(e).m_dirty = false; - IF_VERBOSE(11, verbose_stream() << "set " << &get_node(e) << " " << root(e) << " " << mk_bounded_pp(e, m) << " := " << mk_bounded_pp(r, m) << "\n"); - SASSERT(!m_heap.contains(root(e))); - if (is_uninterp_const(r)) - m_heap.insert(root(e)); + node& rn = root(r); + set_root(p, rn); + expr* rt = rn.term(); + SASSERT(!m_heap.contains(rt->get_id())); + if (is_uninterp_const(rt)) + m_heap.insert(rt->get_id()); else m_created_compound = true; - - IF_VERBOSE(11, verbose_stream() << mk_bounded_pp(get_node(v).m_orig, m) << " " << mk_bounded_pp(t, m) << " -> " << r << " " << get_node(e).m_refcount << "\n";); - } } -expr* elim_unconstrained::get_parent(unsigned n) const { - for (expr* p : get_node(n).m_parents) - if (get_node(p).m_refcount > 0 && get_node(p).m_term == get_node(p).m_orig) - return p; - return nullptr; +void elim_unconstrained::set_root(node& n, node& r) { + SASSERT(n.is_root()); + SASSERT(r.is_root()); + if (&n == &r) + return; + r.add_parents(n.parents()); + n.set_root(r); + for (auto p : n.parents()) + invalidate_parents(*p); } -void elim_unconstrained::invalidate_parents(expr* e) { - ptr_buffer todo; +void elim_unconstrained::invalidate_parents(node& n) { + ptr_buffer todo; + node* np = &n; do { - node& n = get_node(e); - if (!n.m_dirty && e == n.m_term) { - n.m_dirty = true; - for (expr* e : n.m_parents) - todo.push_back(e); + node& n = *np; + if (!n.is_dirty()) { + n.set_dirty(); + for (auto* p : n.parents()) + todo.push_back(p); } - e = nullptr; + np = nullptr; if (!todo.empty()) { - e = todo.back(); + np = todo.back(); todo.pop_back(); } } - while (e); + while (np); +} + +bool elim_unconstrained::is_child(node const& ch, node const& p) { + SASSERT(ch.is_root()); + return is_app(p.term()) && any_of(*to_app(p.term()), [&](expr* arg) { return &root(arg) == &ch; }); +} + +elim_unconstrained::node& elim_unconstrained::get_node(expr* t) { + unsigned id = t->get_id(); + if (m_nodes.size() <= id) + m_nodes.resize(id + 1, nullptr); + node* n = m_nodes[id]; + if (!n) { + n = alloc(node, m, t); + m_nodes[id] = n; + if (is_app(t)) { + for (auto arg : *to_app(t)) { + node& ch = get_node(arg); + SASSERT(ch.is_root()); + ch.add_parent(*n); + } + } + else if (is_quantifier(t)) { + node& ch = get_node(to_quantifier(t)->get_expr()); + SASSERT(ch.is_root()); + ch.add_parent(*n); + } + } + return *n; } +void elim_unconstrained::reset_nodes() { + for (node* n : m_nodes) + dealloc(n); + m_nodes.reset(); +} /** * initialize node structure @@ -182,201 +274,95 @@ void elim_unconstrained::init_nodes() { m_enable_proofs = true; } - m_trail.append(terms); m_heap.reset(); - m_root.reset(); - m_nodes.reset(); + reset_nodes(); // initialize nodes for terms in the original goal - init_terms(terms); - - // top-level terms have reference count > 0 - for (expr* e : terms) - inc_ref(e); - - m_inverter.set_produce_proofs(m_enable_proofs); - -} - -/** -* Create nodes for all terms in the goal -*/ -void elim_unconstrained::init_terms(expr_ref_vector const& terms) { unsigned max_id = 0; for (expr* e : subterms::all(terms)) max_id = std::max(max_id, e->get_id()); m_nodes.reserve(max_id + 1); m_heap.reserve(max_id + 1); - m_root.reserve(max_id + 1, UINT_MAX); for (expr* e : subterms_postorder::all(terms)) { - m_root.setx(e->get_id(), e->get_id(), UINT_MAX); node& n = get_node(e); - if (n.m_term) - continue; - n.m_orig = e; - n.m_term = e; - n.m_refcount = 0; - + SASSERT(n.is_root()); if (is_uninterp_const(e)) - m_heap.insert(root(e)); - if (is_quantifier(e)) { - expr* body = to_quantifier(e)->get_expr(); - get_node(body).m_parents.push_back(e); - inc_ref(body); - } - else if (is_app(e)) { - for (expr* arg : *to_app(e)) { - get_node(arg).m_parents.push_back(e); - inc_ref(arg); - } - } + m_heap.insert(e->get_id()); } -} - -void elim_unconstrained::freeze_rec(expr* r) { - expr_ref_vector children(m); - if (is_quantifier(r)) - children.push_back(to_quantifier(r)->get_expr()); - else if (is_app(r)) - children.append(to_app(r)->get_num_args(), to_app(r)->get_args()); - else - return; - if (children.empty()) - return; - for (expr* t : subterms::all(children)) - freeze(t); -} -void elim_unconstrained::freeze(expr* t) { - if (!is_uninterp_const(t)) - return; - if (m_nodes.size() <= t->get_id()) - return; - if (m_nodes.size() <= root(t)) - return; - node& n = get_node(t); - if (!n.m_term) - return; - if (m_heap.contains(root(t))) { - n.m_refcount = UINT_MAX / 2; - m_heap.increased(root(t)); - } -} + // mark top level terms + for (expr* e : terms) + get_node(e).set_top(); -void elim_unconstrained::gc(expr* t) { - ptr_vector todo; - todo.push_back(t); - while (!todo.empty()) { - t = todo.back(); - todo.pop_back(); + m_inverter.set_produce_proofs(m_enable_proofs); - node& n = get_node(t); - if (n.m_refcount == 0) - continue; - if (n.m_term && !is_node(n.m_term)) - continue; - - dec_ref(t); - if (n.m_refcount != 0) - continue; - if (n.m_term) - t = n.m_term; - if (is_app(t)) { - for (expr* arg : *to_app(t)) - todo.push_back(arg); - } - else if (is_quantifier(t)) - todo.push_back(to_quantifier(t)->get_expr()); - } } - -expr_ref elim_unconstrained::reconstruct_term(node& n0) { - expr* t = n0.m_term; - if (!n0.m_dirty) - return expr_ref(t, m); - if (!is_node(t)) - return expr_ref(t, m); - ptr_buffer todo; - todo.push_back(t); +expr* elim_unconstrained::reconstruct_term(node& n) { + SASSERT(n.is_root()); + if (!n.is_dirty()) + return n.term(); + ptr_buffer todo; + todo.push_back(&n); + expr_ref new_t(m); while (!todo.empty()) { - t = todo.back(); - if (!is_node(t)) { - UNREACHABLE(); + node* np = todo.back(); + if (!np->is_dirty()) { + todo.pop_back(); + continue; } - node& n = get_node(t); + SASSERT(np->is_root()); + auto t = np->term(); unsigned sz0 = todo.size(); - if (is_app(t)) { - if (n.m_term != t) { - n.m_dirty = false; - todo.pop_back(); - continue; + if (is_app(t)) { + for (expr* arg : *to_app(t)) { + node& r = root(arg); + if (r.is_dirty()) + todo.push_back(&r); } - for (expr* arg : *to_app(t)) - if (get_node(arg).m_dirty || !get_node(arg).m_term) - todo.push_back(arg); if (todo.size() != sz0) continue; unsigned sz = m_args.size(); - for (expr* arg : *to_app(t)) - m_args.push_back(get_node(arg).m_term); - n.m_term = m.mk_app(to_app(t)->get_decl(), to_app(t)->get_num_args(), m_args.data() + sz); + for (expr* arg : *to_app(t)) + m_args.push_back(root(arg).term()); + new_t = m.mk_app(to_app(t)->get_decl(), to_app(t)->get_num_args(), m_args.data() + sz); m_args.shrink(sz); } else if (is_quantifier(t)) { expr* body = to_quantifier(t)->get_expr(); - node& n2 = get_node(body); - if (n2.m_dirty || !n2.m_term) { - todo.push_back(body); + node& n2 = root(body); + if (n2.is_dirty()) { + todo.push_back(&n2); continue; } - n.m_term = m.update_quantifier(to_quantifier(t), n2.m_term); + new_t = m.update_quantifier(to_quantifier(t), n2.term()); } - m_trail.push_back(n.m_term); - m_root.setx(n.m_term->get_id(), n.m_term->get_id(), UINT_MAX); + else + new_t = t; + node& new_n = get_node(new_t); + set_root(*np, new_n); + np->set_clean(); todo.pop_back(); - n.m_dirty = false; } - return expr_ref(n0.m_term, m); + return n.root().term(); } /** * walk nodes starting from lowest depth and reconstruct their normalized forms. */ void elim_unconstrained::reconstruct_terms() { - expr_ref_vector terms(m); - for (unsigned i : indices()) - terms.push_back(m_fmls[i].fml()); + ptr_vector nodes; + for (node* n : m_nodes) + if (n && n->is_root()) + nodes.push_back(n); - for (expr* e : subterms_postorder::all(terms)) { - node& n = get_node(e); - expr* t = n.m_term; - if (t != n.m_orig) - continue; - if (is_app(t)) { - bool change = false; - m_args.reset(); - for (expr* arg : *to_app(t)) { - node& n2 = get_node(arg); - m_args.push_back(n2.m_term); - change |= n2.m_term != n2.m_orig; - } - if (change) { - n.m_term = m.mk_app(to_app(t)->get_decl(), m_args); - m_trail.push_back(n.m_term); - } - } - else if (is_quantifier(t)) { - node& n2 = get_node(to_quantifier(t)->get_expr()); - if (n2.m_term != n2.m_orig) { - n.m_term = m.update_quantifier(to_quantifier(t), n2.m_term); - m_trail.push_back(n.m_term); - } - } - } + std::stable_sort(nodes.begin(), nodes.end(), [&](node* a, node* b) { return get_depth(a->term()) < get_depth(b->term()); }); + + for (node* n : nodes) + reconstruct_term(*n); } @@ -384,12 +370,11 @@ void elim_unconstrained::assert_normalized(vector& old_fmls) { for (unsigned i : indices()) { auto [f, p, d] = m_fmls[i](); - node& n = get_node(f); - expr* g = n.m_term; + node& n = root(f); + expr* g = n.term(); if (f == g) continue; old_fmls.push_back(m_fmls[i]); - IF_VERBOSE(11, verbose_stream() << mk_bounded_pp(f, m, 3) << " -> " << mk_bounded_pp(g, m, 3) << "\n"); TRACE("elim_unconstrained", tout << mk_bounded_pp(f, m) << " -> " << mk_bounded_pp(g, m) << "\n"); m_fmls.update(i, dependent_expr(m, g, nullptr, d)); } @@ -441,6 +426,6 @@ void elim_unconstrained::reduce() { vector old_fmls; assert_normalized(old_fmls); update_model_trail(*mc, old_fmls); - mc->reset(); + mc->reset(); } } diff --git a/src/ast/simplifiers/elim_unconstrained.h b/src/ast/simplifiers/elim_unconstrained.h index 5dced90d04c..27f929453f2 100644 --- a/src/ast/simplifiers/elim_unconstrained.h +++ b/src/ast/simplifiers/elim_unconstrained.h @@ -24,14 +24,49 @@ class elim_unconstrained : public dependent_expr_simplifier { friend class seq_simplifier; - struct node { - unsigned m_refcount = 0; - expr* m_term = nullptr; - expr* m_orig = nullptr; - proof* m_proof = nullptr; + class node { + expr_ref m_term; + proof_ref m_proof; bool m_dirty = false; - ptr_vector m_parents; + ptr_vector m_parents; + node* m_root = nullptr; + bool m_top = false; + public: + + node(ast_manager& m, expr* t) : + m_term(t, m), + m_proof(m), + m_root(this) { + } + + void set_top() { m_top = true; } + bool is_top() const { return m_top; } + + void set_dirty() { m_dirty = true; } + void set_clean() { m_dirty = false; } + bool is_dirty() const { return m_dirty; } + + unsigned num_parents() const { return m_parents.size(); } + ptr_vector const& parents() const { return m_parents; } + void add_parent(node& p) { m_parents.push_back(&p); } + void add_parents(ptr_vector const& ps) { m_parents.append(ps); } + node& parent() const { SASSERT(num_parents() == 1); return *m_parents[0]; } + + bool is_root() const { return m_root == this; } + node& root() { node* r = m_root; while (!r->is_root()) r = r->m_root; return *r; } + node const& root() const { node* r = m_root; while (!r->is_root()) r = r->m_root; return *r; } + void set_root(node& r) { + SASSERT(r.is_root()); + m_root = &r; + SASSERT(term()->get_sort() == r.term()->get_sort()); + } + + void set_proof(proof* p) { m_proof = p; } + proof* get_proof() const { return m_proof; } + + expr* term() const { return m_term; } }; + struct var_lt { elim_unconstrained& s; var_lt(elim_unconstrained& s) : s(s) {} @@ -39,50 +74,44 @@ class elim_unconstrained : public dependent_expr_simplifier { return s.is_var_lt(v1, v2); } }; + struct stats { unsigned m_num_eliminated = 0; void reset() { m_num_eliminated = 0; } }; expr_inverter m_inverter; - vector m_nodes; + ptr_vector m_nodes; var_lt m_lt; heap m_heap; expr_ref_vector m_trail; expr_ref_vector m_args; stats m_stats; - unsigned_vector m_root; bool m_created_compound = false; bool m_enable_proofs = false; bool is_var_lt(int v1, int v2) const; - bool is_node(unsigned n) const { return m_nodes.size() > n; } - bool is_node(expr* t) const { return is_node(t->get_id()); } - node& get_node(unsigned n) { return m_nodes[n]; } - node const& get_node(unsigned n) const { return m_nodes[n]; } - node& get_node(expr* t) { return m_nodes[root(t)]; } - unsigned root(expr* t) const { return m_root[t->get_id()]; } - node const& get_node(expr* t) const { return m_nodes[root(t)]; } - unsigned get_refcount(expr* t) const { return get_node(t).m_refcount; } - void inc_ref(expr* t) { ++get_node(t).m_refcount; if (is_uninterp_const(t)) m_heap.increased(root(t)); } - void dec_ref(expr* t) { --get_node(t).m_refcount; if (is_uninterp_const(t)) m_heap.decreased(root(t)); } - void freeze(expr* t); - void freeze_rec(expr* r); - void gc(expr* t); - expr* get_parent(unsigned n) const; - void init_terms(expr_ref_vector const& terms); + node& get_node(unsigned n) const { return *m_nodes[n]; } + node& get_node(expr* t); + node& root(expr* t) { return get_node(t).root(); } + void set_root(node& n, node& r); + void invalidate_parents(node& n); + bool is_child(node const& ch, node const& p); + void init_nodes(); + void reset_nodes(); void eliminate(); void reconstruct_terms(); - expr_ref reconstruct_term(node& n); + expr* reconstruct_term(node& n); void assert_normalized(vector& old_fmls); void update_model_trail(generic_model_converter& mc, vector const& old_fmls); - void invalidate_parents(expr* e); - - + + public: elim_unconstrained(ast_manager& m, dependent_expr_state& fmls); + ~elim_unconstrained() override; + char const* name() const override { return "elim-unconstrained"; } void reduce() override; diff --git a/src/ast/sls/CMakeLists.txt b/src/ast/sls/CMakeLists.txt index 24eaec4dcac..a63fc099463 100644 --- a/src/ast/sls/CMakeLists.txt +++ b/src/ast/sls/CMakeLists.txt @@ -1,14 +1,25 @@ z3_add_component(ast_sls SOURCES bvsls_opt_engine.cpp - bv_sls.cpp - bv_sls_eval.cpp - bv_sls_fixed.cpp - bv_sls_terms.cpp - sls_engine.cpp - sls_valuation.cpp + sat_ddfw.cpp + sls_arith_base.cpp + sls_arith_plugin.cpp + sls_array_plugin.cpp + sls_basic_plugin.cpp + sls_bv_engine.cpp + sls_bv_eval.cpp + sls_bv_fixed.cpp + sls_bv_plugin.cpp + sls_bv_terms.cpp + sls_bv_valuation.cpp + sls_context.cpp + sls_datatype_plugin.cpp + sls_euf_plugin.cpp + sls_smt_plugin.cpp + sls_smt_solver.cpp COMPONENT_DEPENDENCIES ast + euf converters normal_forms ) diff --git a/src/ast/sls/bv_sls.cpp b/src/ast/sls/bv_sls.cpp deleted file mode 100644 index f1b2a9f4f27..00000000000 --- a/src/ast/sls/bv_sls.cpp +++ /dev/null @@ -1,364 +0,0 @@ -/*++ -Copyright (c) 2024 Microsoft Corporation - -Module Name: - - bv_sls.cpp - -Abstract: - - A Stochastic Local Search (SLS) engine - Uses invertibility conditions, - interval annotations - don't care annotations - -Author: - - Nikolaj Bjorner (nbjorner) 2024-02-07 - ---*/ - -#include "ast/sls/bv_sls.h" -#include "ast/ast_pp.h" -#include "ast/ast_ll_pp.h" -#include "params/sls_params.hpp" - -namespace bv { - - sls::sls(ast_manager& m, params_ref const& p): - m(m), - bv(m), - m_terms(m), - m_eval(m), - m_engine(m, p) - { - updt_params(p); - } - - void sls::init() { - m_terms.init(); - } - - void sls::init_eval(std::function& eval) { - m_eval.init_eval(m_terms.assertions(), eval); - m_eval.tighten_range(m_terms.assertions()); - init_repair(); - } - - void sls::init_repair() { - m_repair_down = UINT_MAX; - m_repair_up.reset(); - m_repair_roots.reset(); - for (auto* e : m_terms.assertions()) { - if (!m_eval.bval0(e)) { - m_eval.set(e, true); - m_repair_roots.insert(e->get_id()); - } - } - for (auto* t : m_terms.terms()) { - if (t && !m_eval.re_eval_is_correct(t)) - m_repair_roots.insert(t->get_id()); - } - } - - - void sls::set_model() { - if (!m_set_model) - return; - if (m_repair_roots.size() >= m_min_repair_size) - return; - m_min_repair_size = m_repair_roots.size(); - IF_VERBOSE(2, verbose_stream() << "(sls-update-model :num-unsat " << m_min_repair_size << ")\n"); - m_set_model(*get_model()); - } - - void sls::init_repair_goal(app* t) { - m_eval.init_eval(t); - } - - void sls::init_repair_candidates() { - m_to_repair.reset(); - ptr_vector todo; - expr_fast_mark1 mark; - for (auto index : m_repair_roots) - todo.push_back(m_terms.term(index)); - for (unsigned i = 0; i < todo.size(); ++i) { - expr* e = todo[i]; - if (mark.is_marked(e)) - continue; - mark.mark(e); - if (!is_app(e)) - continue; - for (expr* arg : *to_app(e)) - todo.push_back(arg); - - if (is_uninterp_const(e)) - m_to_repair.insert(e->get_id()); - - } - } - - - void sls::reinit_eval() { - init_repair_candidates(); - - if (m_to_repair.empty()) - return; - - // refresh the best model so far to a callback - set_model(); - - // add fresh units, if any - bool new_assertion = false; - while (m_get_unit) { - auto e = m_get_unit(); - if (!e) - break; - new_assertion = true; - assert_expr(e); - } - if (new_assertion) - init(); - - std::function eval = [&](expr* e, unsigned i) { - unsigned id = e->get_id(); - bool keep = !m_to_repair.contains(id); - if (m.is_bool(e)) { - if (m_eval.is_fixed0(e) || keep) - return m_eval.bval0(e); - if (m_engine_init) { - auto const& z = m_engine.get_value(e); - return rational(z).get_bit(0); - } - } - else if (bv.is_bv(e)) { - auto& w = m_eval.wval(e); - if (w.fixed.get(i) || keep) - return w.get_bit(i); - if (m_engine_init) { - auto const& z = m_engine.get_value(e); - return rational(z).get_bit(i); - } - } - - return m_rand() % 2 == 0; - }; - m_eval.init_eval(m_terms.assertions(), eval); - init_repair(); - // m_engine_init = false; - } - - std::pair sls::next_to_repair() { - app* e = nullptr; - if (m_repair_down != UINT_MAX) { - e = m_terms.term(m_repair_down); - m_repair_down = UINT_MAX; - return { true, e }; - } - - if (!m_repair_up.empty()) { - unsigned index = m_repair_up.elem_at(m_rand(m_repair_up.size())); - m_repair_up.remove(index); - e = m_terms.term(index); - return { false, e }; - } - - while (!m_repair_roots.empty()) { - unsigned index = m_repair_roots.elem_at(m_rand(m_repair_roots.size())); - e = m_terms.term(index); - if (m_terms.is_assertion(e) && !m_eval.bval1(e)) { - SASSERT(m_eval.bval0(e)); - return { true, e }; - } - if (!m_eval.re_eval_is_correct(e)) { - init_repair_goal(e); - return { true, e }; - } - m_repair_roots.remove(index); - } - - return { false, nullptr }; - } - - lbool sls::search1() { - // init and init_eval were invoked - unsigned n = 0; - for (; n < m_config.m_max_repairs && m.inc(); ++n) { - auto [down, e] = next_to_repair(); - if (!e) - return l_true; - - IF_VERBOSE(20, trace_repair(down, e)); - - ++m_stats.m_moves; - if (down) - try_repair_down(e); - else - try_repair_up(e); - } - return l_undef; - } - - lbool sls::search2() { - lbool res = l_undef; - if (m_stats.m_restarts == 0) - res = m_engine(), - m_engine_init = true; - else if (m_stats.m_restarts % 1000 == 0) - res = m_engine.search_loop(), - m_engine_init = true; - if (res != l_undef) - m_engine_model = true; - return res; - } - - - lbool sls::operator()() { - lbool res = l_undef; - m_stats.reset(); - m_stats.m_restarts = 0; - m_engine_model = false; - m_engine_init = false; - - do { - res = search1(); - if (res != l_undef) - break; - trace(); - //res = search2(); - if (res != l_undef) - break; - reinit_eval(); - } - while (m.inc() && m_stats.m_restarts++ < m_config.m_max_restarts); - - return res; - } - - void sls::try_repair_down(app* e) { - unsigned n = e->get_num_args(); - if (n == 0) { - m_eval.commit_eval(e); - for (auto p : m_terms.parents(e)) - m_repair_up.insert(p->get_id()); - - return; - } - - if (n == 2) { - auto d1 = get_depth(e->get_arg(0)); - auto d2 = get_depth(e->get_arg(1)); - unsigned s = m_rand(d1 + d2 + 2); - if (s <= d1 && m_eval.try_repair(e, 0)) { - set_repair_down(e->get_arg(0)); - return; - } - if (m_eval.try_repair(e, 1)) { - set_repair_down(e->get_arg(1)); - return; - } - if (m_eval.try_repair(e, 0)) { - set_repair_down(e->get_arg(0)); - return; - } - } - else { - unsigned s = m_rand(n); - for (unsigned i = 0; i < n; ++i) { - auto j = (i + s) % n; - if (m_eval.try_repair(e, j)) { - set_repair_down(e->get_arg(j)); - return; - } - } - } - IF_VERBOSE(3, verbose_stream() << "init-repair " << mk_bounded_pp(e, m) << "\n"); - // repair was not successful, so reset the state to find a different way to repair - init_repair(); - } - - void sls::try_repair_up(app* e) { - - if (m_terms.is_assertion(e)) - m_repair_roots.insert(e->get_id()); - else if (!m_eval.repair_up(e)) { - IF_VERBOSE(2, verbose_stream() << "repair-up "; trace_repair(true, e)); - if (m_rand(10) != 0) { - m_eval.set_random(e); - m_repair_roots.insert(e->get_id()); - } - else - init_repair(); - } - else { - if (!m_eval.eval_is_correct(e)) { - verbose_stream() << "incorrect eval #" << e->get_id() << " " << mk_bounded_pp(e, m) << "\n"; - } - SASSERT(m_eval.eval_is_correct(e)); - for (auto p : m_terms.parents(e)) - m_repair_up.insert(p->get_id()); - } - } - - - model_ref sls::get_model() { - if (m_engine_model) - return m_engine.get_model(); - - model_ref mdl = alloc(model, m); - auto& terms = m_eval.sort_assertions(m_terms.assertions()); - for (expr* e : terms) { - if (!is_uninterp_const(e)) - continue; - auto f = to_app(e)->get_decl(); - auto v = m_eval.get_value(to_app(e)); - if (v) - mdl->register_decl(f, v); - } - terms.reset(); - return mdl; - } - - std::ostream& sls::display(std::ostream& out) { - auto& terms = m_eval.sort_assertions(m_terms.assertions()); - for (expr* e : terms) { - out << e->get_id() << ": " << mk_bounded_pp(e, m, 1) << " "; - if (m_eval.is_fixed0(e)) - out << "f "; - if (m_repair_up.contains(e->get_id())) - out << "u "; - if (m_repair_roots.contains(e->get_id())) - out << "r "; - m_eval.display_value(out, e); - out << "\n"; - } - terms.reset(); - return out; - } - - void sls::updt_params(params_ref const& _p) { - sls_params p(_p); - m_config.m_max_restarts = p.max_restarts(); - m_config.m_max_repairs = p.max_repairs(); - m_rand.set_seed(p.random_seed()); - m_terms.updt_params(_p); - params_ref q = _p; - q.set_uint("max_restarts", 10); - m_engine.updt_params(q); - } - - std::ostream& sls::trace_repair(bool down, expr* e) { - verbose_stream() << (down ? "d #" : "u #") - << e->get_id() << ": " - << mk_bounded_pp(e, m, 1) << " "; - m_eval.display_value(verbose_stream(), e) << "\n"; - return verbose_stream(); - } - - void sls::trace() { - IF_VERBOSE(2, verbose_stream() - << "(bvsls :restarts " << m_stats.m_restarts - << " :repair-up " << m_repair_up.size() - << " :repair-roots " << m_repair_roots.size() << ")\n"); - } -} diff --git a/src/ast/sls/bv_sls.h b/src/ast/sls/bv_sls.h deleted file mode 100644 index 987cebcdbfc..00000000000 --- a/src/ast/sls/bv_sls.h +++ /dev/null @@ -1,129 +0,0 @@ -/*++ -Copyright (c) 2024 Microsoft Corporation - -Module Name: - - bv_sls.h - -Abstract: - - A Stochastic Local Search (SLS) engine - -Author: - - Nikolaj Bjorner (nbjorner) 2024-02-07 - ---*/ -#pragma once - -#include "util/lbool.h" -#include "util/params.h" -#include "util/scoped_ptr_vector.h" -#include "util/uint_set.h" -#include "ast/ast.h" -#include "ast/sls/sls_stats.h" -#include "ast/sls/sls_powers.h" -#include "ast/sls/sls_valuation.h" -#include "ast/sls/bv_sls_terms.h" -#include "ast/sls/bv_sls_eval.h" -#include "ast/sls/sls_engine.h" -#include "ast/bv_decl_plugin.h" -#include "model/model.h" - -namespace bv { - - - class sls { - - struct config { - unsigned m_max_restarts = 1000; - unsigned m_max_repairs = 1000; - }; - - ast_manager& m; - bv_util bv; - sls_terms m_terms; - sls_eval m_eval; - sls_stats m_stats; - indexed_uint_set m_repair_up, m_repair_roots; - unsigned m_repair_down = UINT_MAX; - ptr_vector m_todo; - random_gen m_rand; - config m_config; - sls_engine m_engine; - bool m_engine_model = false; - bool m_engine_init = false; - std::function m_get_unit; - std::function m_set_model; - unsigned m_min_repair_size = UINT_MAX; - - std::pair next_to_repair(); - - void init_repair_goal(app* e); - void set_model(); - void try_repair_down(app* e); - void try_repair_up(app* e); - void set_repair_down(expr* e) { m_repair_down = e->get_id(); } - - lbool search1(); - lbool search2(); - void reinit_eval(); - void init_repair(); - void trace(); - std::ostream& trace_repair(bool down, expr* e); - - indexed_uint_set m_to_repair; - void init_repair_candidates(); - - public: - sls(ast_manager& m, params_ref const& p); - - /** - * Add constraints - */ - void assert_expr(expr* e) { m_terms.assert_expr(e); m_engine.assert_expr(e); } - - /* - * Invoke init after all expressions are asserted. - */ - void init(); - - /** - * Invoke init_eval to initialize, or re-initialize, values of - * uninterpreted constants. - */ - void init_eval(std::function& eval); - - /** - * add callback to retrieve new units - */ - void init_unit(std::function get_unit) { m_get_unit = get_unit; } - - /** - * Add callback to set model - */ - void set_model(std::function f) { m_set_model = f; } - - /** - * Run (bounded) local search to find feasible assignments. - */ - lbool operator()(); - - void updt_params(params_ref const& p); - void collect_statistics(statistics& st) const { m_stats.collect_statistics(st); m_engine.collect_statistics(st); } - void reset_statistics() { m_stats.reset(); m_engine.reset_statistics(); } - - unsigned get_num_moves() { return m_stats.m_moves + m_engine.get_stats().m_moves; } - - std::ostream& display(std::ostream& out); - - /** - * Retrieve valuation - */ - sls_valuation const& wval(expr* e) const { return m_eval.wval(e); } - - model_ref get_model(); - - void cancel() { m.limit().cancel(); } - }; -} diff --git a/src/ast/sls/bv_sls_terms.cpp b/src/ast/sls/bv_sls_terms.cpp deleted file mode 100644 index ed1bf2396bc..00000000000 --- a/src/ast/sls/bv_sls_terms.cpp +++ /dev/null @@ -1,229 +0,0 @@ -/*++ -Copyright (c) 2024 Microsoft Corporation - -Module Name: - - bv_sls.cpp - -Abstract: - - A Stochastic Local Search (SLS) engine - Uses invertibility conditions, - interval annotations - don't care annotations - -Author: - - Nikolaj Bjorner (nbjorner) 2024-02-07 - ---*/ - -#include "ast/ast_ll_pp.h" -#include "ast/sls/bv_sls.h" -#include "ast/rewriter/th_rewriter.h" - -namespace bv { - - sls_terms::sls_terms(ast_manager& m): - m(m), - bv(m), - m_rewriter(m), - m_assertions(m), - m_pinned(m), - m_translated(m), - m_terms(m){} - - - void sls_terms::assert_expr(expr* e) { - m_assertions.push_back(ensure_binary(e)); - } - - expr* sls_terms::ensure_binary(expr* e) { - expr* top = e; - m_pinned.push_back(e); - m_todo.push_back(e); - { - expr_fast_mark1 mark; - for (unsigned i = 0; i < m_todo.size(); ++i) { - expr* e = m_todo[i]; - if (!is_app(e)) - continue; - if (m_translated.get(e->get_id(), nullptr)) - continue; - if (mark.is_marked(e)) - continue; - mark.mark(e); - for (auto arg : *to_app(e)) - m_todo.push_back(arg); - } - } - std::stable_sort(m_todo.begin(), m_todo.end(), [&](expr* a, expr* b) { return get_depth(a) < get_depth(b); }); - for (expr* e : m_todo) - ensure_binary_core(e); - m_todo.reset(); - return m_translated.get(top->get_id()); - } - - void sls_terms::ensure_binary_core(expr* e) { - if (m_translated.get(e->get_id(), nullptr)) - return; - - app* a = to_app(e); - auto arg = [&](unsigned i) { - return m_translated.get(a->get_arg(i)->get_id()); - }; - unsigned num_args = a->get_num_args(); - expr_ref r(m); -#define FOLD_OP(oper) \ - r = arg(0); \ - for (unsigned i = 1; i < num_args; ++i)\ - r = oper(r, arg(i)); \ - - if (m.is_and(e)) { - FOLD_OP(m.mk_and); - } - else if (m.is_or(e)) { - FOLD_OP(m.mk_or); - } - else if (m.is_xor(e)) { - FOLD_OP(m.mk_xor); - } - else if (bv.is_bv_and(e)) { - FOLD_OP(bv.mk_bv_and); - } - else if (bv.is_bv_or(e)) { - FOLD_OP(bv.mk_bv_or); - } - else if (bv.is_bv_xor(e)) { - FOLD_OP(bv.mk_bv_xor); - } - else if (bv.is_bv_add(e)) { - FOLD_OP(bv.mk_bv_add); - } - else if (bv.is_bv_mul(e)) { - FOLD_OP(bv.mk_bv_mul); - } - else if (bv.is_concat(e)) { - FOLD_OP(bv.mk_concat); - } - else if (m.is_distinct(e)) { - expr_ref_vector es(m); - for (unsigned i = 0; i < num_args; ++i) - for (unsigned j = i + 1; j < num_args; ++j) - es.push_back(m.mk_not(m.mk_eq(arg(i), arg(j)))); - r = m.mk_and(es); - } - else if (bv.is_bv_sdiv(e) || bv.is_bv_sdiv0(e) || bv.is_bv_sdivi(e)) { - r = mk_sdiv(arg(0), arg(1)); - } - else if (bv.is_bv_smod(e) || bv.is_bv_smod0(e) || bv.is_bv_smodi(e)) { - r = mk_smod(arg(0), arg(1)); - } - else if (bv.is_bv_srem(e) || bv.is_bv_srem0(e) || bv.is_bv_sremi(e)) { - r = mk_srem(arg(0), arg(1)); - } - else { - for (unsigned i = 0; i < num_args; ++i) - m_args.push_back(arg(i)); - r = m.mk_app(a->get_decl(), num_args, m_args.data()); - m_args.reset(); - } - m_translated.setx(e->get_id(), r); - } - - expr_ref sls_terms::mk_sdiv(expr* x, expr* y) { - // d = udiv(abs(x), abs(y)) - // y = 0, x >= 0 -> -1 - // y = 0, x < 0 -> 1 - // x = 0, y != 0 -> 0 - // x > 0, y < 0 -> -d - // x < 0, y > 0 -> -d - // x > 0, y > 0 -> d - // x < 0, y < 0 -> d - unsigned sz = bv.get_bv_size(x); - rational N = rational::power_of_two(sz); - expr_ref z(bv.mk_zero(sz), m); - expr* signx = bv.mk_ule(bv.mk_numeral(N / 2, sz), x); - expr* signy = bv.mk_ule(bv.mk_numeral(N / 2, sz), y); - expr* absx = m.mk_ite(signx, bv.mk_bv_neg(x), x); - expr* absy = m.mk_ite(signy, bv.mk_bv_neg(y), y); - expr* d = bv.mk_bv_udiv(absx, absy); - expr_ref r(m.mk_ite(m.mk_eq(signx, signy), d, bv.mk_bv_neg(d)), m); - r = m.mk_ite(m.mk_eq(z, y), - m.mk_ite(signx, bv.mk_one(sz), bv.mk_numeral(N - 1, sz)), - m.mk_ite(m.mk_eq(x, z), z, r)); - m_rewriter(r); - return r; - } - - expr_ref sls_terms::mk_smod(expr* x, expr* y) { - // u := umod(abs(x), abs(y)) - // u = 0 -> 0 - // y = 0 -> x - // x < 0, y < 0 -> -u - // x < 0, y >= 0 -> y - u - // x >= 0, y < 0 -> y + u - // x >= 0, y >= 0 -> u - unsigned sz = bv.get_bv_size(x); - expr_ref z(bv.mk_zero(sz), m); - expr_ref abs_x(m.mk_ite(bv.mk_sle(z, x), x, bv.mk_bv_neg(x)), m); - expr_ref abs_y(m.mk_ite(bv.mk_sle(z, y), y, bv.mk_bv_neg(y)), m); - expr_ref u(bv.mk_bv_urem(abs_x, abs_y), m); - expr_ref r(m); - r = m.mk_ite(m.mk_eq(u, z), z, - m.mk_ite(m.mk_eq(y, z), x, - m.mk_ite(m.mk_and(bv.mk_sle(z, x), bv.mk_sle(z, x)), u, - m.mk_ite(bv.mk_sle(z, x), bv.mk_bv_add(y, u), - m.mk_ite(bv.mk_sle(z, y), bv.mk_bv_sub(y, u), bv.mk_bv_neg(u)))))); - m_rewriter(r); - return r; - } - - expr_ref sls_terms::mk_srem(expr* x, expr* y) { - // y = 0 -> x - // else x - sdiv(x, y) * y - expr_ref r(m); - r = m.mk_ite(m.mk_eq(y, bv.mk_zero(bv.get_bv_size(x))), - x, bv.mk_bv_sub(x, bv.mk_bv_mul(y, mk_sdiv(x, y)))); - m_rewriter(r); - return r; - } - - - void sls_terms::init() { - // populate terms - expr_fast_mark1 mark; - for (expr* e : m_assertions) - m_todo.push_back(e); - while (!m_todo.empty()) { - expr* e = m_todo.back(); - m_todo.pop_back(); - if (mark.is_marked(e) || !is_app(e)) - continue; - mark.mark(e); - m_terms.setx(e->get_id(), to_app(e)); - for (expr* arg : *to_app(e)) - m_todo.push_back(arg); - } - // populate parents - m_parents.reset(); - m_parents.reserve(m_terms.size()); - for (expr* e : m_terms) { - if (!e || !is_app(e)) - continue; - for (expr* arg : *to_app(e)) - m_parents[arg->get_id()].push_back(e); - } - m_assertion_set.reset(); - for (auto a : m_assertions) - m_assertion_set.insert(a->get_id()); - } - - void sls_terms::updt_params(params_ref const& p) { - params_ref q = p; - q.set_bool("flat", false); - m_rewriter.updt_params(q); - } - - -} diff --git a/src/ast/sls/bv_sls_terms.h b/src/ast/sls/bv_sls_terms.h deleted file mode 100644 index a6294aa9c1e..00000000000 --- a/src/ast/sls/bv_sls_terms.h +++ /dev/null @@ -1,79 +0,0 @@ -/*++ -Copyright (c) 2024 Microsoft Corporation - -Module Name: - - bv_sls_terms.h - -Abstract: - - A Stochastic Local Search (SLS) engine - -Author: - - Nikolaj Bjorner (nbjorner) 2024-02-07 - ---*/ -#pragma once - -#include "util/lbool.h" -#include "util/params.h" -#include "util/scoped_ptr_vector.h" -#include "util/uint_set.h" -#include "ast/ast.h" -#include "ast/rewriter/th_rewriter.h" -#include "ast/sls/sls_stats.h" -#include "ast/sls/sls_powers.h" -#include "ast/sls/sls_valuation.h" -#include "ast/bv_decl_plugin.h" - -namespace bv { - - class sls_terms { - ast_manager& m; - bv_util bv; - th_rewriter m_rewriter; - ptr_vector m_todo, m_args; - expr_ref_vector m_assertions, m_pinned, m_translated; - app_ref_vector m_terms; - vector> m_parents; - tracked_uint_set m_assertion_set; - - expr* ensure_binary(expr* e); - void ensure_binary_core(expr* e); - - expr_ref mk_sdiv(expr* x, expr* y); - expr_ref mk_smod(expr* x, expr* y); - expr_ref mk_srem(expr* x, expr* y); - - public: - sls_terms(ast_manager& m); - - void updt_params(params_ref const& p); - - /** - * Add constraints - */ - void assert_expr(expr* e); - - /** - * Initialize structures: assertions, parents, terms - */ - void init(); - - /** - * Accessors. - */ - - ptr_vector const& parents(expr* e) const { return m_parents[e->get_id()]; } - - expr_ref_vector const& assertions() const { return m_assertions; } - - app* term(unsigned id) const { return m_terms.get(id); } - - app_ref_vector const& terms() const { return m_terms; } - - bool is_assertion(expr* e) const { return m_assertion_set.contains(e->get_id()); } - - }; -} diff --git a/src/ast/sls/bvsls_opt_engine.h b/src/ast/sls/bvsls_opt_engine.h index 25182c15624..1ad36259848 100644 --- a/src/ast/sls/bvsls_opt_engine.h +++ b/src/ast/sls/bvsls_opt_engine.h @@ -18,7 +18,7 @@ Module Name: --*/ #pragma once -#include "ast/sls/sls_engine.h" +#include "ast/sls/sls_bv_engine.h" class bvsls_opt_engine : public sls_engine { sls_tracker & m_hard_tracker; diff --git a/src/sat/sat_ddfw.cpp b/src/ast/sls/sat_ddfw.cpp similarity index 71% rename from src/sat/sat_ddfw.cpp rename to src/ast/sls/sat_ddfw.cpp index 73e8afe00b6..5473271ffcb 100644 --- a/src/sat/sat_ddfw.cpp +++ b/src/ast/sls/sat_ddfw.cpp @@ -11,8 +11,7 @@ Author: - Nikolaj Bjorner, Marijn Heule 2019-4-23 - + Nikolaj Bjorner, Marijn Heule 2019-4-23 Notes: @@ -26,20 +25,18 @@ --*/ #include "util/luby.h" -#include "sat/sat_ddfw.h" -#include "sat/sat_solver.h" -#include "sat/sat_params.hpp" +#include "util/trace.h" +#include "ast/sls/sat_ddfw.h" +#include "params/sat_params.hpp" + namespace sat { ddfw::~ddfw() { - for (auto& ci : m_clauses) - m_alloc.del_clause(ci.m_clause); } - lbool ddfw::check(unsigned sz, literal const* assumptions, parallel* p) { - init(sz, assumptions); - flet _p(m_par, p); + lbool ddfw::check(unsigned sz, literal const* assumptions) { + init(sz, assumptions); if (m_plugin) check_with_plugin(); else @@ -52,36 +49,42 @@ namespace sat { void ddfw::check_without_plugin() { while (m_limit.inc() && m_min_sz > 0) { if (should_reinit_weights()) do_reinit_weights(); - else if (do_flip()); + else if (do_flip()); else if (should_restart()) do_restart(); - else if (should_parallel_sync()) do_parallel_sync(); + else if (m_parallel_sync && m_parallel_sync()); else shift_weights(); } } void ddfw::check_with_plugin() { m_plugin->init_search(); - m_steps_since_progress = 0; unsigned steps = 0; - while (m_min_sz > 0 && m_steps_since_progress++ <= 1500000) { - if (should_reinit_weights()) do_reinit_weights(); - else if (steps % 5000 == 0) shift_weights(), m_plugin->on_rescale(); - else if (should_restart()) do_restart(), m_plugin->on_restart(); - else if (do_flip()); - else if (do_literal_flip()); - else if (should_parallel_sync()) do_parallel_sync(); - else shift_weights(), m_plugin->on_rescale(); - ++steps; + if (m_min_sz <= m_unsat.size()) + save_best_values(); + + try { + while (m_min_sz > 0 && m_limit.inc()) { + if (should_reinit_weights()) do_reinit_weights(); + else if (steps % 5000 == 0) shift_weights(), m_plugin->on_rescale(); + else if (should_restart()) do_restart(), m_plugin->on_restart(); + else if (do_flip()); + else shift_weights(), m_plugin->on_rescale(); + //verbose_stream() << "steps: " << steps << " min_sz: " << m_min_sz << " unsat: " << m_unsat.size() << "\n"; + ++steps; + } + } + catch (z3_exception& ex) { + IF_VERBOSE(0, verbose_stream() << "Exception: " << ex.msg() << "\n"); + throw; } m_plugin->finish_search(); } void ddfw::log() { double sec = m_stopwatch.get_current_seconds(); - double kflips_per_sec = (m_flips - m_last_flips) / (1000.0 * sec); + double kflips_per_sec = sec > 0 ? (m_flips - m_last_flips) / (1000.0 * sec) : 0.0; if (m_last_flips == 0) { IF_VERBOSE(1, verbose_stream() << "(sat.ddfw :unsat :models :kflips/sec :flips :restarts :reinits :unsat_vars :shifts"; - if (m_par) verbose_stream() << " :par"; verbose_stream() << ")\n"); } IF_VERBOSE(1, verbose_stream() << "(sat.ddfw " @@ -93,43 +96,37 @@ namespace sat { << std::setw(11) << m_reinit_count << std::setw(13) << m_unsat_vars.size() << std::setw(9) << m_shifts; - if (m_par) verbose_stream() << std::setw(10) << m_parsync_count; verbose_stream() << ")\n"); m_stopwatch.start(); m_last_flips = m_flips; } - template bool ddfw::do_flip() { double reward = 0; - bool_var v = pick_var(reward); - return apply_flip(v, reward); + bool_var v = pick_var(reward); + //verbose_stream() << "flip " << v << " " << reward << "\n"; + return apply_flip(v, reward); } - template bool ddfw::apply_flip(bool_var v, double reward) { if (v == null_bool_var) return false; if (reward > 0 || (reward == 0 && m_rand(100) <= m_config.m_use_reward_zero_pct)) { - if (uses_plugin && is_external(v)) - m_plugin->flip(v); - else - flip(v); - if (m_unsat.size() <= m_min_sz) + flip(v); + if (m_unsat.size() <= m_min_sz) save_best_values(); return true; } return false; } - template bool_var ddfw::pick_var(double& r) { double sum_pos = 0; unsigned n = 1; bool_var v0 = null_bool_var; for (bool_var v : m_unsat_vars) { - r = uses_plugin ? plugin_reward(v) : reward(v); + r = reward(v); if (r > 0.0) sum_pos += score(r); else if (r == 0.0 && sum_pos == 0 && (m_rand() % (n++)) == 0) @@ -138,7 +135,7 @@ namespace sat { if (sum_pos > 0) { double lim_pos = ((double) m_rand() / (1.0 + m_rand.max_value())) * sum_pos; for (bool_var v : m_unsat_vars) { - r = uses_plugin && is_external(v) ? m_vars[v].m_last_reward : reward(v); + r = reward(v); if (r > 0) { lim_pos -= score(r); if (lim_pos <= 0) @@ -154,96 +151,41 @@ namespace sat { return m_unsat_vars.elem_at(m_rand(m_unsat_vars.size())); } - template - bool ddfw::do_literal_flip() { - double reward = 1; - return apply_flip(pick_literal_var(), reward); - } - - /* - * Pick a random false literal from a satisfied clause such that - * the literal has zero break count and positive reward. - */ - template - bool_var ddfw::pick_literal_var() { -#if false - unsigned sz = m_clauses.size(); - unsigned start = rand(); - for (unsigned i = 0; i < 100; ++i) { - unsigned cl = (i + start) % sz; - if (m_unsat.contains(cl)) - continue; - for (auto lit : *m_clauses[cl].m_clause) { - if (is_true(lit)) - continue; - double r = uses_plugin ? plugin_reward(lit.var()) : reward(lit.var()); - if (r < 0) - continue; - //verbose_stream() << "false " << r << " " << lit << "\n"; - return lit.var(); - } - } -#endif - return null_bool_var; - } - - void ddfw::add(unsigned n, literal const* c) { - clause* cls = m_alloc.mk_clause(n, c, false); + void ddfw::add(unsigned n, literal const* c) { unsigned idx = m_clauses.size(); - - m_clauses.push_back(clause_info(cls, m_config.m_init_clause_weight)); - for (literal lit : *cls) { + m_clauses.push_back(clause_info(n, c, m_config.m_init_clause_weight)); + if (n > 2) + ++m_num_non_binary_clauses; + for (literal lit : m_clauses.back().m_clause) { m_use_list.reserve(2*(lit.var()+1)); m_vars.reserve(lit.var()+1); m_use_list[lit.index()].push_back(idx); } } + sat::bool_var ddfw::add_var() { + auto v = m_vars.size(); + m_vars.reserve(v + 1); + return v; + } + + void ddfw::reserve_vars(unsigned n) { + m_vars.reserve(n); + } + + /** * Remove the last clause that was added */ void ddfw::del() { auto& info = m_clauses.back(); - for (literal lit : *info.m_clause) + for (literal lit : info.m_clause) m_use_list[lit.index()].pop_back(); - m_alloc.del_clause(info.m_clause); m_clauses.pop_back(); if (m_unsat.contains(m_clauses.size())) m_unsat.remove(m_clauses.size()); } - void ddfw::add(solver const& s) { - set_seed(s.get_config().m_random_seed); - for (auto& ci : m_clauses) - m_alloc.del_clause(ci.m_clause); - m_clauses.reset(); - m_use_list.reset(); - m_num_non_binary_clauses = 0; - - unsigned trail_sz = s.init_trail_size(); - for (unsigned i = 0; i < trail_sz; ++i) { - add(1, s.m_trail.data() + i); - } - unsigned sz = s.m_watches.size(); - for (unsigned l_idx = 0; l_idx < sz; ++l_idx) { - literal l1 = ~to_literal(l_idx); - watch_list const & wlist = s.m_watches[l_idx]; - for (watched const& w : wlist) { - if (!w.is_binary_non_learned_clause()) - continue; - literal l2 = w.get_literal(); - if (l1.index() > l2.index()) - continue; - literal ls[2] = { l1, l2 }; - add(2, ls); - } - } - for (clause* c : s.m_clauses) { - add(c->size(), c->begin()); - } - m_num_non_binary_clauses = s.m_clauses.size(); - } - void ddfw::add_assumptions() { for (unsigned i = 0; i < m_assumptions.size(); ++i) add(1, m_assumptions.data() + i); @@ -264,8 +206,9 @@ namespace sat { for (unsigned v = 0; v < num_vars(); ++v) { value(v) = (m_rand() % 2) == 0; // m_use_list[lit.index()].size() >= m_use_list[nlit.index()].size(); } - init_clause_data(); - flatten_use_list(); + + if (!flatten_use_list()) + init_clause_data(); m_reinit_count = 0; m_reinit_next = m_config.m_reinit_base; @@ -273,29 +216,23 @@ namespace sat { m_restart_count = 0; m_restart_next = m_config.m_restart_base*2; - m_parsync_count = 0; - m_parsync_next = m_config.m_parsync_base; - - m_min_sz = m_unsat.size(); + m_min_sz = m_clauses.size(); m_flips = 0; m_last_flips = 0; m_shifts = 0; m_stopwatch.start(); } - void ddfw::reinit(solver& s, bool_vector const& phase) { - add(s); + void ddfw::reinit() { add_assumptions(); - for (unsigned v = 0; v < phase.size(); ++v) { - value(v) = phase[v]; - reward(v) = 0; - make_count(v) = 0; - } - init_clause_data(); flatten_use_list(); } - void ddfw::flatten_use_list() { + bool ddfw::flatten_use_list() { + if (num_vars() == m_use_list_vars && m_clauses.size() == m_use_list_clauses) + return false; + m_use_list_vars = num_vars(); + m_use_list_clauses = m_clauses.size(); m_use_list_index.reset(); m_flat_use_list.reset(); for (auto const& ul : m_use_list) { @@ -303,6 +240,8 @@ namespace sat { m_flat_use_list.append(ul); } m_use_list_index.push_back(m_flat_use_list.size()); + init_clause_data(); + return true; } void ddfw::flip(bool_var v) { @@ -310,15 +249,19 @@ namespace sat { literal lit = literal(v, !value(v)); literal nlit = ~lit; SASSERT(is_true(lit)); - for (unsigned cls_idx : use_list(*this, lit)) { - clause_info& ci = m_clauses[cls_idx]; + for (unsigned cls_idx : use_list(lit)) { + clause_info& ci = m_clauses[cls_idx]; ci.del(lit); double w = ci.m_weight; // cls becomes false: flip any variable in clause to receive reward w switch (ci.m_num_trues) { case 0: { +#if 0 + if (ci.m_clause.size() == 1) + verbose_stream() << "flipping unit clause " << ci << "\n"; +#endif m_unsat.insert_fresh(cls_idx); - clause const& c = get_clause(cls_idx); + auto const& c = get_clause(cls_idx); for (literal l : c) { inc_reward(l, w); inc_make(l); @@ -333,7 +276,7 @@ namespace sat { break; } } - for (unsigned cls_idx : use_list(*this, nlit)) { + for (unsigned cls_idx : use_list(nlit)) { clause_info& ci = m_clauses[cls_idx]; double w = ci.m_weight; // the clause used to have a single true (pivot) literal, now it has two. @@ -341,7 +284,7 @@ namespace sat { switch (ci.m_num_trues) { case 0: { m_unsat.remove(cls_idx); - clause const& c = get_clause(cls_idx); + auto const& c = get_clause(cls_idx); for (literal l : c) { dec_reward(l, w); dec_make(l); @@ -388,13 +331,13 @@ namespace sat { for (unsigned v = 0; v < num_vars(); ++v) { make_count(v) = 0; reward(v) = 0; - } + } m_unsat_vars.reset(); m_unsat.reset(); unsigned sz = m_clauses.size(); for (unsigned i = 0; i < sz; ++i) { auto& ci = m_clauses[i]; - clause const& c = get_clause(i); + auto const& c = get_clause(i); ci.m_trues = 0; ci.m_num_trues = 0; for (literal lit : c) @@ -415,13 +358,15 @@ namespace sat { break; } } + if (m_unsat.size() < m_min_sz) + save_best_values(); } bool ddfw::should_restart() { return m_flips >= m_restart_next; } - void ddfw::do_restart() { + void ddfw::do_restart() { reinit_values(); init_clause_data(); m_restart_next += m_config.m_restart_base*get_luby(++m_restart_count); @@ -445,52 +390,39 @@ namespace sat { } } - bool ddfw::should_parallel_sync() { - return m_par != nullptr && m_flips >= m_parsync_next; - } - void ddfw::save_priorities() { m_probs.reset(); for (unsigned v = 0; v < num_vars(); ++v) m_probs.push_back(-m_vars[v].m_reward_avg); } - void ddfw::do_parallel_sync() { - if (m_par->from_solver(*this)) - m_par->to_solver(*this); - - ++m_parsync_count; - m_parsync_next *= 3; - m_parsync_next /= 2; - } - void ddfw::save_model() { m_model.reserve(num_vars()); for (unsigned i = 0; i < num_vars(); ++i) m_model[i] = to_lbool(value(i)); save_priorities(); if (m_plugin) - m_plugin->on_save_model(); + m_plugin->on_save_model(); } - void ddfw::save_best_values() { - if (m_unsat.size() < m_min_sz) { - m_steps_since_progress = 0; - if (m_unsat.size() < 50 || m_min_sz * 10 > m_unsat.size() * 11) - save_model(); - } + if (m_save_best_values) + return; + if (m_plugin && !m_unsat.empty()) + return; + flet _save_best_values(m_save_best_values, true); + + bool do_save_model = ((m_unsat.size() < m_min_sz || m_unsat.empty()) && + ((m_unsat.size() < 50 || m_min_sz * 10 > m_unsat.size() * 11))); + + if (do_save_model) + save_model(); + if (m_unsat.size() < m_min_sz) { m_models.reset(); - // skip saving the first model. - for (unsigned v = 0; v < num_vars(); ++v) { - int& b = bias(v); - if (abs(b) > 3) { - b = b > 0 ? 3 : -3; - } - } + m_min_sz = m_unsat.size(); } - + unsigned h = value_hash(); unsigned occs = 0; bool contains = m_models.find(h, occs); @@ -504,8 +436,7 @@ namespace sat { if (occs > 100) { m_restart_next = m_flips; m_models.erase(h); - } - m_min_sz = m_unsat.size(); + } } unsigned ddfw::value_hash() const { @@ -538,11 +469,11 @@ namespace sat { unsigned ddfw::select_max_same_sign(unsigned cf_idx) { auto& ci = m_clauses[cf_idx]; unsigned cl = UINT_MAX; // clause pointer to same sign, max weight satisfied clause. - clause const& c = *ci.m_clause; + auto const& c = ci.m_clause; double max_weight = m_init_weight; unsigned n = 1; for (literal lit : c) { - for (unsigned cn_idx : use_list(*this, lit)) { + for (unsigned cn_idx : use_list(lit)) { auto& cn = m_clauses[cn_idx]; if (select_clause(max_weight, cn, n)) { cl = cn_idx; @@ -568,15 +499,20 @@ namespace sat { } unsigned ddfw::select_random_true_clause() { - unsigned num_clauses = m_clauses.size(); - unsigned rounds = 100 * num_clauses; - for (unsigned i = 0; i < rounds; ++i) { + unsigned num_clauses = m_clauses.size(); + for (unsigned i = 0; i < num_clauses; ++i) { unsigned idx = (m_rand() * m_rand()) % num_clauses; auto & cn = m_clauses[idx]; if (cn.is_true() && cn.m_weight >= m_init_weight) return idx; } - return UINT_MAX; + unsigned n = 0, idx = UINT_MAX; + for (unsigned i = 0; i < num_clauses; ++i) { + auto& cn = m_clauses[i]; + if (cn.is_true() && cn.m_weight >= m_init_weight && (m_rand() % (++n)) == 0) + idx = i; + } + return idx; } // 1% chance to disregard neighbor @@ -590,6 +526,7 @@ namespace sat { void ddfw::shift_weights() { ++m_shifts; + bool shifted = false; for (unsigned to_idx : m_unsat) { SASSERT(!m_clauses[to_idx].is_true()); unsigned from_idx = select_max_same_sign(to_idx); @@ -597,28 +534,75 @@ namespace sat { from_idx = select_random_true_clause(); if (from_idx == UINT_MAX) continue; + shifted = true; auto & cn = m_clauses[from_idx]; SASSERT(cn.is_true()); double w = calculate_transfer_weight(cn.m_weight); transfer_weight(from_idx, to_idx, w); } + //verbose_stream() << m_shifts << " " << m_flips << " " << shifted << "\n"; + if (!shifted && m_restart_next > m_flips) + m_restart_next = m_flips + (m_restart_next - m_flips) / 2; // DEBUG_CODE(invariant();); } + // apply unit propagation. + void ddfw::simplify() { + verbose_stream() << "simplify\n"; + sat::literal_vector units; + uint_set unit_set; + for (unsigned i = 0; i < m_clauses.size(); ++i) { + auto& ci = m_clauses[i]; + if (ci.m_clause.size() != 1) + continue; + auto lit = ci.m_clause[0]; + units.push_back(lit); + unit_set.insert(lit.index()); + m_use_list[lit.index()].reset(); + m_use_list[lit.index()].push_back(i); + } + auto is_unit = [&](sat::literal lit) { + return unit_set.contains(lit.index()); + }; + + sat::literal_vector new_clause; + for (unsigned i = 0; i < units.size(); ++i) { + auto lit = units[i]; + for (auto cidx : m_use_list[(~lit).index()]) { + auto& ci = m_clauses[cidx]; + if (ci.m_clause.size() == 1) + continue; + new_clause.reset(); + for (auto l : ci) { + if (!is_unit(~l)) + new_clause.push_back(l); + } + if (new_clause.size() == 1) { + verbose_stream() << "new unit " << lit << " " << ci << " -> " << new_clause << "\n"; + } + m_clauses[cidx] = sat::clause_info(new_clause.size(), new_clause.data(), m_config.m_init_clause_weight); + if (new_clause.size() == 1) { + units.push_back(new_clause[0]); + unit_set.insert(new_clause[0].index()); + } + } + } + for (auto unit : units) + m_use_list[(~unit).index()].reset(); + } + std::ostream& ddfw::display(std::ostream& out) const { unsigned num_cls = m_clauses.size(); for (unsigned i = 0; i < num_cls; ++i) { - out << get_clause(i) << " "; + out << get_clause(i) << " nt: "; auto const& ci = m_clauses[i]; - out << ci.m_num_trues << " " << ci.m_weight << "\n"; - } - for (unsigned v = 0; v < num_vars(); ++v) { - out << v << ": rw " << reward(v) << "\n"; + out << ci.m_num_trues << " w: " << ci.m_weight << "\n"; } + for (unsigned v = 0; v < num_vars(); ++v) + out << (is_true(literal(v, false)) ? "" : "-") << v << " rw: " << get_reward(v) << "\n"; out << "unsat vars: "; - for (bool_var v : m_unsat_vars) { - out << v << " "; - } + for (bool_var v : m_unsat_vars) + out << v << " "; out << "\n"; return out; } @@ -681,6 +665,20 @@ namespace sat { m_config.m_reinit_base = p.ddfw_reinit_base(); m_config.m_restart_base = p.ddfw_restart_base(); } + + void ddfw::collect_statistics(statistics& st) const { + st.update("sls-ddfw-flips", (double)m_flips); + st.update("sls-ddfw-restarts", m_restart_count); + st.update("sls-ddfw-reinits", m_reinit_count); + st.update("sls-ddfw-shifts", (double)m_shifts); + } + void ddfw::reset_statistics() { + m_flips = 0; + m_restart_count = 0; + m_reinit_count = 0; + m_shifts = 0; + } + } diff --git a/src/sat/sat_ddfw.h b/src/ast/sls/sat_ddfw.h similarity index 60% rename from src/sat/sat_ddfw.h rename to src/ast/sls/sat_ddfw.h index 3454d47dade..a00a196c9ab 100644 --- a/src/sat/sat_ddfw.h +++ b/src/ast/sls/sat_ddfw.h @@ -24,54 +24,27 @@ #include "util/rlimit.h" #include "util/params.h" #include "util/ema.h" -#include "sat/sat_clause.h" -#include "sat/sat_types.h" +#include "util/sat_sls.h" +#include "util/map.h" +#include "util/sat_literal.h" +#include "util/statistics.h" +#include "util/stopwatch.h" -namespace arith { - class sls; -} namespace sat { - class solver; - class parallel; class local_search_plugin { public: virtual ~local_search_plugin() {} virtual void init_search() = 0; virtual void finish_search() = 0; - virtual void flip(bool_var v) = 0; - virtual double reward(bool_var v) = 0; virtual void on_rescale() = 0; virtual void on_save_model() = 0; virtual void on_restart() = 0; }; - - class ddfw : public i_local_search { - friend class arith::sls; - public: - struct clause_info { - clause_info(clause* cl, double init_weight): m_weight(init_weight), m_clause(cl) {} - double m_weight; // weight of clause - unsigned m_trues = 0; // set of literals that are true - unsigned m_num_trues = 0; // size of true set - clause* m_clause; - bool is_true() const { return m_num_trues > 0; } - void add(literal lit) { ++m_num_trues; m_trues += lit.index(); } - void del(literal lit) { SASSERT(m_num_trues > 0); --m_num_trues; m_trues -= lit.index(); } - }; - - class use_list { - ddfw& p; - unsigned i; - public: - use_list(ddfw& p, literal lit) : - p(p), i(lit.index()) {} - unsigned const* begin() { return p.m_flat_use_list.data() + p.m_use_list_index[i]; } - unsigned const* end() { return p.m_flat_use_list.data() + p.m_use_list_index[i + 1]; } - unsigned size() const { return p.m_use_list_index[i + 1] - p.m_use_list_index[i]; } - }; - + + class ddfw { + friend class ddfw_wrapper; protected: struct config { @@ -95,45 +68,47 @@ namespace sat { }; struct var_info { + var_info() {} bool m_value = false; double m_reward = 0; double m_last_reward = 0; unsigned m_make_count = 0; int m_bias = 0; - bool m_external = false; ema m_reward_avg = 1e-5; }; config m_config; reslimit m_limit; - clause_allocator m_alloc; - svector m_clauses; + vector m_clauses; literal_vector m_assumptions; svector m_vars; // var -> info svector m_probs; // var -> probability of flipping svector m_scores; // reward -> score - model m_model; // var -> best assignment + svector m_model; // var -> best assignment unsigned m_init_weight = 2; - vector m_use_list; unsigned_vector m_flat_use_list; unsigned_vector m_use_list_index; + unsigned m_use_list_vars = 0, m_use_list_clauses = 0; indexed_uint_set m_unsat; indexed_uint_set m_unsat_vars; // set of variables that are in unsat clauses random_gen m_rand; + uint64_t m_last_flips_for_shift = 0; unsigned m_num_non_binary_clauses = 0; - unsigned m_restart_count = 0, m_reinit_count = 0, m_parsync_count = 0; - uint64_t m_restart_next = 0, m_reinit_next = 0, m_parsync_next = 0; + unsigned m_restart_count = 0, m_reinit_count = 0; + uint64_t m_restart_next = 0, m_reinit_next = 0; uint64_t m_flips = 0, m_last_flips = 0, m_shifts = 0; - unsigned m_min_sz = 0, m_steps_since_progress = 0; + unsigned m_min_sz = UINT_MAX; u_map m_models; stopwatch m_stopwatch; + unsigned_vector m_num_models; + bool m_save_best_values = false; - parallel* m_par; - local_search_plugin* m_plugin = nullptr; + scoped_ptr m_plugin = nullptr; + std::function m_parallel_sync; - void flatten_use_list(); + bool flatten_use_list(); /** * TBD: map reward value to a score, possibly through an exponential function, such as @@ -141,31 +116,20 @@ namespace sat { */ inline double score(double r) { return r; } - inline unsigned num_vars() const { return m_vars.size(); } - inline unsigned& make_count(bool_var v) { return m_vars[v].m_make_count; } inline bool& value(bool_var v) { return m_vars[v].m_value; } inline bool value(bool_var v) const { return m_vars[v].m_value; } - inline double& reward(bool_var v) { return m_vars[v].m_reward; } + inline double& reward(bool_var v) { return m_vars[v].m_reward; } - inline double reward(bool_var v) const { return m_vars[v].m_reward; } - - inline double plugin_reward(bool_var v) { return is_external(v) ? (m_vars[v].m_last_reward = m_plugin->reward(v)) : reward(v); } - - void set_external(bool_var v) { m_vars[v].m_external = true; } - - inline bool is_external(bool_var v) const { return m_vars[v].m_external; } - - inline int& bias(bool_var v) { return m_vars[v].m_bias; } unsigned value_hash() const; inline bool is_true(literal lit) const { return value(lit.var()) != lit.sign(); } - inline clause const& get_clause(unsigned idx) const { return *m_clauses[idx].m_clause; } + inline sat::literal_vector const& get_clause(unsigned idx) const { return m_clauses[idx].m_clause; } inline double get_weight(unsigned idx) const { return m_clauses[idx].m_weight; } @@ -193,20 +157,12 @@ namespace sat { void check_without_plugin(); // flip activity - template bool do_flip(); - template bool_var pick_var(double& reward); - template bool apply_flip(bool_var v, double reward); - template - bool do_literal_flip(); - - template - bool_var pick_literal_var(); void save_best_values(); void save_model(); @@ -226,11 +182,7 @@ namespace sat { void do_restart(); void reinit_values(); - unsigned select_random_true_clause(); - - // parallel integration - bool should_parallel_sync(); - void do_parallel_sync(); + unsigned select_random_true_clause(); void log(); @@ -240,8 +192,6 @@ namespace sat { void invariant(); - void add(unsigned sz, literal const* c); - void del(); void add_assumptions(); @@ -252,48 +202,79 @@ namespace sat { public: - ddfw(): m_par(nullptr) {} + ddfw() {} - ~ddfw() override; + ~ddfw(); - void set(local_search_plugin* p) { m_plugin = p; } + void set_plugin(local_search_plugin* p) { m_plugin = p; } - lbool check(unsigned sz, literal const* assumptions, parallel* p) override; + lbool check(unsigned sz, literal const* assumptions); - void updt_params(params_ref const& p) override; + void updt_params(params_ref const& p); - model const& get_model() const override { return m_model; } + svector const& get_model() const { return m_model; } - reslimit& rlimit() override { return m_limit; } + reslimit& rlimit() { return m_limit; } - void set_seed(unsigned n) override { m_rand.set_seed(n); } + void set_seed(unsigned n) { m_rand.set_seed(n); } - void add(solver const& s) override; - bool get_value(bool_var v) const override { return value(v); } + bool get_value(bool_var v) const { return value(v); } std::ostream& display(std::ostream& out) const; // for parallel integration - unsigned num_non_binary_clauses() const override { return m_num_non_binary_clauses; } - void reinit(solver& s, bool_vector const& phase) override; + unsigned num_non_binary_clauses() const { return m_num_non_binary_clauses; } - void collect_statistics(statistics& st) const override {} + void collect_statistics(statistics& st) const; - double get_priority(bool_var v) const override { return m_probs[v]; } + void reset_statistics(); + + double get_priority(bool_var v) const { return m_probs[v]; } // access clause information and state of Boolean search indexed_uint_set& unsat_set() { return m_unsat; } - unsigned num_clauses() const { return m_clauses.size(); } + indexed_uint_set const& unsat_set() const { return m_unsat; } + + vector const& clauses() const { return m_clauses; } clause_info& get_clause_info(unsigned idx) { return m_clauses[idx]; } + clause_info const& get_clause_info(unsigned idx) const { return m_clauses[idx]; } + void remove_assumptions(); void flip(bool_var v); - use_list get_use_list(literal lit) { return use_list(*this, lit); } + inline double get_reward(bool_var v) const { return m_vars[v].m_reward; } + + double get_reward_avg(bool_var v) const { return m_vars[v].m_reward_avg; } + + inline int& bias(bool_var v) { return m_vars[v].m_bias; } + + void reserve_vars(unsigned n); + + void add(unsigned sz, literal const* c); + + sat::bool_var add_var(); + + void reinit(); + + void force_restart() { m_restart_next = m_flips; } + + inline unsigned num_vars() const { return m_vars.size(); } + + void simplify(); + + + ptr_iterator use_list(literal lit) { + flatten_use_list(); + unsigned i = lit.index(); + auto const* b = m_flat_use_list.data() + m_use_list_index[i]; + auto const* e = m_flat_use_list.data() + m_use_list_index[i + 1]; + return { b, e }; + } }; } diff --git a/src/ast/sls/sls_arith_base.cpp b/src/ast/sls/sls_arith_base.cpp new file mode 100644 index 00000000000..ecc48fdcad4 --- /dev/null +++ b/src/ast/sls/sls_arith_base.cpp @@ -0,0 +1,2326 @@ +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + sls_arith_base.cpp + +Abstract: + + Local search dispatch for arithmetic + +Author: + + Nikolaj Bjorner (nbjorner) 2023-02-07 + + Uses quadratic solver method from nia_ls in hybrid-smt + (with a bug fix for when order of roots are swapped) + Other features from nia_ls are also used as a starting point, + such as tabu and fallbacks. + +Todo: + +- add fairness for which variable to flip and direction (by age fifo). + - maintain age per variable, per sign + +- include more general tabu measure + - + +- random walk when there is no applicable update + - repair_down can fail repeatedely. Then allow a mode to reset arguments similar to + repair of literals. + +- avoid overflow for nested products + +Done: +- add tabu for flipping variable back to the same value. + - remember last variable/delta and block -delta = last_delta && last_variable = current_variable +- include measures for bounded updates + - per variable maintain increasing range + +--*/ + +#include "ast/sls/sls_arith_base.h" +#include "ast/ast_ll_pp.h" +#include "ast/ast_pp.h" +#include + +namespace sls { + + template + bool arith_base::ineq::is_true() const { + switch (m_op) { + case ineq_kind::LE: + return m_args_value <= 0; + case ineq_kind::EQ: + return m_args_value == 0; + default: + return m_args_value < 0; + } + } + + + + template + std::ostream& arith_base::ineq::display(std::ostream& out) const { + bool first = true; + unsigned j = 0; + for (auto const& [c, v] : this->m_args) { + out << (first ? (c > 0 ? "" : "-") : (c > 0 ? " + " : " - ")); + bool first2 = abs(c) == 1; + if (abs(c) != 1) + out << abs(c); + auto const& m = this->m_monomials[j]; + + for (auto [w, p] : m) { + out << (first2 ? "" : " * ") << "v" << w; + if (p > 1) + out << "^" << p; + first2 = false; + } + first = false; + ++j; + } + if (this->m_coeff != 0) + out << " + " << this->m_coeff; + switch (m_op) { + case ineq_kind::LE: + out << " <= " << 0 << "(" << m_args_value << ")"; + break; + case ineq_kind::EQ: + out << " == " << 0 << "(" << m_args_value << ")"; + break; + default: + out << " < " << 0 << "(" << m_args_value << ")"; + break; + } +#if 0 + for (auto const& [x, nl] : this->m_nonlinear) { + if (nl.size() == 1 && nl[0].v == x) + continue; + for (auto const& [v, c, p] : nl) { + out << " v" << x; + if (p > 1) out << "^" << p; + out << " in v" << v; + } + } +#endif + return out; + } + + template + arith_base::arith_base(context& ctx) : + plugin(ctx), + a(m) { + m_fid = a.get_family_id(); + } + + template + void arith_base::save_best_values() { + for (auto& v : m_vars) + v.m_best_value = v.m_value; + check_ineqs(); + } + + // distance to true + template + num_t arith_base::dtt(bool sign, num_t const& args, ineq const& ineq) const { + num_t zero{ 0 }; + switch (ineq.m_op) { + case ineq_kind::LE: + if (sign) { + if (args + ineq.m_coeff <= 0) + return -ineq.m_coeff - args + 1; + return zero; + } + if (args + ineq.m_coeff <= 0) + return zero; + return args + ineq.m_coeff; + case ineq_kind::EQ: + if (sign) { + if (args + ineq.m_coeff == 0) + return num_t(1); + return zero; + } + if (args + ineq.m_coeff == 0) + return zero; + return num_t(1); + case ineq_kind::LT: + if (sign) { + if (args + ineq.m_coeff < 0) + return -ineq.m_coeff - args; + return zero; + } + if (args + ineq.m_coeff < 0) + return zero; + return args + ineq.m_coeff + 1; + default: + UNREACHABLE(); + return zero; + } + } + + // + // dtt is high overhead. It walks ineq.m_args + // m_vars[w].m_value can be computed outside and shared among calls + // different data-structures for storing coefficients + // + template + num_t arith_base::dtt(bool sign, ineq const& ineq, var_t v, num_t const& new_value) const { + for (auto const& [coeff, w] : ineq.m_args) + if (w == v) + return dtt(sign, ineq.m_args_value + coeff * (new_value - m_vars[v].m_value), ineq); + return num_t(1); + } + + template + num_t arith_base::dtt(bool sign, ineq const& ineq, num_t const& coeff, num_t const& delta) const { + return dtt(sign, ineq.m_args_value + coeff * delta, ineq); + } + + template + num_t arith_base::divide(var_t v, num_t const& delta, num_t const& coeff) { + if (is_int(v)) + return div(delta + abs(coeff) - 1, coeff); + else + return delta / coeff; + } + + template + num_t arith_base::divide_floor(var_t v, num_t const& a, num_t const& b) { + if (!is_int(v)) + return a / b; + if (b > 0 && a >= 0) + return div(a, b); + else if (b > 0) + return -div(-a + b - 1, b); + else if (a > 0) + return -div(a - b - 1, -b); + else + return div(-a, -b); + } + + template + num_t arith_base::divide_ceil(var_t v, num_t const& a, num_t const& b) { + if (!is_int(v)) + return a / b; + if (b > 0 && a >= 0) + return div(a + b - 1, b); + else if (b > 0) + return -div(-a, b); + else if (a > 0) + return -div(a, -b); + else + return div(-a - b - 1, -b); + } + + // + // i = 1, 3, 5, 7, 9, ... + // d, d - 1, d - 4, d - 9, d - 16, + // + template + static num_t sqrt(num_t d) { + if (d <= 1) + return d; + auto sq = 2*sqrt(div(d, num_t(4))) + 1; + if (sq * sq <= d) + return sq; + return sq - 1; + } + + // + // a*x^2 + b*x + c = sum + // + template + void arith_base::find_quadratic_moves(ineq const& ineq, var_t x, num_t const& a, num_t const& b, num_t const& sum) { + num_t c, d; + try { + c = sum - a * value(x) * value(x) - b * value(x); + d = b * b - 4 * a * c; + } + catch (overflow_exception const&) { + return; + } + if (d < 0) + return; + num_t root = sqrt(d); + bool is_square = root * root == d; + num_t ll = divide_floor(x, -b - root, 2 * a); + num_t lh = divide_ceil(x, -b - root, 2 * a); + num_t rl = divide_floor(x, -b + root, 2 * a); + num_t rh = divide_ceil(x, -b + root, 2 * a); + if (lh > rl) { + std::swap(ll, rl); + std::swap(lh, rh); + } + num_t eps(1); + if (!is_int(x) && abs(rh - lh) <= eps) + eps = abs(rh - lh) / num_t(2); +// verbose_stream() << a << " " << b << " " << c << "\n"; +// verbose_stream() << (-b - root) << " " << (2 * a) << " " << ll << " " << lh << "\n"; +// verbose_stream() << (-b + root) << " " << (2 * a) << " " << rl << " " << rh << "\n"; + SASSERT(ll <= lh && ll + 1 >= lh); + SASSERT(rl <= rh && rl + 1 >= rh); + SASSERT(!is_square || ll != lh || a * ll * ll + b * ll + c == 0); + SASSERT(!is_square || rl != rh || a * rl * rl + b * rl + c == 0); + if (d > 0 && lh == rh) + return; + if (d == 0 && ll != lh) + return; + + if (ineq.is_true()) { + switch (ineq.m_op) { + case ineq_kind::LE: + SASSERT(sum <= 0); + if (d == 0) + break; + if (a < 0) { + if (a * lh * lh + b * lh + c <= 0) + lh += eps; + if (a * rl * rl + b * rl + c <= 0) + rl -= eps; + SASSERT(!is_square || a * lh * lh + b * lh + c > 0); + SASSERT(!is_square || a * rl * rl + b * rl + c > 0); + add_update(x, lh - value(x)); + add_update(x, rl - value(x)); + } + else { + if (a * ll * ll + b * ll + c <= 0) + ll -= eps; + if (a * rh * rh + b * rh + c <= 0) + rh += eps; + SASSERT(!is_square || a * ll * ll + b * ll + c > 0); + SASSERT(!is_square || a * rh * rh + b * rh + c > 0); + add_update(x, ll - value(x)); + add_update(x, rh - value(x)); + } + break; + case ineq_kind::LT: + SASSERT(sum < 0); + SASSERT(!is_int(x)); + SASSERT(ll == lh); + SASSERT(rl == rh); + if (d == 0) + break; + + if (a > 0) { + SASSERT(!is_square || a * (ll + eps) * (ll + eps) + b * (ll + eps) + c >= 0); + SASSERT(!is_square || a * (rl - eps) * (rl - eps) + b * (rl - eps) + c >= 0); + add_update(x, lh - value(x) + eps); + if (ll != rl) + add_update(x, rh - value(x) - eps); + } + else { + SASSERT(!is_square || a * (ll - eps) * (ll - eps) + b * (ll - eps) + c >= 0); + SASSERT(!is_square || a * (rl + eps) * (rl + eps) + b * (rl + eps) + c >= 0); + add_update(x, ll - value(x) - eps); + if (ll != rl) + add_update(x, rl - value(x) + eps); + } + break; + case ineq_kind::EQ: + SASSERT(sum == 0); + SASSERT(!is_square || a * (value(x) + 1) * (value(x) + 1) + b * (value(x) + 1) + c != 0); + SASSERT(!is_square || a * (value(x) - 1) * (value(x) - 1) + b * (value(x) - 1) + c != 0); + add_update(x, num_t(1) - value(x)); + add_update(x, num_t(-1) - value(x)); + break; + } + } + else { + switch (ineq.m_op) { + case ineq_kind::LE: + SASSERT(sum > 0); + if (d == 0) { + SASSERT(!is_square || !is_int(x) || a <= 0 || ll != lh || a * ll * ll + b * ll + c <= 0); + if (a > 0 && ll == lh) + add_update(x, ll - value(x)); + break; + } + SASSERT(d > 0); + if (a > 0) { + if (a * lh * lh + b * lh + c > 0) + lh += eps; + if (a * rl * rl + b * rl + c > 0) + rl -= eps; + SASSERT(!is_square || a * lh * lh + b * lh + c <= 0); + SASSERT(!is_square || a * rl * rl + b * rl + c <= 0); + add_update(x, lh - value(x)); + add_update(x, rl - value(x)); + } + else { + if (a * ll * ll + b * ll + c > 0) + ll += eps; + if (a * rh * rh + b * rh + c > 0) + rh -= eps; + SASSERT(!is_square || a * ll * ll + b * ll + c <= 0); + SASSERT(!is_square || a * rh * rh + b * rh + c <= 0); + add_update(x, ll - value(x)); + add_update(x, rh - value(x)); + } + break; + case ineq_kind::LT: + SASSERT(sum >= 0); + SASSERT(!is_int(x)); + if (d == 0) + break; + SASSERT(d > 0); + if (a > 0) { + SASSERT(!is_square || a * (ll - eps) * (ll - eps) + b * (ll - eps) + c < 0); + SASSERT(!is_square || a * (rl + eps) * (rl + eps) + b * (rl + eps) + c < 0); + add_update(x, lh - value(x) - eps); + if (ll != rl) + add_update(x, rh - value(x) + eps); + } + else { + SASSERT(!is_square || a* (ll + eps)* (ll + eps) + b * (ll + eps) + c < 0); + SASSERT(!is_square || a* (rl - eps)* (rl - eps) + b * (rl - eps) + c < 0); + add_update(x, ll - value(x) + eps); + if (ll != rl) + add_update(x, rl - value(x) - eps); + } + break; + case ineq_kind::EQ: + SASSERT(sum != 0); + if (!is_square) + break; + if (ll == lh) + add_update(x, ll - value(x)); + if (rl == rh && lh != rh) + add_update(x, rl - value(x)); + break; + } + } + } + + template + void arith_base::find_linear_moves(ineq const& ineq, var_t v, num_t const& coeff, num_t const& sum) { + if (ineq.is_true()) { + switch (ineq.m_op) { + case ineq_kind::LE: + SASSERT(sum <= 0); + add_update(v, divide(v, -sum + 1, coeff)); + break; + case ineq_kind::LT: + SASSERT(sum < 0); + add_update(v, divide(v, -sum, coeff)); + break; + case ineq_kind::EQ: { + SASSERT(sum == 0); + add_update(v, num_t(1)); + add_update(v, num_t(- 1)); + break; + } + default: + UNREACHABLE(); + break; + } + } + else { + switch (ineq.m_op) { + case ineq_kind::LE: + SASSERT(sum > 0); + add_update(v, - divide(v, sum, coeff)); + break; + case ineq_kind::LT: + SASSERT(sum >= 0); + add_update(v, - divide(v, sum + 1, coeff)); + break; + case ineq_kind::EQ: { + num_t delta = sum; + SASSERT(sum != 0); + delta = sum < 0 ? divide(v, abs(sum), coeff) : -divide(v, sum, coeff); + if (sum + coeff * delta == 0) + add_update(v, delta); + break; + } + default: + UNREACHABLE(); + break; + } + } + } + + template + bool arith_base::is_permitted_update(var_t v, num_t const& delta, num_t & delta_out) { + auto& vi = m_vars[v]; + + delta_out = delta; + + if (m_last_var == v && m_last_delta == -delta) + return false; + + if (m_use_tabu && vi.is_tabu(m_stats.m_num_steps, delta)) + return false; + + auto old_value = value(v); + auto new_value = old_value + delta; + if (!vi.in_range(new_value)) + return false; + + if (m_use_tabu && !in_bounds(v, new_value) && in_bounds(v, old_value)) { + auto const& lo = m_vars[v].m_lo; + auto const& hi = m_vars[v].m_hi; + if (lo && (lo->is_strict ? lo->value >= new_value : lo->value > new_value)) { + if (lo->is_strict && delta_out < 0 && lo->value <= old_value) { + num_t eps(1); + if (hi && hi->value - lo->value <= eps) + eps = (hi->value - lo->value) / num_t(2); + delta_out = lo->value - old_value + eps; + } + else if (!lo->is_strict && delta_out < 0 && lo->value < old_value) + delta_out = lo->value - old_value; + else + return false; + } + if (hi && (hi->is_strict ? hi->value <= new_value : hi->value < new_value)) { + if (hi->is_strict && delta_out >= 0 && hi->value >= old_value) { + num_t eps(1); + if (lo && hi->value - lo->value <= eps) + eps = (hi->value - lo->value) / num_t(2); + delta_out = hi->value - old_value - eps; + } + else if (!hi->is_strict && delta_out > 0 && hi->value > old_value) + delta_out = hi->value - old_value; + else + return false; + } + } + return delta_out != 0; + } + + template + void arith_base::add_update(var_t v, num_t delta) { + num_t delta_out; + if (!is_permitted_update(v, delta, delta_out)) + return; + + + m_updates.push_back({ v, delta_out, 0 }); + } + + // flip on the first positive score + // it could be changed to flip on maximal positive score + // or flip on maximal non-negative score + // or flip on first non-negative score + + // prefer maximal score + // prefer v/delta with oldest occurrence with same direction + // + + template + bool arith_base::apply_update() { + + while (m_updates.size() > m_updates_max_size) { + auto idx = ctx.rand(m_updates.size()); + m_updates[idx] = m_updates.back(); + m_updates.pop_back(); + } + + for (auto & [v, delta, score] : m_updates) + score = compute_score(v, delta); + + double sum_score = 0; + + for (auto const& [v, delta, score] : m_updates) + sum_score += score; + + while (!m_updates.empty()) { + + unsigned i = m_updates.size(); + double lim = sum_score * ((double)ctx.rand() / random_gen().max_value()); + do { + lim -= m_updates[--i].m_score; + } while (lim >= 0 && i > 0); + + auto [v, delta, score] = m_updates[i]; + + num_t new_value = value(v) + delta; + + + if (update(v, new_value)) { + m_last_delta = delta; + m_stats.m_num_steps++; + m_vars[v].set_step(m_stats.m_num_steps, m_stats.m_num_steps + 3 + ctx.rand(10), delta); + return true; + } + sum_score -= score; + m_updates[i] = m_updates.back(); + m_updates.pop_back(); + } + return false; + } + + template + bool arith_base::find_lin_moves(sat::literal lit) { + m_updates.reset(); + auto* ineq = atom(lit.var()); + num_t a, b; + if (!ineq) + return false; + if (!ineq->m_is_linear) { + for (auto const& [coeff, x] : ineq->m_args) { + if (is_fixed(x)) + continue; + find_linear_moves(*ineq, x, coeff, ineq->m_args_value); + } + } + return apply_update(); + } + + template + bool arith_base::repair(sat::literal lit) { + + //verbose_stream() << "repair " << lit << " " << (ctx.is_unit(lit)?"unit":"") << " " << mk_bounded_pp(ctx.atom(lit.var()), m) << "\n"; + //verbose_stream() << *atom(lit.var()) << "\n"; + m_last_literal = lit; + if (find_nl_moves(lit)) + return true; + + flet _tabu(m_use_tabu, false); + if (false && find_nl_moves(lit)) + return true; + if (false && find_lin_moves(lit)) + return true; + return find_reset_moves(lit); + } + + template + num_t arith_base::compute_dts(unsigned cl) const { + num_t d(1), d2; + bool first = true; + for (auto a : ctx.get_clause(cl)) { + auto const* ineq = atom(a.var()); + if (!ineq) + continue; + d2 = dtt(a.sign(), *ineq); + if (first) + d = d2, first = false; + else + d = std::min(d, d2); + if (d == 0) + break; + } + return d; + } + + template + num_t arith_base::dts(unsigned cl, var_t v, num_t const& new_value) const { + num_t d(1), d2; + bool first = true; + for (auto lit : ctx.get_clause(cl)) { + auto const* ineq = atom(lit.var()); + if (!ineq) + continue; + d2 = dtt(lit.sign(), *ineq, v, new_value); + if (first) + d = d2, first = false; + else + d = std::min(d, d2); + if (d == 0) + break; + } + return d; + } + + + template + bool arith_base::in_bounds(var_t v, num_t const& value) { + auto const& vi = m_vars[v]; + auto const& lo = vi.m_lo; + auto const& hi = vi.m_hi; + if (lo && value < lo->value) + return false; + if (lo && lo->is_strict && value <= lo->value) + return false; + if (hi && value > hi->value) + return false; + if (hi && hi->is_strict && value >= hi->value) + return false; + return true; + } + + template + bool arith_base::is_fixed(var_t v) { + auto const& vi = m_vars[v]; + auto const& lo = vi.m_lo; + auto const& hi = vi.m_hi; + return lo && hi && lo->value == hi->value && lo->value == value(v); + } + + template + bool arith_base::update(var_t v, num_t const& new_value) { + auto& vi = m_vars[v]; + expr* e = vi.m_expr; + auto old_value = vi.m_value; + if (old_value == new_value) + return true; + if (!vi.in_range(new_value)) + return false; + if (!in_bounds(v, new_value) && in_bounds(v, old_value)) + return false; + + // check for overflow + try { + for (auto idx : vi.m_muls) { + auto const& [w, monomial] = m_muls[idx]; + num_t prod(1); + for (auto [w, p] : monomial) + prod *= power_of(v == w ? new_value : value(w), p); + } + } + catch (overflow_exception const&) { + return false; + } + + // IF_VERBOSE(0, display(verbose_stream(), v) << " := " << new_value << "\n"); + + + +#if 0 + if (!check_update(v, new_value)) + return false; + apply_checked_update(); +#else + + for (auto const& [coeff, bv] : vi.m_bool_vars) { + auto& ineq = *atom(bv); + bool old_sign = sign(bv); + sat::literal lit(bv, old_sign); + SASSERT(ctx.is_true(lit)); + ineq.m_args_value += coeff * (new_value - old_value); + num_t dtt_new = dtt(old_sign, ineq); + // verbose_stream() << "dtt " << lit << " " << ineq << " " << dtt_new << "\n"; + if (dtt_new != 0) + ctx.flip(bv); + SASSERT(dtt(sign(bv), ineq) == 0); + } + IF_VERBOSE(5, verbose_stream() << "repair: v" << v << " := " << old_value << " -> " << new_value << "\n"); + vi.m_value = new_value; + ctx.new_value_eh(e); + m_last_var = v; + + IF_VERBOSE(10, verbose_stream() << "new value eh " << mk_bounded_pp(e, m) << "\n"); + + for (auto idx : vi.m_muls) + ctx.new_value_eh(m_vars[m_muls[idx].m_var].m_expr); + for (auto idx : vi.m_adds) + ctx.new_value_eh(m_vars[m_adds[idx].m_var].m_expr); + + for (auto idx : vi.m_muls) { + auto const& [w, monomial] = m_muls[idx]; + num_t prod(1); + try { + for (auto [w, p] : monomial) + prod *= power_of(value(w), p); + } + catch (overflow_exception const&) { + return false; + } + if (value(w) != prod && !update(w, prod)) + return false; + + } + + for (auto idx : vi.m_adds) { + auto const& ad = m_adds[idx]; + auto w = ad.m_var; + num_t sum(ad.m_coeff); + for (auto const& [coeff, w] : ad.m_args) + sum += coeff * value(w); + if (!update(ad.m_var, sum)) + return false; + } +#endif + + return true; + } + + template + bool arith_base::check_update(var_t v, num_t new_value) { + + ++m_update_timestamp; + if (m_update_timestamp == 0) { + for (auto& vi : m_vars) + vi.set_update_value(num_t(0), 0); + ++m_update_timestamp; + } + auto& vi = m_vars[v]; + m_update_trail.reset(); + m_update_trail.push_back(v); + vi.set_update_value(new_value, m_update_timestamp); + + num_t delta; + for (unsigned i = 0; i < m_update_trail.size(); ++i) { + auto v = m_update_trail[i]; + auto& vi = m_vars[v]; + for (auto idx : vi.m_muls) { + auto const& [w, monomial] = m_muls[idx]; + num_t prod(1); + try { + for (auto [w, p] : monomial) + prod *= power_of(get_update_value(w), p); + } + catch (overflow_exception const&) { + return false; + } + if (get_update_value(w) != prod && (!is_permitted_update(w, prod - value(w), delta) || prod - value(w) != delta)) + return false; + m_update_trail.push_back(w); + m_vars[w].set_update_value(prod, m_update_timestamp); + } + + for (auto idx : vi.m_adds) { + auto const& ad = m_adds[idx]; + auto w = ad.m_var; + num_t sum(ad.m_coeff); + for (auto const& [coeff, w] : ad.m_args) + sum += coeff * get_update_value(w); + if (get_update_value(v) != sum && !(is_permitted_update(w, sum - value(w), delta) || sum - value(w) != delta)) + return false; + m_update_trail.push_back(w); + m_vars[w].set_update_value(sum, m_update_timestamp); + } + } + return true; + } + + template + void arith_base::apply_checked_update() { + for (auto v : m_update_trail) { + auto & vi = m_vars[v]; + auto old_value = vi.m_value; + vi.m_value = vi.get_update_value(m_update_timestamp); + auto new_value = vi.m_value; + ctx.new_value_eh(vi.m_expr); + for (auto const& [coeff, bv] : vi.m_bool_vars) { + auto& ineq = *atom(bv); + bool old_sign = sign(bv); + sat::literal lit(bv, old_sign); + SASSERT(ctx.is_true(lit)); + ineq.m_args_value += coeff * (new_value - old_value); + num_t dtt_new = dtt(old_sign, ineq); + // verbose_stream() << "dtt " << lit << " " << ineq << " " << dtt_new << "\n"; + if (dtt_new != 0) + ctx.flip(bv); + SASSERT(dtt(sign(bv), ineq) == 0); + } + } + } + + template + typename arith_base::ineq& arith_base::new_ineq(ineq_kind op, num_t const& coeff) { + auto* i = alloc(ineq); + i->m_coeff = coeff; + i->m_op = op; + return *i; + } + + template + void arith_base::add_arg(linear_term& ineq, num_t const& c, var_t v) { + if (c != 0) + ineq.m_args.push_back({ c, v }); + } + + template<> + bool arith_base>::is_num(expr* e, checked_int64& i) { + rational r; + if (a.is_extended_numeral(e, r)) { + if (!r.is_int64()) + throw overflow_exception(); + i = r.get_int64(); + return true; + } + return false; + } + + template<> + bool arith_base::is_num(expr* e, rational& i) { + return a.is_extended_numeral(e, i); + } + + template + bool arith_base::is_num(expr* e, num_t& i) { + UNREACHABLE(); + return false; + } + + template<> + expr_ref arith_base::from_num(sort* s, rational const& n) { + return expr_ref(a.mk_numeral(n, s), m); + } + + template<> + expr_ref arith_base>::from_num(sort* s, checked_int64 const& n) { + return expr_ref(a.mk_numeral(rational(n.get_int64(), rational::i64()), s), m); + } + + template + expr_ref arith_base::from_num(sort* s, num_t const& n) { + UNREACHABLE(); + return expr_ref(m); + } + + template + void arith_base::add_args(linear_term& term, expr* e, num_t const& coeff) { + auto v = m_expr2var.get(e->get_id(), UINT_MAX); + expr* x, * y; + num_t i; + if (v != UINT_MAX) + add_arg(term, coeff, v); + else if (is_num(e, i)) + term.m_coeff += coeff * i; + else if (a.is_add(e)) { + for (expr* arg : *to_app(e)) + add_args(term, arg, coeff); + } + else if (a.is_sub(e, x, y)) { + add_args(term, x, coeff); + add_args(term, y, -coeff); + } + else if (a.is_mul(e, x, y) && is_num(x, i)) { + add_args(term, y, i * coeff); + } + else if (a.is_mul(e)) { + unsigned_vector ms; + for (expr* arg : *to_app(e)) + ms.push_back(mk_term(arg)); + + switch (ms.size()) { + case 0: + term.m_coeff += coeff; + break; + case 1: + add_arg(term, coeff, ms[0]); + break; + default: { + v = mk_var(e); + unsigned idx = m_muls.size(); + std::stable_sort(ms.begin(), ms.end(), [&](unsigned a, unsigned b) { return a < b; }); + svector> mp; + for (unsigned i = 0; i < ms.size(); ++i) { + auto w = ms[i]; + auto p = 1; + while (i + 1 < ms.size() && ms[i + 1] == w) + ++p, ++i; + mp.push_back({ w, p }); + } + m_muls.push_back({ v, mp }); + num_t prod(1); + for (auto [w, p] : mp) + m_vars[w].m_muls.push_back(idx), prod *= power_of(value(w), p); + m_vars[v].m_def_idx = idx; + m_vars[v].m_op = arith_op_kind::OP_MUL; + m_vars[v].m_value = prod; + add_arg(term, coeff, v); + break; + } + } + } + else if (a.is_uminus(e, x)) + add_args(term, x, -coeff); + else if (a.is_mod(e, x, y) || a.is_mod0(e, x, y)) + add_arg(term, coeff, mk_op(arith_op_kind::OP_MOD, e, x, y)); + else if (a.is_idiv(e, x, y) || a.is_idiv0(e, x, y)) + add_arg(term, coeff, mk_op(arith_op_kind::OP_IDIV, e, x, y)); + else if (a.is_div(e, x, y) || a.is_div0(e, x, y)) + add_arg(term, coeff, mk_op(arith_op_kind::OP_DIV, e, x, y)); + else if (a.is_rem(e, x, y)) + add_arg(term, coeff, mk_op(arith_op_kind::OP_REM, e, x, y)); + else if (a.is_power(e, x, y) || a.is_power0(e, x, y)) + add_arg(term, coeff, mk_op(arith_op_kind::OP_POWER, e, x, y)); + else if (a.is_abs(e, x)) + add_arg(term, coeff, mk_op(arith_op_kind::OP_ABS, e, x, x)); + else if (a.is_to_int(e, x)) + add_arg(term, coeff, mk_op(arith_op_kind::OP_TO_INT, e, x, x)); + else if (a.is_to_real(e, x)) + add_arg(term, coeff, mk_op(arith_op_kind::OP_TO_REAL, e, x, x)); + else if (a.is_arith_expr(e)) { + NOT_IMPLEMENTED_YET(); + } + else + add_arg(term, coeff, mk_var(e)); + } + + template + typename arith_base::var_t arith_base::mk_op(arith_op_kind k, expr* e, expr* x, expr* y) { + auto v = mk_var(e); + auto w = mk_term(x); + unsigned idx = m_ops.size(); + num_t val; + switch (k) { + case arith_op_kind::OP_MOD: + val = value(v) == 0 ? num_t(0) : mod(value(w), value(v)); + break; + case arith_op_kind::OP_REM: + if (value(v) == 0) + val = 0; + else { + val = value(w); + val %= value(v); + } + break; + case arith_op_kind::OP_IDIV: + val = value(v) == 0 ? num_t(0): div(value(w), value(v)); + break; + case arith_op_kind::OP_DIV: + val = value(v) == 0? num_t(0) : value(w) / value(v); + break; + case arith_op_kind::OP_ABS: + val = abs(value(w)); + break; + default: + NOT_IMPLEMENTED_YET(); + break; + } + verbose_stream() << "mk-op " << mk_bounded_pp(e, m) << "\n"; + m_ops.push_back({v, k, v, w}); + m_vars[v].m_def_idx = idx; + m_vars[v].m_op = k; + m_vars[v].m_value = val; + return v; + } + + template + typename arith_base::var_t arith_base::mk_term(expr* e) { + auto v = m_expr2var.get(e->get_id(), UINT_MAX); + if (v != UINT_MAX) + return v; + linear_term t; + add_args(t, e, num_t(1)); + if (t.m_coeff == 0 && t.m_args.size() == 1 && t.m_args[0].first == 1) + return t.m_args[0].second; + v = mk_var(e); + auto idx = m_adds.size(); + num_t sum(t.m_coeff); + m_adds.push_back({ { t.m_args, t.m_coeff }, v }); + for (auto const& [c, w] : t.m_args) + m_vars[w].m_adds.push_back(idx), sum += c * value(w); + m_vars[v].m_def_idx = idx; + m_vars[v].m_op = arith_op_kind::OP_ADD; + m_vars[v].m_value = sum; + return v; + } + + template + typename arith_base::var_t arith_base::mk_var(expr* e) { + var_t v = m_expr2var.get(e->get_id(), UINT_MAX); + if (v == UINT_MAX) { + // verbose_stream() << "mk-var " << mk_bounded_pp(e, m) << "\n"; + v = m_vars.size(); + m_expr2var.setx(e->get_id(), v, UINT_MAX); + m_vars.push_back(var_info(e, a.is_int(e) ? var_sort::INT : var_sort::REAL)); + } + return v; + } + + template + void arith_base::init_bool_var(sat::bool_var bv) { + expr* e = ctx.atom(bv); + if (m_bool_vars.get(bv, nullptr)) + return; + if (!e) + return; + expr* x, * y; + m_bool_vars.reserve(bv + 1); + if (a.is_le(e, x, y) || a.is_ge(e, y, x)) { + auto& ineq = new_ineq(ineq_kind::LE, num_t(0)); + add_args(ineq, x, num_t(1)); + add_args(ineq, y, num_t(-1)); + init_ineq(bv, ineq); + } + else if ((a.is_lt(e, x, y) || a.is_gt(e, y, x)) && a.is_int(x)) { + auto& ineq = new_ineq(ineq_kind::LE, num_t(1)); + add_args(ineq, x, num_t(1)); + add_args(ineq, y, num_t(-1)); + init_ineq(bv, ineq); + } + else if ((a.is_lt(e, x, y) || a.is_gt(e, y, x)) && a.is_real(x)) { + auto& ineq = new_ineq(ineq_kind::LT, num_t(0)); + add_args(ineq, x, num_t(1)); + add_args(ineq, y, num_t(-1)); + init_ineq(bv, ineq); + } + else if (m.is_eq(e, x, y) && a.is_int_real(x)) { + auto& ineq = new_ineq(ineq_kind::EQ, num_t(0)); + add_args(ineq, x, num_t(1)); + add_args(ineq, y, num_t(-1)); + init_ineq(bv, ineq); + } + else if (m.is_distinct(e) && a.is_int_real(to_app(e)->get_arg(0))) { + NOT_IMPLEMENTED_YET(); + } + else if (a.is_is_int(e, x)) + { + NOT_IMPLEMENTED_YET(); + } +#if 0 + else if (a.is_idivides(e, x, y)) + NOT_IMPLEMENTED_YET(); +#endif + else { + SASSERT(!a.is_arith_expr(e)); + } + } + + template + void arith_base::init_ineq(sat::bool_var bv, ineq& i) { + + // ensure that variables are unique in the linear term: + std::stable_sort(i.m_args.begin(), i.m_args.end(), [&](auto const& a, auto const& b) { return a.second < b.second; }); + unsigned k = 0; + for (unsigned j = 0; j < i.m_args.size(); ++j) { + if (j > k && i.m_args[k].second == i.m_args[j].second) + i.m_args[k].first += i.m_args[j].first; + else + i.m_args[k++] = i.m_args[j]; + } + i.m_args.shrink(k); + i.m_monomials.reserve(k); + for (unsigned j = 0; j < i.m_args.size(); ++j) { + auto const& [c, v] = i.m_args[j]; + if (is_mul(v)) + i.m_monomials[j].append(get_mul(v).m_monomial); + else + i.m_monomials[j].push_back({ v, 1 }); + } + // compute the value of the linear term, and accumulate non-linear sub-terms + i.m_args_value = i.m_coeff; + for (auto const& [coeff, v] : i.m_args) { + m_vars[v].m_bool_vars.push_back({ coeff, bv }); + i.m_args_value += coeff * value(v); + if (is_mul(v)) { + auto const& [w, monomial] = get_mul(v); + for (auto [w, p] : monomial) + i.m_nonlinear.push_back({ w, { {v, coeff, p} } }); + i.m_is_linear = false; + } + else + i.m_nonlinear.push_back({ v, { { v, coeff, 1 } } }); + } + std::stable_sort(i.m_nonlinear.begin(), i.m_nonlinear.end(), [&](auto const& a, auto const& b) { return a.first < b.first; }); + + // ensure that non-linear terms are have a unique summary. + k = 0; + for (unsigned j = 0; j < i.m_nonlinear.size(); ++j) { + if (j > k && i.m_nonlinear[k].first == i.m_nonlinear[j].first) + i.m_nonlinear[k].second.append(i.m_nonlinear[j].second); + else + i.m_nonlinear[k++] = i.m_nonlinear[j]; + } + i.m_nonlinear.shrink(k); + + // Ensure that non-linear term occurrences are sorted, and + // that terms with the same variable are combined. + for (auto& [x, nl] : i.m_nonlinear) { + if (nl.size() == 1) + continue; + std::stable_sort(nl.begin(), nl.end(), [&](auto const& a, auto const& b) { return a.p < b.p; }); + k = 0; + for (unsigned j = 0; j < nl.size(); ++j) { + if (j > k && nl[k].v == nl[j].v) + nl[k].coeff += nl[j].coeff; + else + nl[k++] = nl[j]; + } + nl.shrink(k); + } + + // attach i to bv + m_bool_vars.set(bv, &i); + } + + template + void arith_base::init_bool_var_assignment(sat::bool_var v) { + auto* ineq = atom(v); + if (ineq && ineq->is_true() != ctx.is_true(v)) + ctx.flip(v); + } + + template + void arith_base::propagate_literal(sat::literal lit) { + if (!ctx.is_true(lit)) + return; + auto const* ineq = atom(lit.var()); + if (!ineq) + return; + if (ineq->is_true() != lit.sign()) + return; + repair(lit); + } + + template + void arith_base::repair_literal(sat::literal lit) { + init_bool_var_assignment(lit.var()); + } + + template + bool arith_base::propagate() { + // m_last_var = UINT_MAX; // allow to change last variable. + return false; + } + + + template + num_t arith_base::value1(var_t v) { + auto const& vi = m_vars[v]; + if (vi.m_def_idx == UINT_MAX) + return value(v); + + num_t result, v1, v2; + switch (vi.m_op) { + case LAST_ARITH_OP: + break; + case OP_ADD: { + auto const& ad = m_adds[vi.m_def_idx]; + auto const& args = ad.m_args; + result = ad.m_coeff; + for (auto [c, w] : args) + result += c * value(w); + break; + } + case OP_MUL: { + auto const& [w, monomial] = m_muls[vi.m_def_idx]; + result = num_t(1); + for (auto [w, p] : monomial) + result *= power_of(value(w), p); + break; + } + case OP_MOD: + v1 = value(m_ops[vi.m_def_idx].m_arg1); + v2 = value(m_ops[vi.m_def_idx].m_arg2); + result = v2 == 0 ? num_t(0) : mod(v1, v2); + break; + case OP_DIV: + v1 = value(m_ops[vi.m_def_idx].m_arg1); + v2 = value(m_ops[vi.m_def_idx].m_arg2); + result = v2 == 0 ? num_t(0) : v1 / v2; + break; + case OP_IDIV: + v1 = value(m_ops[vi.m_def_idx].m_arg1); + v2 = value(m_ops[vi.m_def_idx].m_arg2); + result = v2 == 0 ? num_t(0) : div(v1, v2); + break; + case OP_REM: + v1 = value(m_ops[vi.m_def_idx].m_arg1); + v2 = value(m_ops[vi.m_def_idx].m_arg2); + result = v2 == 0 ? num_t(0) : v1 %= v2; + break; + case OP_ABS: + result = abs(value(m_ops[vi.m_def_idx].m_arg1)); + break; + default: + NOT_IMPLEMENTED_YET(); + } + return result; + } + + template + void arith_base::repair_up(app* e) { + if (m.is_bool(e)) { + auto v = ctx.atom2bool_var(e); + auto const* ineq = atom(v); + if (ineq && ineq->is_true() != ctx.is_true(v)) + ctx.flip(v); + return; + } + auto v = m_expr2var.get(e->get_id(), UINT_MAX); + if (v == UINT_MAX) + return; + auto const& vi = m_vars[v]; + if (vi.m_def_idx == UINT_MAX) + return; + auto new_value = value1(v); + if (!update(v, new_value)) + ctx.new_value_eh(e); + } + + template + bool arith_base::repair_down(app* e) { + auto v = m_expr2var.get(e->get_id(), UINT_MAX); + if (v == UINT_MAX) + return false; + auto const& vi = m_vars[v]; + if (vi.m_def_idx == UINT_MAX) + return false; + flet _tabu(m_use_tabu, false); + TRACE("sls", tout << "repair def " << mk_bounded_pp(vi.m_expr, m) << "\n"); + // verbose_stream() << "repair down " << mk_bounded_pp(e, m) << "\n"; + switch (vi.m_op) { + case arith_op_kind::LAST_ARITH_OP: + break; + case arith_op_kind::OP_ADD: + return repair_add(m_adds[vi.m_def_idx]); + case arith_op_kind::OP_MUL: + return repair_mul(m_muls[vi.m_def_idx]); + case arith_op_kind::OP_MOD: + return repair_mod(m_ops[vi.m_def_idx]); + case arith_op_kind::OP_REM: + return repair_rem(m_ops[vi.m_def_idx]); + case arith_op_kind::OP_POWER: + return repair_power(m_ops[vi.m_def_idx]); + case arith_op_kind::OP_IDIV: + return repair_idiv(m_ops[vi.m_def_idx]); + case arith_op_kind::OP_DIV: + return repair_div(m_ops[vi.m_def_idx]); + case arith_op_kind::OP_ABS: + return repair_abs(m_ops[vi.m_def_idx]); + case arith_op_kind::OP_TO_INT: + return repair_to_int(m_ops[vi.m_def_idx]); + case arith_op_kind::OP_TO_REAL: + return repair_to_real(m_ops[vi.m_def_idx]); + default: + NOT_IMPLEMENTED_YET(); + } + return true; + } + + template + void arith_base::initialize() { + for (auto lit : ctx.unit_literals()) + initialize_unit(lit); + for (unsigned v = 0; v < m_vars.size(); ++v) { + auto const& vi = m_vars[v]; + if (vi.m_lo || vi.m_hi) + continue; + expr* e = vi.m_expr; + if (is_add(v)) { + auto const& ad = get_add(v); + num_t lo(ad.m_coeff), hi(ad.m_coeff); + bool lo_valid = true, hi_valid = true; + bool lo_strict = false, hi_strict = false; + for (auto const& [c, w] : ad.m_args) { + if (!lo_valid && !hi_valid) + break; + auto const& wi = m_vars[w]; + if (lo_valid) { + if (c > 0 && wi.m_lo) + lo += c * wi.m_lo->value, + lo_strict |= wi.m_lo->is_strict; + else if (c < 0 && wi.m_hi) + lo += c * wi.m_hi->value, + lo_strict |= wi.m_hi->is_strict; + else + lo_valid = false; + } + if (hi_valid) { + if (c > 0 && wi.m_hi) + hi += c * wi.m_hi->value, + hi_strict |= wi.m_hi->is_strict; + else if (c < 0 && wi.m_lo) + hi += c * wi.m_lo->value, + hi_strict |= wi.m_lo->is_strict; + else + hi_valid = false; + } + } + if (lo_valid) { + if (lo_strict) + add_gt(v, lo); + else + add_ge(v, lo); + } + if (hi_valid) { + if (hi_strict) + add_lt(v, hi); + else + add_le(v, hi); + } + } + if (is_mul(v)) { + auto const& [w, monomial] = get_mul(v); + num_t lo(1), hi(1); + bool lo_valid = true, hi_valid = true; + bool lo_strict = false, hi_strict = false; + for (auto [w, p] : monomial) { + if (!lo_valid) + break; + auto const& wi = m_vars[w]; + if (wi.m_lo && !wi.m_lo->is_strict && wi.m_lo->value >= 0) + lo *= power_of(wi.m_lo->value, p); + else + lo_valid = false; + } + for (auto [w, p] : monomial) { + if (!lo_valid && !hi_valid) + break; + auto const& wi = m_vars[w]; + try { + if (wi.m_hi && !wi.m_hi->is_strict) + hi *= power_of(wi.m_hi->value, p); + else + hi_valid = false; + } + catch (overflow_exception&) { + hi_valid = false; + } + } + if (lo_valid) { + if (lo_strict) + add_gt(v, lo); + else + add_ge(v, lo); + } + if (lo_valid && hi_valid) { + if (hi_strict) + add_lt(v, hi); + else + add_le(v, hi); + } + } + expr* c, * th, * el; + if (m.is_ite(e, c, th, el)) { + auto vth = m_expr2var.get(th->get_id(), UINT_MAX); + auto vel = m_expr2var.get(el->get_id(), UINT_MAX); + if (vth == UINT_MAX || vel == UINT_MAX) + continue; + auto const& vith = m_vars[vth]; + auto const& viel = m_vars[vel]; + if (vith.m_lo && viel.m_lo && !vith.m_lo->is_strict && !viel.m_lo->is_strict) + add_ge(v, std::min(vith.m_lo->value, viel.m_lo->value)); + if (vith.m_hi && viel.m_hi && !vith.m_hi->is_strict && !viel.m_hi->is_strict) + add_le(v, std::max(vith.m_hi->value, viel.m_hi->value)); + + } + switch (vi.m_op) { + case LAST_ARITH_OP: + case OP_ADD: + case OP_MUL: + break; + case OP_MOD: { + auto v2 = m_ops[vi.m_def_idx].m_arg2; + auto const& vi2 = m_vars[v2]; + if (vi2.m_lo && vi2.m_hi && vi2.m_lo->value == vi2.m_hi->value && vi2.m_lo->value > 0) { + add_le(v, vi2.m_lo->value - 1); + add_ge(v, num_t(0)); + } + break; + } + case OP_DIV: + break; + case OP_IDIV: + break; + case OP_REM: + break; + case OP_ABS: + add_ge(v, num_t(0)); + break; + default: + NOT_IMPLEMENTED_YET(); + + } + // TBD: can also do with other operators. + } + } + + template + void arith_base::initialize_unit(sat::literal lit) { + init_bool_var(lit.var()); + auto* ineq = atom(lit.var()); + if (!ineq) + return; + + if (ineq->m_args.size() != 1) + return; + auto [c, v] = ineq->m_args[0]; + + switch (ineq->m_op) { + case ineq_kind::LE: + if (lit.sign()) { + if (c == -1) // -x + c >= 0 <=> c >= x + add_le(v, ineq->m_coeff); + else if (c == 1) // x + c >= 0 <=> x >= -c + add_ge(v, -ineq->m_coeff); + else + verbose_stream() << "INITIALIZE " << lit << " " << *ineq << "\n"; + } + else { + if (c == -1) + add_ge(v, ineq->m_coeff); + else if (c == 1) + add_le(v, -ineq->m_coeff); + else + verbose_stream() << "INITIALIZE " << lit << " " << *ineq << "\n"; + } + break; + case ineq_kind::EQ: + if (!lit.sign()) { + if (c == -1) { + add_ge(v, ineq->m_coeff); + add_le(v, ineq->m_coeff); + } + else if (c == 1) { + add_ge(v, -ineq->m_coeff); + add_le(v, -ineq->m_coeff); + } + else + verbose_stream() << "INITIALIZE " << lit << " " << *ineq << "\n"; + } + break; + case ineq_kind::LT: + + if (lit.sign()) { + if (c == -1) // -x + c >= 0 <=> c >= x + add_le(v, ineq->m_coeff); + else if (c == 1) // x + c >= 0 <=> x >= -c + add_ge(v, -ineq->m_coeff); + else + verbose_stream() << "INITIALIZE " << lit << " " << *ineq << "\n"; + } + else { + if (c == -1) + add_gt(v, ineq->m_coeff); + else if (c == 1) + add_lt(v, -ineq->m_coeff); + else + verbose_stream() << "INITIALIZE " << lit << " " << *ineq << "\n"; + } + break; + } + } + + template + void arith_base::add_le(var_t v, num_t const& n) { + if (m_vars[v].m_hi && m_vars[v].m_hi->value <= n) + return; + m_vars[v].m_hi = { false, n }; + } + + template + void arith_base::add_ge(var_t v, num_t const& n) { + if (m_vars[v].m_lo && m_vars[v].m_lo->value >= n) + return; + m_vars[v].m_lo = { false, n }; + } + + template + void arith_base::add_lt(var_t v, num_t const& n) { + if (is_int(v)) + add_le(v, n - 1); + else + m_vars[v].m_hi = { true, n }; + } + + template + void arith_base::add_gt(var_t v, num_t const& n) { + if (is_int(v)) + add_ge(v, n + 1); + else + m_vars[v].m_lo = { true, n }; + } + + template + bool arith_base::repair_add(add_def const& ad) { + auto v = ad.m_var; + auto old_value = value(v); + auto const& coeffs = ad.m_args; + num_t sum(ad.m_coeff); + + for (auto const& [c, w] : coeffs) + sum += c * value(w); + + if (old_value == sum) + return true; + + //display(verbose_stream() << "repair add v" << v << " ", ad) << " " << old_value << " sum " << sum << "\n"; + + m_updates.reset(); +// display(verbose_stream(), v) << " "; +// verbose_stream() << mk_bounded_pp(m_vars[v].m_expr, m) << " := " << old_value << " " << sum << "\n"; + + for (auto const& [coeff, w] : coeffs) { + auto delta = divide(w, sum - old_value, coeff); + if (sum == coeff*delta + old_value) + add_update(w, delta); + } + if (apply_update()) + return eval_is_correct(v); + + flet _use_tabu(m_use_tabu, false); + + m_updates.reset(); + for (auto const& [coeff, w] : coeffs) { + auto delta = divide(w, sum - old_value, coeff); + if (sum != coeff*delta + old_value) + add_update(w, delta); + } + for (auto const& [coeff, w] : coeffs) + add_reset_update(w); + + if (apply_update()) + return eval_is_correct(v); + + return update(v, sum); + } + + template + bool arith_base::repair_mul(mul_def const& md) { + auto const& [v, monomial] = md; + num_t product(1); + num_t val = value(v); + for (auto [v, p]: monomial) + product *= power_of(value(v), p); + if (product == val) + return true; + IF_VERBOSE(10, verbose_stream() << "v" << v << " repair mul " << mk_bounded_pp(m_vars[v].m_expr, m) << " : = " << val << " (product : " << product << ")\n"); + + + m_updates.reset(); + if (val == 0) { + for (auto [x, p] : monomial) + add_update(x, -value(x)); + } + else if (val == 1 || val == -1) { + for (auto [x, p] : monomial) { + add_update(x, num_t(1) - value(x)); + add_update(x, num_t(-1) - value(x)); + } + } + else { + for (auto [x, p] : monomial) { + auto mx = mul_value_without(v, x); + // val / mx = x^p + if (mx == 0) + continue; + auto valmx = divide(x, val, mx); + auto r = root_of(p, valmx); + add_update(x, r - value(x)); + if (p % 2 == 0) + add_update(x, -r - value(x)); + } + } + + // verbose_stream() << "repair product v" << v << "\n"; + + if (apply_update()) + return eval_is_correct(v); + + flet _use_tabu(m_use_tabu, false); + m_updates.reset(); + for (auto [x, p] : monomial) + add_reset_update(x); + + if (apply_update()) + return eval_is_correct(v); + + return update(v, product); + } + + template + bool arith_base::repair_rem(op_def const& od) { + auto v1 = value(od.m_arg1); + auto v2 = value(od.m_arg2); + if (v2 == 0) + return update(od.m_var, num_t(0)); + + + IF_VERBOSE(0, verbose_stream() << "todo repair rem"); + // bail + v1 %= v2; + return update(od.m_var, v1); + } + + template + bool arith_base::repair_abs(op_def const& od) { + auto val = value(od.m_var); + auto v1 = value(od.m_arg1); + if (val < 0) + return update(od.m_var, abs(v1)); + else if (ctx.rand(2) == 0) + return update(od.m_arg1, val); + else + return update(od.m_arg1, -val); + } + + template + bool arith_base::repair_to_int(op_def const& od) { + auto val = value(od.m_var); + auto v1 = value(od.m_arg1); + if (val - 1 < v1 && v1 <= val) + return true; + return update(od.m_arg1, val); + } + + template + bool arith_base::repair_to_real(op_def const& od) { + if (ctx.rand(20) == 0) + return update(od.m_var, value(od.m_arg1)); + else + return update(od.m_arg1, value(od.m_arg1)); + } + + template + bool arith_base::repair_power(op_def const& od) { + auto v1 = value(od.m_arg1); + auto v2 = value(od.m_arg2); + if (v1 == 0 && v2 == 0) { + return update(od.m_var, num_t(0)); + } + IF_VERBOSE(0, verbose_stream() << "todo repair ^"); + NOT_IMPLEMENTED_YET(); + return false; + } + + template + bool arith_base::repair_mod(op_def const& od) { + auto val = value(od.m_var); + auto v1 = value(od.m_arg1); + auto v2 = value(od.m_arg2); + // repair first argument + if (val >= 0 && val < v2) { + auto v3 = mod(v1, v2); + if (v3 == val) + return true; + // find r, such that mod(v1 + r, v2) = val + // v1 := v1 + val - v3 (+/- v2) + v1 += val - v3; + switch (ctx.rand(6)) { + case 0: + v1 += v2; + break; + case 1: + v1 -= v2; + break; + default: + break; + } + return update(od.m_arg1, v1); + } + return update(od.m_var, v2 == 0 ? num_t(0) : mod(v1, v2)); + } + + template + bool arith_base::repair_idiv(op_def const& od) { + auto v1 = value(od.m_arg1); + auto v2 = value(od.m_arg2); + IF_VERBOSE(0, verbose_stream() << "todo repair div"); + // bail + return update(od.m_var, v2 == 0 ? num_t(0) : div(v1, v2)); + } + + template + bool arith_base::repair_div(op_def const& od) { + auto v1 = value(od.m_arg1); + auto v2 = value(od.m_arg2); + IF_VERBOSE(0, verbose_stream() << "todo repair /"); + // bail + return update(od.m_var, v2 == 0 ? num_t(0) : v1 / v2); + } + + template + double arith_base::compute_score(var_t x, num_t const& delta) { + int result = 0; + int breaks = 0; + for (auto const& [coeff, bv] : m_vars[x].m_bool_vars) { + bool old_sign = sign(bv); + auto lit = sat::literal(bv, old_sign); + auto dtt_old = dtt(old_sign, *atom(bv)); + auto dtt_new = dtt(old_sign, *atom(bv), coeff, delta); +#if 1 + if (dtt_new == 0 && dtt_old != 0) + result += 1; + + if (dtt_new != 0 && dtt_old == 0) { + if (m_use_tabu && ctx.is_unit(lit)) + return 0; + result -= 1; + breaks += 1; + } +#else + if (dtt_new == dtt_old) + continue; + if (m_use_tabu && ctx.is_unit(lit) && dtt_new != 0) + return 0; + double reward = ctx.reward(bv); + result += reward; +#endif + } + + if (result < 0) + return 0.1; + else if (result == 0) + return 0.2; + for (int i = m_prob_break.size(); i <= breaks; ++i) + m_prob_break.push_back(std::pow(m_config.cb, -i)); + return m_prob_break[breaks]; + } + + template + num_t arith_base::mul_value_without(var_t m, var_t x) { + auto const& vi = m_vars[m]; + auto const& [w, monomial] = m_muls[vi.m_def_idx]; + SASSERT(m == w); + num_t r(1); + for (auto [y, p] : monomial) + if (x != y) + r *= power_of(value(y), p); + return r; + } + + template + bool arith_base::is_linear(var_t x, vector const& nl, num_t& b) { + if (nl.size() == 1 && nl[0].v == x) { + b = nl[0].coeff; + return true; + } + b = 0; + for (auto const& [v, c, p] : nl) { + if (p > 1) + return false; + if (x == v) + b += c; + else + b += c * mul_value_without(v, x); + } + return b != 0; + } + + template + bool arith_base::is_quadratic(var_t x, vector const& nl, num_t& a, num_t& b) { + a = 0; + b = 0; + for (auto const& [v, c, p] : nl) { + if (p == 1) { + if (x == v) + b += c; + else + b += c * mul_value_without(v, x); + } + else if (p == 2) { + SASSERT(v != x); + a += c * mul_value_without(v, x); + } + else + return false; + } + return a != 0 || b != 0; + } + + template + bool arith_base::find_nl_moves(sat::literal lit) { + m_updates.reset(); + auto* ineq = atom(lit.var()); + num_t a, b; + if (!ineq) + return false; + for (auto const& [x, nl] : ineq->m_nonlinear) { + if (is_fixed(x)) + continue; + if (is_linear(x, nl, b)) + find_linear_moves(*ineq, x, b, ineq->m_args_value); + else if (is_quadratic(x, nl, a, b)) + find_quadratic_moves(*ineq, x, a, b, ineq->m_args_value); + else + ; + } + return apply_update(); + } + + template + void arith_base::add_reset_update(var_t x) { + m_last_delta = 0; + if (is_fixed(x)) + return; + if (is_mul(x)) { + auto const& [w1, monomial] = get_mul(x); + for (auto [w1, p] : monomial) + add_reset_update(w1); + } + if (is_add(x)) { + auto const& ad = get_add(x); + for (auto [c, w] : ad.m_args) + add_reset_update(w); + } + auto const& vi = m_vars[x]; + auto const& lo = vi.m_lo; + auto const& hi = vi.m_hi; + auto new_value = num_t(-2 + (int)ctx.rand(5)); + if (lo && lo->value > new_value) + new_value = lo->value + num_t(ctx.rand(2)); + else if (hi && hi->value < new_value) + new_value = hi->value - num_t(ctx.rand(2)); + if (new_value != value(x)) + add_update(x, new_value - value(x) + num_t(-1 + (int)ctx.rand(3))); + else { + add_update(x, num_t(1) - value(x)); + add_update(x, -num_t(1) - value(x)); + if (value(x) != 0) { + add_update(x, num_t(1)); + add_update(x, -num_t(1)); + } + } + } + + template + bool arith_base::find_reset_moves(sat::literal lit) { + m_updates.reset(); + auto* ineq = atom(lit.var()); + num_t a, b; + if (!ineq) + return false; + for (auto const& [x, nl] : ineq->m_nonlinear) + add_reset_update(x); + + IF_VERBOSE(10, + if (m_updates.empty()) { + verbose_stream() << lit << ": " << * ineq << "\n"; + for (auto const& [x, nl] : ineq->m_nonlinear) { + display(verbose_stream(), x) << "\n"; + } + } + verbose_stream() << "RESET moves num updates: " << lit << " " << m_updates.size() << "\n"); + + return apply_update(); + } + + template + num_t arith_base::power_of(num_t x, unsigned k) { + num_t r(1); + while (k > 1) { + if (k % 2 == 1) { + r = x * r; + --k; + } + x = x * x; + k /= 2; + } + return x * r; + } + + // Newton function for integer n'th root of a + // x_{k+1} = 1/k ((k-1)*x_k + a / x_k^{n-1}) + template + num_t arith_base::root_of(unsigned k, num_t a) { + if (a <= 1) + return a; + if (k == 1) + return a; + if (a <= k) + return num_t(1); + SASSERT(k > 1); + + auto x0 = div(a, num_t(k)); + auto x1 = div((x0 * num_t(k - 1)) + div(a, power_of(x0, k - 1)), num_t(k)); + + while (x1 < x0) { + x0 = x1; + x1 = div((x0 * num_t(k - 1)) + div(a, power_of(x0, k - 1)), num_t(k)); + } + return x0; + } + + template + vector const& arith_base::factor(num_t n) { + m_factors.reset(); + if (n == 0) + return m_factors; + for (auto d : { 2, 3, 5 }) { + while (mod(n, num_t(d)) == 0) { + m_factors.push_back(num_t(d)); + n = div(n, num_t(d)); + } + } + static int increments[8] = { 4, 2, 4, 2, 4, 6, 2, 6 }; + unsigned i = 0, j = 0; + for (auto d = num_t(7); d * d <= n && j < 3; d += num_t(increments[i++]), ++j) { + while (mod(n, d) == 0) { + m_factors.push_back(d); + n = div(n, d); + } + if (i == 8) + i = 0; + } + if (n > 1) + m_factors.push_back(n); + return m_factors; + } + + // switch to dscore mode + template + void arith_base::on_rescale() { + m_dscore_mode = true; + } + + template + void arith_base::on_restart() { +#if 0 + for (var_t v = 0; v < m_vars.size(); ++v) { + auto& vi = m_vars[v]; + num_t new_value; + if (vi.m_def_idx == UINT_MAX) { + auto val = value(v); + + if (ctx.rand(10) != 0) { + new_value = num_t((int)ctx.rand(2)); + if (!in_bounds(v, new_value)) + new_value = val; + } + else + new_value = val; + //verbose_stream() << v << " " << vi.m_value << " -> " << new_value << "\n"; + vi.m_value = new_value; + } + else { + vi.m_value = value1(v); + } + ctx.new_value_eh(vi.m_expr); + } + + for (sat::bool_var v = 0; v < ctx.num_bool_vars(); ++v) { + auto* ineq = atom(v); + if (!ineq) + continue; + ineq->m_args_value = ineq->m_coeff; + for (auto const& [coeff, w] : ineq->m_args) + ineq->m_args_value += coeff * value(w); + init_bool_var(v); + } +#endif + } + + template + void arith_base::check_ineqs() { + for (unsigned bv = 0; bv < ctx.num_bool_vars(); ++bv) { + auto const* ineq = atom(bv); + if (!ineq) + continue; + num_t d = dtt(sign(bv), *ineq); + sat::literal lit(bv, sign(bv)); + if (ctx.is_true(lit) != (d == 0)) { + verbose_stream() << "invalid assignment " << bv << " " << *ineq << "\n"; + } + VERIFY(ctx.is_true(lit) == (d == 0)); + } + } + + template + void arith_base::register_term(expr* _e) { + if (!is_app(_e)) + return; + app* e = to_app(_e); + auto v = ctx.atom2bool_var(e); + if (v != sat::null_bool_var) + init_bool_var(v); + if (!a.is_arith_expr(e) && !m.is_eq(e) && !m.is_distinct(e)) + for (auto arg : *e) + if (a.is_int_real(arg)) + mk_term(arg); + } + + template + bool arith_base::set_value(expr* e, expr* v) { + if (!a.is_int_real(e)) + return false; + var_t w = m_expr2var.get(e->get_id(), UINT_MAX); + if (w == UINT_MAX) + w = mk_term(e); + + num_t n; + if (!is_num(v, n)) + return false; + // verbose_stream() << "set value " << w << " " << mk_bounded_pp(e, m) << " " << n << " " << value(w) << "\n"; + if (n == value(w)) + return true; + return update(w, n); + } + + template + expr_ref arith_base::get_value(expr* e) { + num_t n; + if (is_num(e, n)) + return expr_ref(a.mk_numeral(n.to_rational(), a.is_int(e)), m); + auto v = mk_term(e); + return expr_ref(a.mk_numeral(m_vars[v].m_value.to_rational(), a.is_int(e)), m); + } + + template + bool arith_base::is_sat() { + invariant(); + for (auto const& clause : ctx.clauses()) { + bool sat = false; + for (auto lit : clause.m_clause) { + if (!ctx.is_true(lit)) + continue; + auto ineq = atom(lit.var()); + if (!ineq) { + sat = true; + break; + } + if (ineq->is_true() != lit.sign()) { + sat = true; + break; + } + } + if (sat) + continue; + verbose_stream() << "not sat:\n"; + verbose_stream() << clause << "\n"; + for (auto lit : clause.m_clause) { + verbose_stream() << lit << " (" << ctx.is_true(lit) << ") "; + auto ineq = atom(lit.var()); + if (!ineq) + continue; + verbose_stream() << *ineq << "\n"; + for (auto const& [coeff, v] : ineq->m_args) + verbose_stream() << coeff << " " << v << " " << mk_bounded_pp(m_vars[v].m_expr, m) << " := " << value(v) << "\n"; + } + exit(0); + if (!sat) + return false; + } + return true; + } + + template + std::ostream& arith_base::display(std::ostream& out, mul_def const& md) const { + auto const& [w, monomial] = md; + bool first = true; + for (auto [v, p] : monomial) { + if (!first) + out << " * "; + out << "v" << v; + if (p > 1) + out << "^" << p; + first = false; + } + return out; + } + + template + std::ostream& arith_base::display(std::ostream& out, add_def const& ad) const { + bool first = true; + for (auto [c, w] : ad.m_args) { + if (first && c == 1) + ; + else if (first && c == -1) + out << "-"; + else if (first) + out << c << "*"; + else if (c == 1) + out << " + "; + else if (c == - 1) + out << " - "; + else if (c > 0) + out << " + " << c << "*"; + else + out << " - " << -c << "*"; + first = false; + out << "v" << w; + } + if (ad.m_args.empty()) + out << ad.m_coeff; + else if (ad.m_coeff > 0) + out << " + " << ad.m_coeff; + else if (ad.m_coeff < 0) + out << " - " << -ad.m_coeff; + return out; + } + + template + std::ostream& arith_base::display(std::ostream& out, var_t v) const { + auto const& vi = m_vars[v]; + auto const& lo = vi.m_lo; + auto const& hi = vi.m_hi; + out << "v" << v << " := " << vi.m_value << " "; + if (lo || hi) { + if (lo) + out << (lo->is_strict ? "(": "[") << lo->value; + else + out << "("; + out << " "; + if (hi) + out << hi->value << (hi->is_strict ? ")" : "]"); + else + out << ")"; + out << " "; + } + out << mk_bounded_pp(vi.m_expr, m) << " "; + if (is_add(v)) + display(out << "add: ", get_add(v)) << " "; + if (is_mul(v)) + display(out << "mul: ", get_mul(v)) << " "; + + if (!vi.m_adds.empty()) { + out << " adds: "; + for (auto v : vi.m_adds) + out << "v" << m_adds[v].m_var << " "; + out << " "; + } + + if (!vi.m_muls.empty()) { + out << " muls: "; + for (auto v : vi.m_muls) + out << "v" << m_muls[v].m_var << " "; + out << " "; + } + + if (!vi.m_bool_vars.empty()) { + out << " bool: "; + for (auto [c, bv] : vi.m_bool_vars) + out << c << "@" << bv << " "; + } + return out; + } + + template + std::ostream& arith_base::display(std::ostream& out) const { + for (unsigned v = 0; v < ctx.num_bool_vars(); ++v) { + auto ineq = atom(v); + if (ineq) + out << v << ": " << *ineq << "\n"; + } + for (unsigned v = 0; v < m_vars.size(); ++v) + display(out, v) << "\n"; + + for (auto md : m_muls) { + out << "v" << md.m_var << " := "; + for (auto [w, p] : md.m_monomial) { + out << "v" << w; + if (p > 1) + out << "^" << p; + out << " "; + } + + out << "\n"; + } + + for (auto od : m_ops) { + out << "v" << od.m_var << " := "; + out << "v" << od.m_arg1 << " op-" << od.m_op << " v" << od.m_arg2 << "\n"; + } + return out; + } + + template + bool arith_base::eval_is_correct(var_t v) { + auto const& vi = m_vars[v]; + if (vi.m_def_idx == UINT_MAX) + return true; + TRACE("sls", tout << "repair def " << mk_bounded_pp(vi.m_expr, m) << "\n"); + switch (vi.m_op) { + case arith_op_kind::LAST_ARITH_OP: + break; + case arith_op_kind::OP_ADD: { + auto ad = m_adds[vi.m_def_idx]; + num_t sum(ad.m_coeff); + for (auto [c, w] : ad.m_args) + sum += c * value(w); + return sum == value(v); + } + case arith_op_kind::OP_MUL: { + auto md = m_muls[vi.m_def_idx]; + num_t prod(1); + for (auto [w, p] : md.m_monomial) + prod *= power_of(value(w), p); + return prod == value(v); + } + case arith_op_kind::OP_MOD: { + auto od = m_ops[vi.m_def_idx]; + return value(v) == (value(od.m_arg2) == 0 ? num_t(0) : mod(value(od.m_arg1), value(od.m_arg2))); + } + case arith_op_kind::OP_REM: { + auto od = m_ops[vi.m_def_idx]; + return value(v) == (value(od.m_arg2) == 0 ? num_t(0) : mod(value(od.m_arg1), value(od.m_arg2))); + } + case arith_op_kind::OP_POWER: { + auto od = m_ops[vi.m_def_idx]; + NOT_IMPLEMENTED_YET(); + break; + } + case arith_op_kind::OP_IDIV: { + auto od = m_ops[vi.m_def_idx]; + return value(v) == (value(od.m_arg2) == 0 ? num_t(0) : div(value(od.m_arg1), value(od.m_arg2))); + } + case arith_op_kind::OP_DIV: { + auto od = m_ops[vi.m_def_idx]; + return value(v) == (value(od.m_arg2) == 0 ? num_t(0) : value(od.m_arg1) / value(od.m_arg2)); + } + case arith_op_kind::OP_ABS: { + auto od = m_ops[vi.m_def_idx]; + return value(v) == abs(value(od.m_arg1)); + } + case arith_op_kind::OP_TO_INT: { + // auto od = m_ops[vi.m_def_idx]; + NOT_IMPLEMENTED_YET(); + break; + } + case arith_op_kind::OP_TO_REAL: { + // auto od = m_ops[vi.m_def_idx]; + NOT_IMPLEMENTED_YET(); + break; + } + default: { + NOT_IMPLEMENTED_YET(); + break; + } + } + return true; + } + + template + void arith_base::invariant() { + for (unsigned v = 0; v < ctx.num_bool_vars(); ++v) { + auto ineq = atom(v); + if (ineq) + invariant(*ineq); + } + auto& out = verbose_stream(); + for (var_t v = 0; v < m_vars.size(); ++v) { + if (!eval_is_correct(v)) { + + display(out); + display(out, v) << "\n"; + out << mk_bounded_pp(m_vars[v].m_expr, m) << "\n"; + + if (is_mul(v)) { + auto const& [w, monomial] = get_mul(v); + num_t prod(1); + for (auto [v, p] : monomial) + prod *= power_of(value(v), p); + out << "product " << prod << " value " << value(w) << "\n"; + out << "v" << w << " := "; + for (auto [w, p] : monomial) { + out << "(v" << w; + if (p > 1) + out << "^" << p; + out << " := " << value(w); + out << ") "; + } + out << "\n"; + } + else if (is_add(v)) { + auto const& ad = get_add(v); + out << "v" << ad.m_var << " := "; + display(out, ad) << "\n"; + } + + UNREACHABLE(); + } + } + } + + template + void arith_base::invariant(ineq const& i) { + num_t val = i.m_coeff; + for (auto const& [c, v] : i.m_args) + val += c * value(v); + //verbose_stream() << "invariant " << i << "\n"; + if (val != i.m_args_value) + verbose_stream() << i << "\n"; + SASSERT(val == i.m_args_value); + } + + + template + void arith_base::collect_statistics(statistics& st) const { + st.update("sls-arith-flips", m_stats.m_num_steps); + } + + template + void arith_base::reset_statistics() { + m_stats.m_num_steps = 0; + } +} + +template class sls::arith_base>; +template class sls::arith_base; diff --git a/src/ast/sls/sls_arith_base.h b/src/ast/sls/sls_arith_base.h new file mode 100644 index 00000000000..fe987666035 --- /dev/null +++ b/src/ast/sls/sls_arith_base.h @@ -0,0 +1,292 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + sls_arith_base.h + +Abstract: + + Theory plugin for arithmetic local search + +Author: + + Nikolaj Bjorner (nbjorner) 2020-09-08 + +--*/ +#pragma once + +#include "util/obj_pair_set.h" +#include "util/checked_int64.h" +#include "util/optional.h" +#include "ast/ast_trail.h" +#include "ast/arith_decl_plugin.h" +#include "ast/sls/sls_context.h" + +namespace sls { + + using theory_var = int; + + // local search portion for arithmetic + template + class arith_base : public plugin { + enum class ineq_kind { EQ, LE, LT}; + enum class var_sort { INT, REAL }; + struct bound { bool is_strict = false; num_t value; }; + typedef unsigned var_t; + typedef unsigned atom_t; + + struct config { + double cb = 2.85; + unsigned L = 20; + unsigned t = 45; + unsigned max_no_improve = 500000; + double sp = 0.0003; + }; + + struct stats { + unsigned m_num_steps = 0; + }; + + public: + struct linear_term { + vector> m_args; + num_t m_coeff{ 0 }; + }; + struct nonlinear_coeff { + var_t v; // variable or multiplier containing x + num_t coeff; // coeff of v in inequality + unsigned p; // power + }; + + typedef svector> monomial_t; + + // encode args <= bound, args = bound, args < bound + struct ineq : public linear_term { + vector>> m_nonlinear; + vector m_monomials; + ineq_kind m_op = ineq_kind::LE; + num_t m_args_value; + bool m_is_linear = true; + + bool is_true() const; + std::ostream& display(std::ostream& out) const; + }; + private: + + class var_info { + num_t m_range{ 100000000 }; + num_t m_update_value{ 0 }; + unsigned m_update_timestamp = 0; + public: + var_info(expr* e, var_sort k): m_expr(e), m_sort(k) {} + expr* m_expr; + num_t m_value{ 0 }; + num_t m_best_value{ 0 }; + var_sort m_sort; + arith_op_kind m_op = arith_op_kind::LAST_ARITH_OP; + unsigned m_def_idx = UINT_MAX; + vector> m_bool_vars; + unsigned_vector m_muls; + unsigned_vector m_adds; + optional m_lo, m_hi; + + // retrieve temporary value during an update. + void set_update_value(num_t const& v, unsigned timestamp) { + m_update_value = v; + m_update_timestamp = timestamp; + } + num_t const& get_update_value(unsigned ts) const { + return ts == m_update_timestamp ? m_update_value : m_value; + } + + bool in_range(num_t const& n) const { + if (-m_range < n && n < m_range) + return true; + if (m_lo && !m_hi) + return n < m_lo->value + m_range; + if (!m_lo && m_hi) + return n > m_hi->value - m_range; + return false; + } + unsigned m_tabu_pos = 0, m_tabu_neg = 0; + unsigned m_last_pos = 0, m_last_neg = 0; + bool is_tabu(unsigned step, num_t const& delta) { + return (delta > 0 ? m_tabu_pos : m_tabu_neg) > step; + } + void set_step(unsigned step, unsigned tabu_step, num_t const& delta) { + if (delta > 0) + m_tabu_pos = tabu_step, m_last_pos = step; + else + m_tabu_neg = tabu_step, m_last_neg = step; + } + }; + + struct mul_def { + unsigned m_var; + monomial_t m_monomial; + }; + + struct add_def : public linear_term { + unsigned m_var; + }; + + struct op_def { + unsigned m_var = UINT_MAX; + arith_op_kind m_op = LAST_ARITH_OP; + unsigned m_arg1, m_arg2; + }; + + struct var_change { + unsigned m_var; + num_t m_delta; + double m_score; + }; + + stats m_stats; + config m_config; + scoped_ptr_vector m_bool_vars; + vector m_vars; + vector m_muls; + vector m_adds; + vector m_ops; + unsigned_vector m_expr2var; + svector m_probs; + bool m_dscore_mode = false; + vector m_updates; + var_t m_last_var = 0; + sat::literal m_last_literal = sat::null_literal; + num_t m_last_delta { 0 }; + bool m_use_tabu = true; + unsigned m_updates_max_size = 45; + arith_util a; + svector m_prob_break; + + void invariant(); + void invariant(ineq const& i); + + unsigned get_num_vars() const { return m_vars.size(); } + + bool eval_is_correct(var_t v); + bool repair_mul(mul_def const& md); + bool repair_add(add_def const& ad); + bool repair_mod(op_def const& od); + bool repair_idiv(op_def const& od); + bool repair_div(op_def const& od); + bool repair_rem(op_def const& od); + bool repair_power(op_def const& od); + bool repair_abs(op_def const& od); + bool repair_to_int(op_def const& od); + bool repair_to_real(op_def const& od); + bool repair(sat::literal lit); + bool in_bounds(var_t v, num_t const& value); + bool is_fixed(var_t v); + bool is_linear(var_t x, vector const& nlc, num_t& b); + bool is_quadratic(var_t x, vector const& nlc, num_t& a, num_t& b); + num_t mul_value_without(var_t m, var_t x); + + void add_update(var_t v, num_t delta); + bool is_permitted_update(var_t v, num_t const& delta, num_t& delta_out); + unsigned m_update_timestamp = 0; + svector m_update_trail; + bool check_update(var_t v, num_t new_value); + void apply_checked_update(); + + num_t value1(var_t v); + + vector m_factors; + vector const& factor(num_t n); + num_t root_of(unsigned n, num_t a); + num_t power_of(num_t a, unsigned k); + + struct monomial_elem { + num_t other_product; + var_t v; + unsigned p; // power + }; + + // double reward(sat::literal lit); + + bool sign(sat::bool_var v) const { return !ctx.is_true(sat::literal(v, false)); } + ineq* atom(sat::bool_var bv) const { return m_bool_vars.get(bv, nullptr); } + num_t dtt(bool sign, ineq const& ineq) const { return dtt(sign, ineq.m_args_value, ineq); } + num_t dtt(bool sign, num_t const& args_value, ineq const& ineq) const; + num_t dtt(bool sign, ineq const& ineq, var_t v, num_t const& new_value) const; + num_t dtt(bool sign, ineq const& ineq, num_t const& coeff, num_t const& delta) const; + num_t dts(unsigned cl, var_t v, num_t const& new_value) const; + num_t compute_dts(unsigned cl) const; + + bool is_mul(var_t v) const { return m_vars[v].m_op == arith_op_kind::OP_MUL; } + bool is_add(var_t v) const { return m_vars[v].m_op == arith_op_kind::OP_ADD; } + mul_def const& get_mul(var_t v) const { SASSERT(is_mul(v)); return m_muls[m_vars[v].m_def_idx]; } + add_def const& get_add(var_t v) const { SASSERT(is_add(v)); return m_adds[m_vars[v].m_def_idx]; } + + bool update(var_t v, num_t const& new_value); + bool apply_update(); + bool find_nl_moves(sat::literal lit); + bool find_lin_moves(sat::literal lit); + bool find_reset_moves(sat::literal lit); + void add_reset_update(var_t v); + void find_linear_moves(ineq const& i, var_t x, num_t const& coeff, num_t const& sum); + void find_quadratic_moves(ineq const& i, var_t x, num_t const& a, num_t const& b, num_t const& sum); + double compute_score(var_t x, num_t const& delta); + void save_best_values(); + + var_t mk_var(expr* e); + var_t mk_term(expr* e); + var_t mk_op(arith_op_kind k, expr* e, expr* x, expr* y); + void add_arg(linear_term& term, num_t const& c, var_t v); + void add_args(linear_term& term, expr* e, num_t const& sign); + ineq& new_ineq(ineq_kind op, num_t const& bound); + void init_ineq(sat::bool_var bv, ineq& i); + num_t divide(var_t v, num_t const& delta, num_t const& coeff); + num_t divide_floor(var_t v, num_t const& a, num_t const& b); + num_t divide_ceil(var_t v, num_t const& a, num_t const& b); + + void init_bool_var_assignment(sat::bool_var v); + + bool is_int(var_t v) const { return m_vars[v].m_sort == var_sort::INT; } + + num_t value(var_t v) const { return m_vars[v].m_value; } + num_t const& get_update_value(var_t v) const { return m_vars[v].get_update_value(m_update_timestamp); } + bool is_num(expr* e, num_t& i); + expr_ref from_num(sort* s, num_t const& n); + void check_ineqs(); + void init_bool_var(sat::bool_var bv); + void initialize_unit(sat::literal lit); + void add_le(var_t v, num_t const& n); + void add_ge(var_t v, num_t const& n); + void add_lt(var_t v, num_t const& n); + void add_gt(var_t v, num_t const& n); + std::ostream& display(std::ostream& out, var_t v) const; + std::ostream& display(std::ostream& out, add_def const& ad) const; + std::ostream& display(std::ostream& out, mul_def const& md) const; + public: + arith_base(context& ctx); + ~arith_base() override {} + void register_term(expr* e) override; + bool set_value(expr* e, expr* v) override; + expr_ref get_value(expr* e) override; + void initialize() override; + void propagate_literal(sat::literal lit) override; + bool propagate() override; + void repair_up(app* e) override; + bool repair_down(app* e) override; + void repair_literal(sat::literal lit) override; + bool is_sat() override; + void on_rescale() override; + void on_restart() override; + std::ostream& display(std::ostream& out) const override; + void collect_statistics(statistics& st) const override; + void reset_statistics() override; + }; + + + inline std::ostream& operator<<(std::ostream& out, typename arith_base>::ineq const& ineq) { + return ineq.display(out); + } + + inline std::ostream& operator<<(std::ostream& out, typename arith_base::ineq const& ineq) { + return ineq.display(out); + } +} diff --git a/src/ast/sls/sls_arith_plugin.cpp b/src/ast/sls/sls_arith_plugin.cpp new file mode 100644 index 00000000000..310d4009f41 --- /dev/null +++ b/src/ast/sls/sls_arith_plugin.cpp @@ -0,0 +1,131 @@ + +/*++ +Copyright (c) 2023 Microsoft Corporation + +Module Name: + + sls_arith_plugin.cpp + +Abstract: + + Local search dispatch for NIA + +Author: + + Nikolaj Bjorner (nbjorner) 2023-02-07 + +--*/ + +#include "ast/sls/sls_arith_plugin.h" +#include "ast/ast_ll_pp.h" + +namespace sls { + +#define WITH_FALLBACK(_fn_) \ + if (m_arith64) { \ + try {\ + return m_arith64->_fn_;\ + }\ + catch (overflow_exception&) {\ + throw;\ + init_backup();\ + }\ + }\ + return m_arith->_fn_; \ + +#define APPLY_BOTH(_fn_) \ + if (m_arith64) { \ + try {\ + m_arith64->_fn_;\ + }\ + catch (overflow_exception&) {\ + throw;\ + init_backup();\ + }\ + }\ + m_arith->_fn_; \ + + arith_plugin::arith_plugin(context& ctx) : + plugin(ctx), m_shared(ctx.get_manager()) { + m_arith64 = alloc(arith_base>, ctx); + m_arith = alloc(arith_base, ctx); + m_arith64 = nullptr; + if (m_arith) + m_fid = m_arith->fid(); + else + m_fid = m_arith64->fid(); + } + + void arith_plugin::init_backup() { + m_arith64 = nullptr; + } + + void arith_plugin::register_term(expr* e) { + APPLY_BOTH(register_term(e)); + } + + expr_ref arith_plugin::get_value(expr* e) { + WITH_FALLBACK(get_value(e)); + } + + void arith_plugin::initialize() { + APPLY_BOTH(initialize()); + } + + void arith_plugin::propagate_literal(sat::literal lit) { + WITH_FALLBACK(propagate_literal(lit)); + } + + bool arith_plugin::propagate() { + WITH_FALLBACK(propagate()); + } + + bool arith_plugin::is_sat() { + WITH_FALLBACK(is_sat()); + } + + void arith_plugin::on_rescale() { + APPLY_BOTH(on_rescale()); + } + + void arith_plugin::on_restart() { + WITH_FALLBACK(on_restart()); + } + + std::ostream& arith_plugin::display(std::ostream& out) const { + if (m_arith64) + return m_arith64->display(out); + else + return m_arith->display(out); + } + + bool arith_plugin::repair_down(app* e) { + WITH_FALLBACK(repair_down(e)); + } + + void arith_plugin::repair_up(app* e) { + WITH_FALLBACK(repair_up(e)); + } + + void arith_plugin::repair_literal(sat::literal lit) { + WITH_FALLBACK(repair_literal(lit)); + } + + bool arith_plugin::set_value(expr* e, expr* v) { + WITH_FALLBACK(set_value(e, v)); + } + + void arith_plugin::collect_statistics(statistics& st) const { + if (m_arith64) + m_arith64->collect_statistics(st); + else + m_arith->collect_statistics(st); + } + + void arith_plugin::reset_statistics() { + if (m_arith) + m_arith->reset_statistics(); + if (m_arith64) + m_arith64->reset_statistics(); + } +} diff --git a/src/ast/sls/sls_arith_plugin.h b/src/ast/sls/sls_arith_plugin.h new file mode 100644 index 00000000000..7d84915798f --- /dev/null +++ b/src/ast/sls/sls_arith_plugin.h @@ -0,0 +1,52 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + sls_arith_plugin.h + +Abstract: + + Theory plugin for arithmetic local search + +Author: + + Nikolaj Bjorner (nbjorner) 2024-07-05 + +--*/ +#pragma once + +#include "ast/sls/sls_context.h" +#include "ast/sls/sls_arith_base.h" + +namespace sls { + + class arith_plugin : public plugin { + scoped_ptr>> m_arith64; + scoped_ptr> m_arith; + expr_ref_vector m_shared; + + void init_backup(); + public: + arith_plugin(context& ctx); + ~arith_plugin() override {} + void register_term(expr* e) override; + expr_ref get_value(expr* e) override; + void initialize() override; + void propagate_literal(sat::literal lit) override; + bool propagate() override; + bool repair_down(app* e) override; + void repair_up(app* e) override; + void repair_literal(sat::literal lit) override; + bool is_sat() override; + + void on_rescale() override; + void on_restart() override; + std::ostream& display(std::ostream& out) const override; + bool set_value(expr* e, expr* v) override; + + void collect_statistics(statistics& st) const override; + void reset_statistics() override; + }; + +} diff --git a/src/ast/sls/sls_array_plugin.cpp b/src/ast/sls/sls_array_plugin.cpp new file mode 100644 index 00000000000..930e8f9d8d7 --- /dev/null +++ b/src/ast/sls/sls_array_plugin.cpp @@ -0,0 +1,277 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + sls_array_plugin.cpp + +Abstract: + + Theory plugin for arrays local search + +Author: + + Nikolaj Bjorner (nbjorner) 2024-07-06 + +--*/ + +#include "ast/sls/sls_array_plugin.h" +#include "ast/ast_ll_pp.h" +#include "ast/ast_pp.h" + + +namespace sls { + + array_plugin::array_plugin(context& ctx): + plugin(ctx), + a(m) + { + m_fid = a.get_family_id(); + } + + bool array_plugin::is_sat() { + if (!m_has_arrays) + return true; + m_g = alloc(euf::egraph, m); + m_kv = nullptr; + init_egraph(*m_g); + saturate_store(*m_g); + return true; + } + + // b ~ a[i -> v] + // ensure b[i] ~ v + // ensure b[j] ~ a[j] for j != i + + void array_plugin::saturate_store(euf::egraph& g) { + unsigned sz = 0; + while (sz < g.nodes().size()) { + sz = g.nodes().size(); + for (unsigned i = 0; i < sz; ++i) { + auto n = g.nodes()[i]; + if (!a.is_store(n->get_expr())) + continue; + + force_store_axiom1(g, n); + + for (auto p : euf::enode_parents(n->get_root())) + if (a.is_select(p->get_expr())) + force_store_axiom2_down(g, n, p); + + auto arr = n->get_arg(0); + for (auto p : euf::enode_parents(arr->get_root())) + if (a.is_select(p->get_expr())) + force_store_axiom2_up(g, n, p); + } + } + display(verbose_stream() << "saturated\n"); + } + + euf::enode* array_plugin::mk_select(euf::egraph& g, euf::enode* b, euf::enode* sel) { + auto arity = get_array_arity(b->get_sort()); + ptr_buffer args; + ptr_buffer eargs; + args.push_back(b->get_expr()); + eargs.push_back(b); + for (unsigned i = 1; i <= arity; ++i) { + auto idx = sel->get_arg(i); + eargs.push_back(idx); + args.push_back(idx->get_expr()); + } + expr_ref esel(a.mk_select(args), m); + auto n = g.find(esel); + return n ? n : g.mk(esel, 0, eargs.size(), eargs.data()); + } + + // ensure a[i->v][i] = v exists in the e-graph + void array_plugin::force_store_axiom1(euf::egraph& g, euf::enode* n) { + SASSERT(a.is_store(n->get_expr())); + auto val = n->get_arg(n->num_args() - 1); + auto nsel = mk_select(g, n, n); + if (are_distinct(nsel, val)) + add_store_axiom1(n->get_app()); + else { + g.merge(nsel, val, nullptr); + VERIFY(g.propagate()); + } + } + + // i /~ j, b ~ a[i->v], b[j] occurs -> a[j] = b[j] + void array_plugin::force_store_axiom2_down(euf::egraph& g, euf::enode* sto, euf::enode* sel) { + SASSERT(a.is_store(sto->get_expr())); + SASSERT(a.is_select(sel->get_expr())); + if (sel->get_arg(0)->get_root() != sto->get_root()) + return; + if (eq_args(sto, sel)) + return; + auto nsel = mk_select(g, sto->get_arg(0), sel); + if (are_distinct(nsel, sel)) + add_store_axiom2(sto->get_app(), sel->get_app()); + else { + g.merge(nsel, sel, nullptr); + VERIFY(g.propagate()); + } + } + + // a ~ b, i /~ j, b[j] occurs -> a[i -> v][j] = b[j] + void array_plugin::force_store_axiom2_up(euf::egraph& g, euf::enode* sto, euf::enode* sel) { + SASSERT(a.is_store(sto->get_expr())); + SASSERT(a.is_select(sel->get_expr())); + if (sel->get_arg(0)->get_root() != sto->get_arg(0)->get_root()) + return; + if (eq_args(sto, sel)) + return; + auto nsel = mk_select(g, sto, sel); + if (are_distinct(nsel, sel)) + add_store_axiom2(sto->get_app(), sel->get_app()); + else { + g.merge(nsel, sel, nullptr); + VERIFY(g.propagate()); + } + } + + bool array_plugin::are_distinct(euf::enode* a, euf::enode* b) { + a = a->get_root(); + b = b->get_root(); + return a->interpreted() && b->interpreted() && a != b; // TODO work with nested arrays? + } + + bool array_plugin::eq_args(euf::enode* sto, euf::enode* sel) { + SASSERT(a.is_store(sto->get_expr())); + SASSERT(a.is_select(sel->get_expr())); + unsigned arity = get_array_arity(sto->get_sort()); + for (unsigned i = 1; i < arity; ++i) { + if (sto->get_arg(i)->get_root() != sel->get_arg(i)->get_root()) + return false; + } + return true; + } + + void array_plugin::add_store_axiom1(app* sto) { + if (!m_add_conflicts) + return; + ptr_vector args; + args.push_back(sto); + for (unsigned i = 1; i < sto->get_num_args() - 1; ++i) + args.push_back(sto->get_arg(i)); + expr_ref sel(a.mk_select(args), m); + expr_ref eq(m.mk_eq(sel, to_app(sto)->get_arg(sto->get_num_args() - 1)), m); + verbose_stream() << "add store axiom 1 " << mk_bounded_pp(sto, m) << "\n"; + ctx.add_clause(eq); + } + + void array_plugin::add_store_axiom2(app* sto, app* sel) { + if (!m_add_conflicts) + return; + ptr_vector args1, args2; + args1.push_back(sto); + args2.push_back(sto->get_arg(0)); + for (unsigned i = 1; i < sel->get_num_args() - 1; ++i) { + args1.push_back(sel->get_arg(i)); + args2.push_back(sel->get_arg(i)); + } + expr_ref sel1(a.mk_select(args1), m); + expr_ref sel2(a.mk_select(args2), m); + expr_ref eq(m.mk_eq(sel1, sel2), m); + expr_ref_vector ors(m); + ors.push_back(eq); + for (unsigned i = 1; i < sel->get_num_args() - 1; ++i) + ors.push_back(m.mk_eq(sel->get_arg(i), sto->get_arg(i))); + verbose_stream() << "add store axiom 2 " << mk_bounded_pp(sto, m) << " " << mk_bounded_pp(sel, m) << "\n"; + ctx.add_clause(m.mk_or(ors)); + } + + void array_plugin::init_egraph(euf::egraph& g) { + ptr_vector args; + for (auto t : ctx.subterms()) { + args.reset(); + if (is_app(t)) + for (auto* arg : *to_app(t)) + args.push_back(g.find(arg)); + + euf::enode* n1, * n2; + n1 = g.find(t); + n1 = n1 ? n1 : g.mk(t, 0, args.size(), args.data()); + if (a.is_array(t)) + continue; + auto v = ctx.get_value(t); + verbose_stream() << "init " << mk_bounded_pp(t, m) << " := " << mk_bounded_pp(v, m) << "\n"; + n2 = g.find(v); + n2 = n2 ? n2: g.mk(v, 0, 0, nullptr); + g.merge(n1, n2, nullptr); + } + for (auto lit : ctx.root_literals()) { + if (!ctx.is_true(lit) || lit.sign()) + continue; + auto e = ctx.atom(lit.var()); + expr* x, * y; + if (e && m.is_eq(e, x, y)) + g.merge(g.find(x), g.find(y), nullptr); + } + + display(verbose_stream()); + + } + + void array_plugin::init_kv(euf::egraph& g, kv& kv) { + for (auto n : g.nodes()) { + if (!n->is_root() || !a.is_array(n->get_expr())) + continue; + kv.insert(n, select2value()); + for (auto p : euf::enode_parents(n)) { + if (!a.is_select(p->get_expr())) + continue; + if (p->get_arg(0)->get_root() != n->get_root()) + continue; + auto val = p->get_root(); + kv[n].insert(select_args(p), val); + } + } + display(verbose_stream()); + } + + expr_ref array_plugin::get_value(expr* e) { + SASSERT(a.is_array(e)); + if (!m_g) { + m_g = alloc(euf::egraph, m); + init_egraph(*m_g); + flet _strong(m_add_conflicts, false); + saturate_store(*m_g); + } + if (!m_kv) { + m_kv = alloc(kv); + init_kv(*m_g, *m_kv); + } + auto& kv = *m_kv; + auto n = m_g->find(e)->get_root(); + expr_ref r(n->get_expr(), m); + for (auto [k, v] : kv[n]) { + ptr_vector args; + args.push_back(r); + args.push_back(k.sel->get_arg(1)->get_expr()); + args.push_back(v->get_expr()); + r = a.mk_store(args); + } + return r; + } + + std::ostream& array_plugin::display(std::ostream& out) const { + if (m_g) + m_g->display(out); + if (m_kv) { + for (auto& [n, kvs] : *m_kv) { + out << m_g->pp(n) << " -> {"; + char const* sp = ""; + for (auto& [k, v] : kvs) { + out << sp; + for (unsigned i = 1; i < k.sel->num_args(); ++i) + out << m_g->pp(k.sel->get_arg(i)->get_root()) << " "; + out << "-> " << m_g->pp(v); + sp = " "; + } + out << "}\n"; + } + } + return out; + } +} diff --git a/src/ast/sls/sls_array_plugin.h b/src/ast/sls/sls_array_plugin.h new file mode 100644 index 00000000000..4f8f051f4ce --- /dev/null +++ b/src/ast/sls/sls_array_plugin.h @@ -0,0 +1,90 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + sls_array_plugin.h + +Abstract: + + Theory plugin for arrays local search + +Author: + + Nikolaj Bjorner (nbjorner) 2024-07-06 + +--*/ +#pragma once + +#include "ast/sls/sls_context.h" +#include "ast/array_decl_plugin.h" +#include "ast/euf/euf_egraph.h" + +namespace sls { + + class array_plugin : public plugin { + struct select_args { + euf::enode* sel = nullptr; + select_args(euf::enode* s) : sel(s) {} + select_args() {} + }; + struct select_args_hash { + unsigned operator()(select_args const& a) const { + unsigned h = 0; + for (unsigned i = 1; i < a.sel->num_args(); ++i) + h ^= a.sel->get_arg(i)->get_root()->hash(); + return h; + }; + }; + struct select_args_eq { + bool operator()(select_args const& a, select_args const& b) const { + SASSERT(a.sel->num_args() == b.sel->num_args()); + for (unsigned i = 1; i < a.sel->num_args(); ++i) + if (a.sel->get_arg(i)->get_root() != b.sel->get_arg(i)->get_root()) + return false; + return true; + } + }; + typedef map select2value; + typedef obj_map kv; + + array_util a; + scoped_ptr m_g; + scoped_ptr m_kv; + bool m_add_conflicts = true; + bool m_has_arrays = false; + + void init_egraph(euf::egraph& g); + void init_kv(euf::egraph& g, kv& kv); + void saturate_store(euf::egraph& g); + void force_store_axiom1(euf::egraph& g, euf::enode* n); + void force_store_axiom2_down(euf::egraph& g, euf::enode* sto, euf::enode* sel); + void force_store_axiom2_up(euf::egraph& g, euf::enode* sto, euf::enode* sel); + void add_store_axiom1(app* sto); + void add_store_axiom2(app* sto, app* sel); + bool are_distinct(euf::enode* a, euf::enode* b); + bool eq_args(euf::enode* sto, euf::enode* sel); + euf::enode* mk_select(euf::egraph& g, euf::enode* b, euf::enode* sel); + + public: + array_plugin(context& ctx); + ~array_plugin() override {} + void register_term(expr* e) override { if (a.is_array(e->get_sort())) m_has_arrays = true; } + expr_ref get_value(expr* e) override; + void initialize() override { m_g = nullptr; } + void propagate_literal(sat::literal lit) override { m_g = nullptr; } + bool propagate() override { return false; } + bool repair_down(app* e) override { return true; } + void repair_up(app* e) override {} + void repair_literal(sat::literal lit) override { m_g = nullptr; } + bool is_sat() override; + + void on_rescale() override {} + void on_restart() override {} + std::ostream& display(std::ostream& out) const override; + bool set_value(expr* e, expr* v) override { return false; } + void collect_statistics(statistics& st) const override {} + void reset_statistics() override {} + }; + +} diff --git a/src/ast/sls/sls_basic_plugin.cpp b/src/ast/sls/sls_basic_plugin.cpp new file mode 100644 index 00000000000..bb876c44f28 --- /dev/null +++ b/src/ast/sls/sls_basic_plugin.cpp @@ -0,0 +1,210 @@ +/*++ +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + sls_basic_plugin.cpp + +Abstract: + + Local search dispatch for Boolean connectives + +Author: + + Nikolaj Bjorner (nbjorner) 2024-07-07 + +--*/ + +#include "ast/sls/sls_basic_plugin.h" +#include "ast/ast_ll_pp.h" +#include "ast/ast_pp.h" +#include "ast/ast_util.h" + +namespace sls { + + expr_ref basic_plugin::get_value(expr* e) { + return expr_ref(m.mk_bool_val(bval0(e)), m); + } + + bool basic_plugin::is_basic(expr* e) const { + if (!e || !is_app(e)) + return false; + if (m.is_ite(e) && !m.is_bool(e) && false) + return true; + if (m.is_xor(e) && to_app(e)->get_num_args() != 2) + return true; + if (m.is_distinct(e)) + return true; + return false; + } + + void basic_plugin::propagate_literal(sat::literal lit) { + } + + void basic_plugin::register_term(expr* e) { + expr* c, * th, * el; + if (m.is_ite(e, c, th, el) && !m.is_bool(e)) { + ctx.add_clause(m.mk_or(mk_not(m, c), m.mk_eq(e, th))); + ctx.add_clause(m.mk_or(c, m.mk_eq(e, el))); + } + } + + void basic_plugin::initialize() { + } + + bool basic_plugin::propagate() { + return false; + } + + bool basic_plugin::is_sat() { + return true; + } + + std::ostream& basic_plugin::display(std::ostream& out) const { + return out; + } + + bool basic_plugin::set_value(expr* e, expr* v) { + if (!m.is_bool(e)) + return false; + SASSERT(m.is_true(v) || m.is_false(v)); + return set_value(e, m.is_true(v)); + } + + expr_ref basic_plugin::eval_ite(app* e) { + expr* c, * th, * el; + VERIFY(m.is_ite(e, c, th, el)); + if (bval0(c)) + return ctx.get_value(th); + else + return ctx.get_value(el); + } + + expr_ref basic_plugin::eval_distinct(app* e) { + for (unsigned i = 0; i < e->get_num_args(); ++i) { + for (unsigned j = i + 1; j < e->get_num_args(); ++j) { + if (bval0(e->get_arg(i)) == bval0(e->get_arg(j))) + return expr_ref(m.mk_false(), m); + } + } + return expr_ref(m.mk_true(), m); + } + + expr_ref basic_plugin::eval_xor(app* e) { + bool b = false; + for (expr* arg : *e) + b ^= bval0(arg); + return expr_ref(m.mk_bool_val(b), m); + } + + bool basic_plugin::bval0(expr* e) const { + SASSERT(m.is_bool(e)); + return ctx.is_true(ctx.mk_literal(e)); + } + + bool basic_plugin::try_repair(app* e, unsigned i) { + switch (e->get_decl_kind()) { + case OP_XOR: + return try_repair_xor(e, i); + case OP_ITE: + return try_repair_ite(e, i); + case OP_DISTINCT: + return try_repair_distinct(e, i); + default: + return true; + } + } + + bool basic_plugin::try_repair_xor(app* e, unsigned i) { + auto child = e->get_arg(i); + bool bv = false; + for (unsigned j = 0; j < e->get_num_args(); ++j) + if (j != i) + bv ^= bval0(e->get_arg(j)); + bool ev = bval0(e); + return set_value(child, ev != bv); + } + + bool basic_plugin::try_repair_ite(app* e, unsigned i) { + if (m.is_bool(e)) + return true; + auto child = e->get_arg(i); + auto cond = e->get_arg(0); + bool c = bval0(cond); + + if (i == 0) { + auto eval = ctx.get_value(e); + auto eval1 = ctx.get_value(e->get_arg(1)); + auto eval2 = ctx.get_value(e->get_arg(2)); + if (eval == eval1 && eval == eval2) + return true; + if (eval == eval1) + return set_value(cond, true); + if (eval == eval2) + return set_value(cond, false); + return false; + } + if (c != (i == 1)) + return false; + if (m.is_value(child)) + return false; + bool r = ctx.set_value(child, ctx.get_value(e)); + verbose_stream() << "repair-ite-down " << mk_bounded_pp(e, m) << " @ " << mk_bounded_pp(child, m) << " := " << ctx.get_value(e) << " success " << r << "\n"; + return r; + } + + void basic_plugin::repair_up(app* e) { + expr* c, * th, * el; + expr_ref val(m); + if (!is_basic(e)) + return; + if (m.is_ite(e, c, th, el) && !m.is_bool(e)) + val = eval_ite(e); + else if (m.is_xor(e)) + val = eval_xor(e); + else if (m.is_distinct(e)) + val = eval_distinct(e); + else + return; + verbose_stream() << "repair-up " << mk_bounded_pp(e, m) << " " << val << "\n"; + if (!ctx.set_value(e, val)) + ctx.new_value_eh(e); + } + + void basic_plugin::repair_literal(sat::literal lit) { + } + + bool basic_plugin::repair_down(app* e) { + if (!is_basic(e)) + return true; + if (m.is_xor(e) && eval_xor(e) == ctx.get_value(e)) + return true; + if (m.is_ite(e) && eval_ite(e) == ctx.get_value(e)) + return true; + if (m.is_distinct(e) && eval_distinct(e) == ctx.get_value(e)) + return true; + verbose_stream() << "basic repair down " << mk_bounded_pp(e, m) << "\n"; + unsigned n = e->get_num_args(); + unsigned s = ctx.rand(n); + for (unsigned i = 0; i < n; ++i) { + auto j = (i + s) % n; + if (try_repair(e, j)) + return true; + } + return false; + } + + bool basic_plugin::try_repair_distinct(app* e, unsigned i) { + NOT_IMPLEMENTED_YET(); + return false; + } + + bool basic_plugin::set_value(expr* e, bool b) { + auto lit = ctx.mk_literal(e); + if (ctx.is_true(lit) != b) { + ctx.flip(lit.var()); + ctx.new_value_eh(e); + } + return true; + } +} diff --git a/src/ast/sls/sls_basic_plugin.h b/src/ast/sls/sls_basic_plugin.h new file mode 100644 index 00000000000..6c1936532ad --- /dev/null +++ b/src/ast/sls/sls_basic_plugin.h @@ -0,0 +1,58 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + sls_basic_plugin.h + +Author: + + Nikolaj Bjorner (nbjorner) 2024-07-05 + +--*/ +#pragma once + +#include "ast/sls/sls_context.h" + +namespace sls { + + class basic_plugin : public plugin { + expr_mark m_axiomatized; + + bool is_basic(expr* e) const; + bool bval0(expr* e) const; + bool try_repair(app* e, unsigned i); + bool try_repair_xor(app* e, unsigned i); + bool try_repair_ite(app* e, unsigned i); + bool try_repair_distinct(app* e, unsigned i); + bool set_value(expr* e, bool b); + + expr_ref eval_ite(app* e); + expr_ref eval_distinct(app* e); + expr_ref eval_xor(app* e); + + public: + basic_plugin(context& ctx) : + plugin(ctx) { + m_fid = basic_family_id; + } + ~basic_plugin() override {} + void register_term(expr* e) override; + expr_ref get_value(expr* e) override; + void initialize() override; + void propagate_literal(sat::literal lit) override; + bool propagate() override; + bool repair_down(app* e) override; + void repair_up(app* e) override; + void repair_literal(sat::literal lit) override; + bool is_sat() override; + + void on_rescale() override {} + void on_restart() override {} + std::ostream& display(std::ostream& out) const override; + bool set_value(expr* e, expr* v) override; + void collect_statistics(statistics& st) const override {} + void reset_statistics() override {} + }; + +} diff --git a/src/ast/sls/sls_engine.cpp b/src/ast/sls/sls_bv_engine.cpp similarity index 99% rename from src/ast/sls/sls_engine.cpp rename to src/ast/sls/sls_bv_engine.cpp index a2ab861c0ed..124a5ea778c 100644 --- a/src/ast/sls/sls_engine.cpp +++ b/src/ast/sls/sls_bv_engine.cpp @@ -26,7 +26,7 @@ Module Name: #include "util/luby.h" #include "params/sls_params.hpp" -#include "ast/sls/sls_engine.h" +#include "ast/sls/sls_bv_engine.h" sls_engine::sls_engine(ast_manager & m, params_ref const & p) : diff --git a/src/ast/sls/sls_engine.h b/src/ast/sls/sls_bv_engine.h similarity index 97% rename from src/ast/sls/sls_engine.h rename to src/ast/sls/sls_bv_engine.h index 3e67aa49ca0..f9d26ee702a 100644 --- a/src/ast/sls/sls_engine.h +++ b/src/ast/sls/sls_bv_engine.h @@ -23,8 +23,8 @@ Module Name: #include "ast/converters/model_converter.h" #include "ast/sls/sls_stats.h" -#include "ast/sls/sls_tracker.h" -#include "ast/sls/sls_evaluator.h" +#include "ast/sls/sls_bv_tracker.h" +#include "ast/sls/sls_bv_evaluator.h" class sls_engine { diff --git a/src/ast/sls/bv_sls_eval.cpp b/src/ast/sls/sls_bv_eval.cpp similarity index 66% rename from src/ast/sls/bv_sls_eval.cpp rename to src/ast/sls/sls_bv_eval.cpp index eecc42511ce..e6f8f8ab51b 100644 --- a/src/ast/sls/bv_sls_eval.cpp +++ b/src/ast/sls/sls_bv_eval.cpp @@ -3,7 +3,7 @@ Copyright (c) 2024 Microsoft Corporation Module Name: - bv_sls_eval.cpp + sls_bv_eval.cpp Author: @@ -13,75 +13,46 @@ Module Name: #include "ast/ast_pp.h" #include "ast/ast_ll_pp.h" -#include "ast/sls/bv_sls.h" +#include "ast/sls/sls_bv_eval.h" +#include "ast/sls/sls_bv_terms.h" +#include "ast/rewriter/th_rewriter.h" -namespace bv { +namespace sls { - sls_eval::sls_eval(ast_manager& m): - m(m), + bv_eval::bv_eval(sls::bv_terms& terms, sls::context& ctx): + m(ctx.get_manager()), + ctx(ctx), + terms(terms), bv(m), - m_fix(*this) + m_fix(*this, terms, ctx) {} - void sls_eval::init_eval(expr_ref_vector const& es, std::function const& eval) { - auto& terms = sort_assertions(es); - for (expr* e : terms) { - if (!is_app(e)) - continue; - app* a = to_app(e); - if (bv.is_bv(e)) - add_bit_vector(a); - if (a->get_family_id() == basic_family_id) - init_eval_basic(a); - else if (a->get_family_id() == bv.get_family_id()) - init_eval_bv(a); - else if (is_uninterp(e)) { - if (bv.is_bv(e)) { - auto& v = wval(e); - for (unsigned i = 0; i < v.bw; ++i) - m_tmp.set(i, eval(e, i)); - v.set_repair(random_bool(), m_tmp); - } - else if (m.is_bool(e)) - m_eval.setx(e->get_id(), eval(e, 0), false); - } - else { - TRACE("sls", tout << "Unhandled expression " << mk_pp(e, m) << "\n"); - } - } - terms.reset(); - } - /** - * Sort all sub-expressions by depth, smallest first. - */ - ptr_vector& sls_eval::sort_assertions(expr_ref_vector const& es) { - expr_fast_mark1 mark; - for (expr* e : es) { - if (!mark.is_marked(e)) { - mark.mark(e); - m_todo.push_back(e); - } + void bv_eval::register_term(expr* e) { + if (!is_app(e)) + return; + app* a = to_app(e); + add_bit_vector(a); + if (a->get_family_id() == bv.get_family_id()) + init_eval_bv(a); + else if (bv.is_bv(e)) { + auto& v = wval(e); + for (unsigned i = 0; i < v.bw; ++i) + m_tmp.set(i, false); + v.set_repair(random_bool(), m_tmp); } - for (unsigned i = 0; i < m_todo.size(); ++i) { - auto e = m_todo[i]; - if (!is_app(e)) - continue; - for (expr* arg : *to_app(e)) { - if (!mark.is_marked(arg)) { - mark.mark(arg); - m_todo.push_back(arg); - } - } + if (bv.is_bv(e)) { + auto& v = wval(e); + v.bits().copy_to(v.nw, v.eval); } - std::stable_sort(m_todo.begin(), m_todo.end(), [&](expr* a, expr* b) { return get_depth(a) < get_depth(b); }); - return m_todo; } - bool sls_eval::add_bit_vector(app* e) { + void bv_eval::add_bit_vector(app* e) { + if (!bv.is_bv(e)) + return; m_values.reserve(e->get_id() + 1); if (m_values.get(e->get_id())) - return false; + return; auto v = alloc_valuation(e); m_values.set(e->get_id(), v); expr* x, * y; @@ -91,17 +62,18 @@ namespace bv { else if (bv.is_bv_ashr(e, x, y) && bv.is_numeral(y, val) && val.is_unsigned() && val.get_unsigned() <= bv.get_bv_size(e)) v->set_signed(val.get_unsigned()); - return true; + return; } - sls_valuation* sls_eval::alloc_valuation(app* e) { + sls::bv_valuation* bv_eval::alloc_valuation(app* e) { auto bit_width = bv.get_bv_size(e); - auto* r = alloc(sls_valuation, bit_width); + auto* r = alloc(sls::bv_valuation, bit_width); while (m_tmp.size() < 2 * r->nw) { m_tmp.push_back(0); m_tmp2.push_back(0); m_tmp3.push_back(0); m_tmp4.push_back(0); + m_mul_tmp.push_back(0); m_zero.push_back(0); m_one.push_back(0); m_a.push_back(0); @@ -115,90 +87,18 @@ namespace bv { return r; } - void sls_eval::init_eval_basic(app* e) { - auto id = e->get_id(); - if (m.is_bool(e)) - m_eval.setx(id, bval1(e), false); - else if (m.is_ite(e)) { - SASSERT(bv.is_bv(e->get_arg(1))); - auto& val = wval(e); - auto& val_th = wval(e->get_arg(1)); - auto& val_el = wval(e->get_arg(2)); - if (bval0(e->get_arg(0))) - val.set(val_th.bits()); - else - val.set(val_el.bits()); - } - else { - UNREACHABLE(); - } - } - void sls_eval::init_eval_bv(app* e) { + void bv_eval::init_eval_bv(app* e) { if (bv.is_bv(e)) - eval(e).commit_eval(); - else if (m.is_bool(e)) - m_eval.setx(e->get_id(), bval1_bv(e), false); - } - - bool sls_eval::bval1_basic(app* e) const { - SASSERT(m.is_bool(e)); - SASSERT(e->get_family_id() == basic_family_id); - - switch (e->get_decl_kind()) { - case OP_TRUE: - return true; - case OP_FALSE: - return false; - case OP_AND: - return all_of(*to_app(e), [&](expr* arg) { return bval0(arg); }); - case OP_OR: - return any_of(*to_app(e), [&](expr* arg) { return bval0(arg); }); - case OP_NOT: - return !bval0(e->get_arg(0)); - case OP_XOR: { - bool r = false; - for (auto* arg : *to_app(e)) - r ^= bval0(arg); - return r; - } - case OP_IMPLIES: { - auto a = e->get_arg(0); - auto b = e->get_arg(1); - return !bval0(a) || bval0(b); - } - case OP_ITE: { - auto c = bval0(e->get_arg(0)); - return bval0(c ? e->get_arg(1) : e->get_arg(2)); - } - case OP_EQ: { - auto a = e->get_arg(0); - auto b = e->get_arg(1); - if (m.is_bool(a)) - return bval0(a) == bval0(b); - else if (bv.is_bv(a)) { - auto const& va = wval(a); - auto const& vb = wval(b); - return va.eq(vb); - } - return m.are_equal(a, b); - } - case OP_DISTINCT: - default: - verbose_stream() << mk_bounded_pp(e, m) << "\n"; - UNREACHABLE(); - break; - } - UNREACHABLE(); - return false; + eval(e).commit_eval(); } - bool sls_eval::can_eval1(app* e) const { - expr* x, * y, * z; + bool bv_eval::can_eval1(app* e) const { + expr* x, * y; if (m.is_eq(e, x, y)) - return m.is_bool(x) || bv.is_bv(x); - if (m.is_ite(e, x, y, z)) - return m.is_bool(y) || bv.is_bv(y); + return bv.is_bv(x); + if (m.is_ite(e)) + return bv.is_bv(e->get_arg(0)); if (e->get_family_id() == bv.get_fid()) { switch (e->get_decl_kind()) { case OP_BNEG_OVFL: @@ -212,29 +112,29 @@ namespace bv { return true; } } - if (e->get_family_id() == basic_family_id) - return true; if (is_uninterp_const(e)) - return m.is_bool(e) || bv.is_bv(e); + return bv.is_bv(e); return false; } - bool sls_eval::bval1_bv(app* e) const { + bool bv_eval::bval1_bv(app* e, bool use_current) const { SASSERT(m.is_bool(e)); SASSERT(e->get_family_id() == bv.get_fid()); + bool use_current1 = use_current || (e->get_num_args() > 0 && !m_on_restore.is_marked(e->get_arg(0))); + bool use_current2 = use_current || (e->get_num_args() > 1 && !m_on_restore.is_marked(e->get_arg(1))); auto ucompare = [&](std::function const& f) { auto& a = wval(e->get_arg(0)); auto& b = wval(e->get_arg(1)); - return f(mpn.compare(a.bits().data(), a.nw, b.bits().data(), b.nw)); + return f(mpn.compare(a.tmp_bits(use_current1).data(), a.nw, b.tmp_bits(use_current2).data(), b.nw)); }; // x x + 2^{bw-1} const& f) { auto& a = wval(e->get_arg(0)); auto& b = wval(e->get_arg(1)); - add_p2_1(a, m_tmp); - add_p2_1(b, m_tmp2); + add_p2_1(a, use_current1, m_tmp); + add_p2_1(b, use_current2, m_tmp2); return f(mpn.compare(m_tmp.data(), a.nw, m_tmp2.data(), b.nw)); }; @@ -242,7 +142,7 @@ namespace bv { SASSERT(e->get_num_args() == 2); auto const& a = wval(e->get_arg(0)); auto const& b = wval(e->get_arg(1)); - return a.set_mul(m_tmp2, a.bits(), b.bits()); + return a.set_mul(m_tmp2, a.tmp_bits(use_current1), b.tmp_bits(use_current2)); }; switch (e->get_decl_kind()) { @@ -267,7 +167,7 @@ namespace bv { unsigned idx; VERIFY(bv.is_bit2bool(e, child, idx)); auto& a = wval(child); - return a.get_bit(idx); + return a.tmp_bits(use_current1).get(idx); } case OP_BUMUL_NO_OVFL: return !umul_overflow(); @@ -277,7 +177,7 @@ namespace bv { SASSERT(e->get_num_args() == 2); auto const& a = wval(e->get_arg(0)); auto const& b = wval(e->get_arg(1)); - return a.set_add(m_tmp, a.bits(), b.bits()); + return a.set_add(m_tmp, a.tmp_bits(use_current1), b.tmp_bits(use_current1)); } case OP_BNEG_OVFL: case OP_BSADD_OVFL: @@ -294,22 +194,45 @@ namespace bv { return false; } - bool sls_eval::bval1(app* e) const { - if (e->get_family_id() == basic_family_id) - return bval1_basic(e); + bool bv_eval::bval1(app* e) const { if (e->get_family_id() == bv.get_fid()) - return bval1_bv(e); - SASSERT(is_uninterp_const(e)); - return bval0(e); + return bval1_bv(e, true); + expr* x, * y; + if (m.is_eq(e, x, y) && bv.is_bv(x)) { + return wval(x).bits() == wval(y).bits(); + } + verbose_stream() << mk_bounded_pp(e, m) << "\n"; + UNREACHABLE(); + return false; } - sls_valuation& sls_eval::eval(app* e) const { + bool bv_eval::bval1_tmp(app* e) const { + if (e->get_family_id() == bv.get_fid()) + return bval1_bv(e, false); + expr* x, * y; + if (m.is_eq(e, x, y) && bv.is_bv(x)) { + bool use_current1 = !m_on_restore.is_marked(x); + bool use_current2 = !m_on_restore.is_marked(y); + return wval(x).tmp_bits(use_current1) == wval(y).tmp_bits(use_current2); + } + verbose_stream() << mk_bounded_pp(e, m) << "\n"; + UNREACHABLE(); + return false; + } + + // unsigned ddt_orig(expr* e); + + sls::bv_valuation& bv_eval::eval(app* e) const { auto& val = *m_values[e->get_id()]; eval(e, val); return val; } - void sls_eval::eval(app* e, sls_valuation& val) const { + void bv_eval::set(expr* e, sls::bv_valuation const& val) { + m_values[e->get_id()]->set(val.bits()); + } + + void bv_eval::eval(app* e, sls::bv_valuation& val) const { SASSERT(bv.is_bv(e)); if (m.is_ite(e)) { SASSERT(bv.is_bv(e->get_arg(1))); @@ -380,31 +303,46 @@ namespace bv { break; } case OP_BAND: { - SASSERT(e->get_num_args() == 2); + SASSERT(e->get_num_args() >= 2); auto const& a = wval(e->get_arg(0)); auto const& b = wval(e->get_arg(1)); for (unsigned i = 0; i < a.nw; ++i) val.eval[i] = a.bits()[i] & b.bits()[i]; + for (unsigned j = 2; j < e->get_num_args(); ++j) { + auto const& c = wval(e->get_arg(j)); + for (unsigned i = 0; i < a.nw; ++i) + val.eval[i] &= c.bits()[i]; + } break; } case OP_BOR: { - SASSERT(e->get_num_args() == 2); + SASSERT(e->get_num_args() >= 2); auto const& a = wval(e->get_arg(0)); auto const& b = wval(e->get_arg(1)); for (unsigned i = 0; i < a.nw; ++i) val.eval[i] = a.bits()[i] | b.bits()[i]; + for (unsigned j = 2; j < e->get_num_args(); ++j) { + auto const& c = wval(e->get_arg(j)); + for (unsigned i = 0; i < a.nw; ++i) + val.eval[i] |= c.bits()[i]; + } break; } case OP_BXOR: { - SASSERT(e->get_num_args() == 2); + SASSERT(e->get_num_args() >= 2); auto const& a = wval(e->get_arg(0)); auto const& b = wval(e->get_arg(1)); for (unsigned i = 0; i < a.nw; ++i) val.eval[i] = a.bits()[i] ^ b.bits()[i]; + for (unsigned j = 2; j < e->get_num_args(); ++j) { + auto const& c = wval(e->get_arg(j)); + for (unsigned i = 0; i < a.nw; ++i) + val.eval[i] ^= c.bits()[i]; + } break; } case OP_BNAND: { - SASSERT(e->get_num_args() == 2); + VERIFY(e->get_num_args() == 2); auto const& a = wval(e->get_arg(0)); auto const& b = wval(e->get_arg(1)); for (unsigned i = 0; i < a.nw; ++i) @@ -412,10 +350,15 @@ namespace bv { break; } case OP_BADD: { - SASSERT(e->get_num_args() == 2); + SASSERT(e->get_num_args() >= 2); auto const& a = wval(e->get_arg(0)); auto const& b = wval(e->get_arg(1)); - val.set_add(val.eval, a.bits(), b.bits()); + for (unsigned i = 0; i < a.nw; ++i) + val.set_add(val.eval, a.bits(), b.bits()); + for (unsigned j = 2; j < e->get_num_args(); ++j) { + auto const& c = wval(e->get_arg(j)); + val.set_add(val.eval, val.eval, c.bits()); + } break; } case OP_BSUB: { @@ -426,21 +369,24 @@ namespace bv { break; } case OP_BMUL: { - SASSERT(e->get_num_args() == 2); + SASSERT(e->get_num_args() > 1); auto const& a = wval(e->get_arg(0)); auto const& b = wval(e->get_arg(1)); - val.set_mul(m_tmp2, a.bits(), b.bits()); - val.set(m_tmp2); + val.set_mul(val.eval, a.bits(), b.bits(), false); + for (unsigned j = 2; j < e->get_num_args(); ++j) { + auto const& c = wval(e->get_arg(j)); + val.set_mul(val.eval, val.eval, c.bits(), false); + } break; } case OP_CONCAT: { - SASSERT(e->get_num_args() == 2); - auto const& a = wval(e->get_arg(0)); - auto const& b = wval(e->get_arg(1)); - for (unsigned i = 0; i < b.bw; ++i) - val.eval.set(i, b.get_bit(i)); - for (unsigned i = 0; i < a.bw; ++i) - val.eval.set(i + b.bw, a.get_bit(i)); + unsigned bw = 0; + for (unsigned i = e->get_num_args(); i-- > 0;) { + auto const& a = wval(e->get_arg(i)); + for (unsigned j = 0; j < a.bw; ++j) + val.eval.set(j + bw, a.get_bit(j)); + bw += a.bw; + } break; } case OP_EXTRACT: { @@ -659,11 +605,19 @@ namespace bv { val.set(val.eval, 0); break; } + case OP_INT2BV: { + expr_ref v = ctx.get_value(e->get_arg(0)); + th_rewriter rw(m); + v = bv.mk_int2bv(bv.get_bv_size(e), v); + rw(v); + rational r; + VERIFY(bv.is_numeral(v, r)); + val.set_value(val.eval, r); + break; + } case OP_BREDAND: case OP_BREDOR: case OP_BXNOR: - case OP_INT2BV: - verbose_stream() << mk_bounded_pp(e, m) << "\n"; NOT_IMPLEMENTED_YET(); break; @@ -695,64 +649,99 @@ namespace bv { val.clear_overflow_bits(val.eval); } - digit_t sls_eval::random_bits() { - return sls_valuation::random_bits(m_rand); + digit_t bv_eval::random_bits() { + return sls::bv_valuation::random_bits(m_rand); } - bool sls_eval::try_repair(app* e, unsigned i) { - if (is_fixed0(e->get_arg(i))) - return false; - else if (e->get_family_id() == basic_family_id) - return try_repair_basic(e, i); - if (e->get_family_id() == bv.get_family_id()) - return try_repair_bv(e, i); - return false; - } - bool sls_eval::try_repair_basic(app* e, unsigned i) { - switch (e->get_decl_kind()) { - case OP_AND: - return try_repair_and_or(e, i); - case OP_OR: - return try_repair_and_or(e, i); - case OP_NOT: - return try_repair_not(e); - case OP_FALSE: - return false; - case OP_TRUE: + bool bv_eval::is_uninterpreted(app* e) const { + if (is_uninterp(e)) + return true; + if (e->get_family_id() != bv.get_family_id()) return false; - case OP_EQ: - return try_repair_eq(e, i); - case OP_IMPLIES: - return try_repair_implies(e, i); - case OP_XOR: - return try_repair_xor(e, i); - case OP_ITE: - return try_repair_ite(e, i); + switch (e->get_decl_kind()) { + case OP_BSREM: + case OP_BSREM_I: + case OP_BSREM0: + case OP_BSMOD: + case OP_BSMOD_I: + case OP_BSMOD0: + case OP_BSDIV: + case OP_BSDIV_I: + case OP_BSDIV0: + return true; default: - UNREACHABLE(); return false; } } + + bool bv_eval::repair_down(app* e, unsigned i) { + expr* arg = e->get_arg(i); + if (m.is_value(arg)) + return false; + if (e->get_family_id() == bv.get_family_id() && try_repair_bv(e, i)) { + commit_eval(e, to_app(arg)); + IF_VERBOSE(11, verbose_stream() << "repair " << mk_bounded_pp(e, m) << " : " << mk_bounded_pp(arg, m) << " := " << wval(arg) << "\n";); + ctx.new_value_eh(arg); + return true; + } + if (m.is_eq(e) && bv.is_bv(arg) && try_repair_eq(e, i)) { + commit_eval(e, to_app(arg)); + IF_VERBOSE(11, verbose_stream() << mk_bounded_pp(arg, m) << " := " << wval(arg) << "\n";); + ctx.new_value_eh(arg); + return true; + } + if (m.is_eq(e) && bv.is_bv(arg)) { + return try_repair_eq_lookahead(e); + } + return false; + } - bool sls_eval::try_repair_bv(app* e, unsigned i) { + bool bv_eval::try_repair_bv(app* e, unsigned i) { switch (e->get_decl_kind()) { case OP_BAND: - return try_repair_band(eval_value(e), wval(e, i), wval(e, 1 - i)); + SASSERT(e->get_num_args() >= 2); + if (e->get_num_args() == 2) + return try_repair_band(assign_value(e), wval(e, i), wval(e, 1 - i)); + else + return try_repair_band(e, i); case OP_BOR: - return try_repair_bor(eval_value(e), wval(e, i), wval(e, 1 - i)); + SASSERT(e->get_num_args() >= 2); + if (e->get_num_args() == 2) + return try_repair_bor(assign_value(e), wval(e, i), wval(e, 1 - i)); + else + return try_repair_bor(e, i); case OP_BXOR: - return try_repair_bxor(eval_value(e), wval(e, i), wval(e, 1 - i)); + SASSERT(e->get_num_args() >= 2); + if (e->get_num_args() == 2) + return try_repair_bxor(assign_value(e), wval(e, i), wval(e, 1 - i)); + else + return try_repair_bxor(e, i); case OP_BADD: - return try_repair_add(eval_value(e), wval(e, i), wval(e, 1 - i)); + SASSERT(e->get_num_args() >= 2); + if (e->get_num_args() == 2) + return try_repair_add(assign_value(e), wval(e, i), wval(e, 1 - i)); + else + return try_repair_add(e, i); case OP_BSUB: - return try_repair_sub(eval_value(e), wval(e, 0), wval(e, 1), i); + return try_repair_sub(assign_value(e), wval(e, 0), wval(e, 1), i); case OP_BMUL: - return try_repair_mul(eval_value(e), wval(e, i), wval(e, 1 - i)); + SASSERT(e->get_num_args() >= 2); + if (e->get_num_args() == 2) + return try_repair_mul(assign_value(e), wval(e, i), assign_value(to_app(e->get_arg(1 - i)))); + else { + auto const& a = wval(e, 0); + auto f = [&](bvect& out, bvval const& c) { + a.set_mul(out, out, c.bits()); + }; + fold_oper(m_mul_tmp, e, i, f); + m_mul_tmp.set_bw(a.bw); + return try_repair_mul(assign_value(e), wval(e, i), m_mul_tmp); + } case OP_BNOT: - return try_repair_bnot(eval_value(e), wval(e, i)); + return try_repair_bnot(assign_value(e), wval(e, i)); case OP_BNEG: - return try_repair_bneg(eval_value(e), wval(e, i)); + return try_repair_bneg(assign_value(e), wval(e, i)); case OP_BIT0: return false; case OP_BIT1: @@ -760,7 +749,7 @@ namespace bv { case OP_BV2INT: return false; case OP_INT2BV: - return false; + return try_repair_int2bv(assign_value(e), e->get_arg(0)); case OP_ULEQ: if (i == 0) return try_repair_ule(bval0(e), wval(e, i), wval(e, 1 - i)); @@ -802,11 +791,11 @@ namespace bv { else return try_repair_sle(!bval0(e), wval(e, i), wval(e, 1 - i)); case OP_BASHR: - return try_repair_ashr(eval_value(e), wval(e, 0), wval(e, 1), i); + return try_repair_ashr(assign_value(e), wval(e, 0), wval(e, 1), i); case OP_BLSHR: - return try_repair_lshr(eval_value(e), wval(e, 0), wval(e, 1), i); + return try_repair_lshr(assign_value(e), wval(e, 0), wval(e, 1), i); case OP_BSHL: - return try_repair_shl(eval_value(e), wval(e, 0), wval(e, 1), i); + return try_repair_shl(assign_value(e), wval(e, 0), wval(e, 1), i); case OP_BIT2BOOL: { unsigned idx; expr* arg; @@ -817,37 +806,37 @@ namespace bv { case OP_BUDIV: case OP_BUDIV_I: case OP_BUDIV0: - return try_repair_udiv(eval_value(e), wval(e, 0), wval(e, 1), i); + return try_repair_udiv(assign_value(e), wval(e, 0), wval(e, 1), i); case OP_BUREM: case OP_BUREM_I: case OP_BUREM0: - return try_repair_urem(eval_value(e), wval(e, 0), wval(e, 1), i); + return try_repair_urem(assign_value(e), wval(e, 0), wval(e, 1), i); case OP_ROTATE_LEFT: - return try_repair_rotate_left(eval_value(e), wval(e, 0), e->get_parameter(0).get_int()); + return try_repair_rotate_left(assign_value(e), wval(e, 0), e->get_parameter(0).get_int()); case OP_ROTATE_RIGHT: - return try_repair_rotate_left(eval_value(e), wval(e, 0), wval(e).bw - e->get_parameter(0).get_int()); + return try_repair_rotate_left(assign_value(e), wval(e, 0), wval(e).bw - e->get_parameter(0).get_int()); case OP_EXT_ROTATE_LEFT: - return try_repair_rotate_left(eval_value(e), wval(e, 0), wval(e, 1), i); + return try_repair_rotate_left(assign_value(e), wval(e, 0), wval(e, 1), i); case OP_EXT_ROTATE_RIGHT: - return try_repair_rotate_right(eval_value(e), wval(e, 0), wval(e, 1), i); + return try_repair_rotate_right(assign_value(e), wval(e, 0), wval(e, 1), i); case OP_ZERO_EXT: - return try_repair_zero_ext(eval_value(e), wval(e, 0)); + return try_repair_zero_ext(assign_value(e), wval(e, 0)); case OP_SIGN_EXT: - return try_repair_sign_ext(eval_value(e), wval(e, 0)); - case OP_CONCAT: - return try_repair_concat(eval_value(e), wval(e, 0), wval(e, 1), i); + return try_repair_sign_ext(assign_value(e), wval(e, 0)); + case OP_CONCAT: + return try_repair_concat(e, i); case OP_EXTRACT: { unsigned hi, lo; expr* arg; VERIFY(bv.is_extract(e, lo, hi, arg)); - return try_repair_extract(eval_value(e), wval(arg), lo); + return try_repair_extract(assign_value(e), wval(arg), lo); } case OP_BUMUL_NO_OVFL: return try_repair_umul_ovfl(!bval0(e), wval(e, 0), wval(e, 1), i); case OP_BUMUL_OVFL: return try_repair_umul_ovfl(bval0(e), wval(e, 0), wval(e, 1), i); case OP_BCOMP: - return try_repair_comp(eval_value(e), wval(e, 0), wval(e, 1), i); + return try_repair_comp(assign_value(e), wval(e, 0), wval(e, 1), i); case OP_BUADD_OVFL: case OP_BNAND: @@ -871,39 +860,20 @@ namespace bv { case OP_BSDIV: case OP_BSDIV_I: case OP_BSDIV0: - // these are currently compiled to udiv and urem. UNREACHABLE(); - return false; + // these are currently compiled to udiv and urem. + // there is an equation that enforces equality between the semantics + // of these operators. + return true; default: return false; } } - bool sls_eval::try_repair_and_or(app* e, unsigned i) { - auto b = bval0(e); - auto child = e->get_arg(i); - if (b == bval0(child)) - return false; - m_eval[child->get_id()] = b; - return true; - } - - bool sls_eval::try_repair_not(app* e) { - auto child = e->get_arg(0); - m_eval[child->get_id()] = !bval0(e); - return true; - } - - bool sls_eval::try_repair_eq(app* e, unsigned i) { + bool bv_eval::try_repair_eq(app* e, unsigned i) { auto child = e->get_arg(i); auto is_true = bval0(e); - if (m.is_bool(child)) { - SASSERT(!is_fixed0(child)); - auto bv = bval0(e->get_arg(1 - i)); - m_eval[child->get_id()] = is_true == bv; - return true; - } - else if (bv.is_bv(child)) { + if (bv.is_bv(child)) { auto & a = wval(e->get_arg(i)); auto & b = wval(e->get_arg(1 - i)); return try_repair_eq(is_true, a, b); @@ -911,72 +881,100 @@ namespace bv { return false; } - bool sls_eval::try_repair_eq(bool is_true, bvval& a, bvval const& b) { + bool bv_eval::try_repair_eq_lookahead(app* e) { + return false; + auto is_true = bval0(e); + if (!is_true) + return false; + auto const& uninterp = terms.uninterp_occurs(e); + if (uninterp.empty()) + return false; +// for (auto e : uninterp) +// verbose_stream() << mk_bounded_pp(e, m) << " "; +// verbose_stream() << "\n"; + + expr* t = uninterp[m_rand() % uninterp.size()]; + + auto& v = wval(t); + if (v.set_random(m_rand)) { + //verbose_stream() << "set random " << mk_bounded_pp(t, m) << "\n"; + ctx.new_value_eh(t); + return true; + } + return false; + + + for (auto e : uninterp) { + auto& v = wval(e); + v.get_variant(m_tmp, m_rand); + auto d = lookahead(e, m_tmp); + //verbose_stream() << mk_bounded_pp(e, m) << " " << d << "\n"; + } + return false; + } + + bool bv_eval::try_repair_eq(bool is_true, bvval& a, bvval const& b) { if (is_true) { +#if 0 + if (bv.is_bv_add(t)) { + bvval tmp(b); + unsigned start = m_rand(); + unsigned sz = to_app(t)->get_num_args(); + for (unsigned i = 0; i < sz; ++i) { + unsigned j = (start + i) % sz; + for (unsigned k = 0; k < sz; ++k) { + if (k == j) + continue; + auto& c = wval(to_app(t)->get_arg(k)); + set_sub(tmp, tmp, c.bits()); + } + + auto& c = wval(to_app(t)->get_arg(j)); + verbose_stream() << "TRY " << c << " := " << tmp << "\n"; + + + } + } +#endif if (m_rand(20) != 0) if (a.try_set(b.bits())) return true; - - return a.set_random(m_rand); + + if (m_rand(20) == 0) + return a.set_random(m_rand); + return false; } else { bool try_above = m_rand(2) == 0; + m_tmp.set_bw(a.bw); if (try_above) { a.set_add(m_tmp, b.bits(), m_one); - if (!a.is_zero(m_tmp) && a.set_random_at_least(m_tmp, m_rand)) + if (a.set_random_at_least(m_tmp, m_rand) && m_tmp != b.bits()) return true; } a.set_sub(m_tmp, b.bits(), m_one); - if (!a.is_zero(m_tmp) && a.set_random_at_most(m_tmp, m_rand)) + if (a.set_random_at_most(m_tmp, m_rand) && m_tmp != b.bits()) return true; if (!try_above) { a.set_add(m_tmp, b.bits(), m_one); - if (!a.is_zero(m_tmp) && a.set_random_at_least(m_tmp, m_rand)) + if (a.set_random_at_least(m_tmp, m_rand) && m_tmp != b.bits()) return true; } return false; } } - bool sls_eval::try_repair_xor(app* e, unsigned i) { - bool ev = bval0(e); - bool bv = bval0(e->get_arg(1 - i)); - auto child = e->get_arg(i); - m_eval[child->get_id()] = ev != bv; - return true; - } - - bool sls_eval::try_repair_ite(app* e, unsigned i) { - auto child = e->get_arg(i); - bool c = bval0(e->get_arg(0)); - if (i == 0) { - m_eval[child->get_id()] = !c; - return true; - } - if (c != (i == 1)) - return false; - if (m.is_bool(e)) { - m_eval[child->get_id()] = bval0(e); - return true; - } - if (bv.is_bv(e)) - return wval(child).try_set(wval(e).bits()); - return false; - } - - bool sls_eval::try_repair_implies(app* e, unsigned i) { - auto child = e->get_arg(i); - bool ev = bval0(e); - bool av = bval0(child); - bool bv = bval0(e->get_arg(1 - i)); - if (i == 0) { - if (ev == (!av || bv)) - return false; + void bv_eval::fold_oper(bvect& out, app* t, unsigned i, std::function const& f) { + auto i2 = i == 0 ? 1 : 0; + auto const& c = wval(t->get_arg(i2)); + for (unsigned j = 0; j < c.nw; ++j) + out[j] = c.bits()[j]; + for (unsigned k = 1; k < t->get_num_args(); ++k) { + if (k == i || k == i2) + continue; + bvval const& c = wval(t->get_arg(k)); + f(out, c); } - else if (ev != (!bv || av)) - return false; - m_eval[child->get_id()] = ev; - return true; } // @@ -986,44 +984,109 @@ namespace bv { // e[i] = 0 & b[i] = 0 -> a[i] = random // a := e[i] | (~b[i] & a[i]) - bool sls_eval::try_repair_band(bvect const& e, bvval& a, bvval const& b) { + bool bv_eval::try_repair_band(bvect const& e, bvval& a, bvval const& b) { for (unsigned i = 0; i < a.nw; ++i) m_tmp[i] = ~a.fixed[i] & (e[i] | (~b.bits()[i] & random_bits())); return a.set_repair(random_bool(), m_tmp); } + bool bv_eval::try_repair_band(app* t, unsigned i) { + bvect const& e = assign_value(t); + auto f = [&](bvect& out, bvval const& c) { + for (unsigned j = 0; j < c.nw; ++j) + out[j] &= c.bits()[j]; + }; + fold_oper(m_tmp2, t, i, f); + + bvval& a = wval(t, i); + for (unsigned j = 0; j < a.nw; ++j) + m_tmp[j] = ~a.fixed[j] & (e[j] | (~m_tmp2[j] & random_bits())); + + return a.set_repair(random_bool(), m_tmp); + } + // // e = a | b // set a[i] to 1 where b[i] = 0, e[i] = 1 // set a[i] to 0 where e[i] = 0, a[i] = 1 // - bool sls_eval::try_repair_bor(bvect const& e, bvval& a, bvval const& b) { + bool bv_eval::try_repair_bor(bvect const& e, bvval& a, bvval const& b) { for (unsigned i = 0; i < a.nw; ++i) m_tmp[i] = e[i] & (~b.bits()[i] | random_bits()); return a.set_repair(random_bool(), m_tmp); } - bool sls_eval::try_repair_bxor(bvect const& e, bvval& a, bvval const& b) { + bool bv_eval::try_repair_bor(app* t, unsigned i) { + bvect const& e = assign_value(t); + auto f = [&](bvect& out, bvval const& c) { + for (unsigned j = 0; j < c.nw; ++j) + out[j] |= c.bits()[j]; + }; + fold_oper(m_tmp2, t, i, f); + bvval& a = wval(t, i); + m_tmp.set_bw(a.bw); + for (unsigned j = 0; j < a.nw; ++j) + m_tmp[j] = e[j] & (~m_tmp2[j] | random_bits()); + + //verbose_stream() << wval(t) << " " << m_tmp << "\n"; + return a.set_repair(random_bool(), m_tmp); + } + + bool bv_eval::try_repair_bxor(bvect const& e, bvval& a, bvval const& b) { for (unsigned i = 0; i < a.nw; ++i) m_tmp[i] = e[i] ^ b.bits()[i]; return a.set_repair(random_bool(), m_tmp); } + + bool bv_eval::try_repair_bxor(app* t, unsigned i) { + bvect const& e = assign_value(t); + auto f = [&](bvect& out, bvval const& c) { + for (unsigned j = 0; j < c.nw; ++j) + out[j] ^= c.bits()[j]; + }; + fold_oper(m_tmp2, t, i, f); + + bvval& a = wval(t, i); + for (unsigned j = 0; j < a.nw; ++j) + m_tmp[j] = e[j] ^ m_tmp2[j]; + + return a.set_repair(random_bool(), m_tmp); + } + + // // first try to set a := e - b // If this fails, set a to a random value // - bool sls_eval::try_repair_add(bvect const& e, bvval& a, bvval const& b) { + bool bv_eval::try_repair_add(bvect const& e, bvval& a, bvval const& b) { if (m_rand(20) != 0) { + m_tmp.set_bw(a.bw); a.set_sub(m_tmp, e, b.bits()); + // verbose_stream() << "set-sub " << e << " - " << b.bits() << " = " << m_tmp << "\n"; if (a.try_set(m_tmp)) return true; } return a.set_random(m_rand); } - bool sls_eval::try_repair_sub(bvect const& e, bvval& a, bvval & b, unsigned i) { + bool bv_eval::try_repair_add(app* t, unsigned i) { + bvval& a = wval(t, i); + bvect const& e = assign_value(t); + if (m_rand(20) != 0) { + auto f = [&](bvect& out, bvval const& c) { + a.set_add(m_tmp2, m_tmp2, c.bits()); + }; + fold_oper(m_tmp2, t, i, f); + a.set_sub(m_tmp, e, m_tmp2); + if (a.try_set(m_tmp)) + return true; + } + return a.set_random(m_rand); + } + + bool bv_eval::try_repair_sub(bvect const& e, bvval& a, bvval & b, unsigned i) { if (m_rand(20) != 0) { if (i == 0) // e = a - b -> a := e + b @@ -1035,18 +1098,21 @@ namespace bv { return true; } // fall back to a random value - return a.set_random(m_rand); + return i == 0 ? a.set_random(m_rand) : b.set_random(m_rand); } /** * e = a*b, then a = e * b^-1 * 8*e = a*(2b), then a = 4e*b^-1 */ - bool sls_eval::try_repair_mul(bvect const& e, bvval& a, bvval const& b) { - unsigned parity_e = b.parity(e); - unsigned parity_b = b.parity(b.bits()); + bool bv_eval::try_repair_mul(bvect const& e, bvval& a, bvect const& b) { + // verbose_stream() << e << " := " << a << " * " << b << "\n"; + unsigned parity_e = a.parity(e); + unsigned parity_b = a.parity(b); - if (b.is_zero(e)) { + if (a.is_zero(b)) { + if (a.try_set(e)) + return true; a.get_variant(m_tmp, m_rand); if (m_rand(10) != 0) for (unsigned i = 0; i < b.bw - parity_b; ++i) @@ -1054,11 +1120,13 @@ namespace bv { return a.set_repair(random_bool(), m_tmp); } - if (b.is_zero() || m_rand(20) == 0) { + if (m_rand(20) == 0) { a.get_variant(m_tmp, m_rand); return a.set_repair(random_bool(), m_tmp); } + + #if 0 verbose_stream() << "solve for " << e << "\n"; @@ -1082,14 +1150,18 @@ namespace bv { // x*ta + y*tb = x - b.get(y); + y.set_bw(a.bw); + b.copy_to(a.nw, y); + //verbose_stream() << "a.nw " << a.nw << " b.nw " << b.nw << " b " << b << " y.nw " << y.nw << " y " << y << "\n"; if (parity_b > 0) { - b.shift_right(y, parity_b); + a.shift_right(y, parity_b); + #if 0 for (unsigned i = parity_b; i < b.bw; ++i) y.set(i, m_rand(2) == 0); #endif } + //verbose_stream() << parity_b << " y " << y << "\n"; y[a.nw] = 0; x[a.nw] = 0; @@ -1129,30 +1201,30 @@ namespace bv { tb.set_bw(0); #if Z3DEBUG - b.get(y); + b.copy_to(a.nw, y); if (parity_b > 0) - b.shift_right(y, parity_b); - a.set_mul(m_tmp, tb, y); + a.shift_right(y, parity_b); + a.set_mul(m_tmp, tb, y, false); SASSERT(a.is_one(m_tmp)); #endif e.copy_to(b.nw, m_tmp2); if (parity_e > 0 && parity_b > 0) - b.shift_right(m_tmp2, std::min(parity_b, parity_e)); + a.shift_right(m_tmp2, std::min(parity_b, parity_e)); a.set_mul(m_tmp, tb, m_tmp2); if (a.set_repair(random_bool(), m_tmp)) return true; - + return a.set_random(m_rand); } - bool sls_eval::try_repair_bnot(bvect const& e, bvval& a) { + bool bv_eval::try_repair_bnot(bvect const& e, bvval& a) { for (unsigned i = 0; i < a.nw; ++i) m_tmp[i] = ~e[i]; a.clear_overflow_bits(m_tmp); return a.try_set(m_tmp); } - bool sls_eval::try_repair_bneg(bvect const& e, bvval& a) { + bool bv_eval::try_repair_bneg(bvect const& e, bvval& a) { a.set_sub(m_tmp, m_zero, e); return a.try_set(m_tmp); } @@ -1166,7 +1238,7 @@ namespace bv { // infeasible if b + 1 = p2 // solve for x >=s b + 1 // - bool sls_eval::try_repair_sle(bool e, bvval& a, bvval const& b) { + bool bv_eval::try_repair_sle(bool e, bvval& a, bvval const& b) { auto& p2 = m_b; b.set_zero(p2); p2.set(b.bw - 1, true); @@ -1192,7 +1264,7 @@ namespace bv { // infeasible if b = 0 // solve for x <=s b - 1 // - bool sls_eval::try_repair_sge(bool e, bvval& a, bvval const& b) { + bool bv_eval::try_repair_sge(bool e, bvval& a, bvval const& b) { auto& p2 = m_b; b.set_zero(p2); p2.set(b.bw - 1, true); @@ -1223,7 +1295,7 @@ namespace bv { // or // x := random p2 <= x <= b if c < p2 (b >= p2) // - bool sls_eval::try_repair_sle(bvval& a, bvect const& b, bvect const& p2) { + bool bv_eval::try_repair_sle(bvval& a, bvect const& b, bvect const& p2) { bool r = false; if (b < p2) { bool coin = m_rand(2) == 0; @@ -1248,42 +1320,44 @@ namespace bv { // x := random b <= x or x < p2 if d < p2 // - bool sls_eval::try_repair_sge(bvval& a, bvect const& b, bvect const& p2) { + bool bv_eval::try_repair_sge(bvval& a, bvect const& b, bvect const& p2) { auto& p2_1 = m_tmp4; a.set_sub(p2_1, p2, m_one); p2_1.set_bw(a.bw); bool r = false; - if (p2 < b) + if (b < p2) // random b <= x < p2 r = a.set_random_in_range(b, p2_1, m_rand); else { // random b <= x or x < p2 bool coin = m_rand(2) == 0; if (coin) - r = a.set_random_at_most(p2_1,m_rand); + r = a.set_random_at_most(p2_1, m_rand); if (!r) - r = a.set_random_at_least(b, m_rand); + r = a.set_random_at_least(b, m_rand); if (!r && !coin) - r = a.set_random_at_most(p2_1, m_rand); + r = a.set_random_at_most(p2_1, m_rand); } p2_1.set_bw(0); return r; } - void sls_eval::add_p2_1(bvval const& a, bvect& t) const { + void bv_eval::add_p2_1(bvval const& a, bool use_current, bvect& t) const { m_zero.set(a.bw - 1, true); - a.set_add(t, a.bits(), m_zero); + a.set_add(t, a.tmp_bits(use_current), m_zero); m_zero.set(a.bw - 1, false); a.clear_overflow_bits(t); } - bool sls_eval::try_repair_ule(bool e, bvval& a, bvval const& b) { + bool bv_eval::try_repair_ule(bool e, bvval& a, bvval const& b) { + //verbose_stream() << "try-repair-ule " << e << " " << a << " " << b << "\n"; if (e) { // a <= t return a.set_random_at_most(b.bits(), m_rand); } else { // a > t + m_tmp.set_bw(a.bw); a.set_add(m_tmp, b.bits(), m_one); if (a.is_zero(m_tmp)) return false; @@ -1291,13 +1365,14 @@ namespace bv { } } - bool sls_eval::try_repair_uge(bool e, bvval& a, bvval const& b) { + bool bv_eval::try_repair_uge(bool e, bvval& a, bvval const& b) { if (e) { // a >= t return a.set_random_at_least(b.bits(), m_rand); } else { // a < t + m_tmp.set_bw(a.bw); if (b.is_zero()) return false; a.set_sub(m_tmp, b.bits(), m_one); @@ -1305,11 +1380,11 @@ namespace bv { } } - bool sls_eval::try_repair_bit2bool(bvval& a, unsigned idx) { + bool bv_eval::try_repair_bit2bool(bvval& a, unsigned idx) { return a.try_set_bit(idx, !a.get_bit(idx)); } - bool sls_eval::try_repair_shl(bvect const& e, bvval& a, bvval& b, unsigned i) { + bool bv_eval::try_repair_shl(bvect const& e, bvval& a, bvval& b, unsigned i) { if (i == 0) { unsigned sh = b.to_nat(b.bw); if (sh == 0) @@ -1332,23 +1407,39 @@ namespace bv { } } else { - // NB. blind sub-range of possible values for b SASSERT(i == 1); - unsigned sh = m_rand(a.bw + 1); - b.set(m_tmp, sh); - return b.try_set(m_tmp); + if (a.is_zero()) + return b.set_random(m_rand); + + unsigned start = m_rand(); + for (unsigned j = 0; j <= a.bw; ++j) { + unsigned sh = (j + start) % (a.bw + 1); + m_tmp.set_bw(a.bw); + m_tmp2.set_bw(a.bw); + b.set(m_tmp, sh); + if (!b.can_set(m_tmp)) + continue; + m_tmp2.set_shift_left(a.bits(), m_tmp); + if (m_tmp2 == e && b.try_set(m_tmp)) + return true; + } + + if (m_rand(2) == 0) + return false; + + return b.set_random(m_rand); } return false; } - bool sls_eval::try_repair_ashr(bvect const& e, bvval & a, bvval& b, unsigned i) { + bool bv_eval::try_repair_ashr(bvect const& e, bvval & a, bvval& b, unsigned i) { if (i == 0) return try_repair_ashr0(e, a, b); else return try_repair_ashr1(e, a, b); } - bool sls_eval::try_repair_lshr(bvect const& e, bvval& a, bvval& b, unsigned i) { + bool bv_eval::try_repair_lshr(bvect const& e, bvval& a, bvval& b, unsigned i) { if (i == 0) return try_repair_lshr0(e, a, b); else @@ -1362,7 +1453,7 @@ namespace bv { * - e = 0 -> a := random * - e > 0 -> a := random with msb(a) >= msb(e) */ - bool sls_eval::try_repair_lshr0(bvect const& e, bvval& a, bvval const& b) { + bool bv_eval::try_repair_lshr0(bvect const& e, bvval& a, bvval const& b) { auto& t = m_tmp; // t := e << b @@ -1441,7 +1532,7 @@ namespace bv { * - e = 0: b := random * - e > 0: b := random >= clz(e) */ - bool sls_eval::try_repair_lshr1(bvect const& e, bvval const& a, bvval& b) { + bool bv_eval::try_repair_lshr1(bvect const& e, bvval const& a, bvval& b) { auto& t = m_tmp; auto clza = a.clz(a.bits()); @@ -1495,7 +1586,7 @@ namespace bv { * weak: * */ - bool sls_eval::try_repair_ashr0(bvect const& e, bvval& a, bvval const& b) { + bool bv_eval::try_repair_ashr0(bvect const& e, bvval& a, bvval const& b) { auto& t = m_tmp; t.set_bw(b.bw); auto n = b.msb(b.bits()); @@ -1556,7 +1647,7 @@ namespace bv { * - e > 0: b := random >= clz(e) */ - bool sls_eval::try_repair_ashr1(bvect const& e, bvval const& a, bvval& b) { + bool bv_eval::try_repair_ashr1(bvect const& e, bvval const& a, bvval& b) { auto& t = m_tmp; auto clza = a.clz(a.bits()); @@ -1598,7 +1689,7 @@ namespace bv { return b.set_repair(random_bool(), t); } - bool sls_eval::try_repair_comp(bvect const& e, bvval& a, bvval& b, unsigned i) { + bool bv_eval::try_repair_comp(bvect const& e, bvval& a, bvval& b, unsigned i) { SASSERT(e[0] == 0 || e[0] == 1); SASSERT(e.bw == 1); return try_repair_eq(e[0] == 1, i == 0 ? a : b, i == 0 ? b : a); @@ -1609,7 +1700,7 @@ namespace bv { // b = 0 => e = -1 // nothing to repair on a // e != -1 => max(a) >=u e - bool sls_eval::try_repair_udiv(bvect const& e, bvval& a, bvval& b, unsigned i) { + bool bv_eval::try_repair_udiv(bvect const& e, bvval& a, bvval& b, unsigned i) { if (i == 0) { if (a.is_zero(e) && a.is_ones(a.fixed) && a.is_ones()) return false; @@ -1674,7 +1765,7 @@ namespace bv { // (s != t => exists y . (mcb(x, y) and y >u t and (s - t) mod y = 0) - bool sls_eval::try_repair_urem(bvect const& e, bvval& a, bvval& b, unsigned i) { + bool bv_eval::try_repair_urem(bvect const& e, bvval& a, bvval& b, unsigned i) { if (i == 0) { if (b.is_zero()) { @@ -1719,21 +1810,21 @@ namespace bv { } } - bool sls_eval::add_overflow_on_fixed(bvval const& a, bvect const& t) { + bool bv_eval::add_overflow_on_fixed(bvval const& a, bvect const& t) { a.set(m_tmp3, m_zero); for (unsigned i = 0; i < a.nw; ++i) m_tmp3[i] = a.fixed[i] & a.bits()[i]; return a.set_add(m_tmp4, t, m_tmp3); } - bool sls_eval::mul_overflow_on_fixed(bvval const& a, bvect const& t) { + bool bv_eval::mul_overflow_on_fixed(bvval const& a, bvect const& t) { a.set(m_tmp3, m_zero); for (unsigned i = 0; i < a.nw; ++i) m_tmp3[i] = a.fixed[i] & a.bits()[i]; return a.set_mul(m_tmp4, m_tmp3, t); } - bool sls_eval::try_repair_rotate_left(bvect const& e, bvval& a, unsigned n) const { + bool bv_eval::try_repair_rotate_left(bvect const& e, bvval& a, unsigned n) const { // a := rotate_right(e, n) n = (a.bw - n) % a.bw; for (unsigned i = a.bw - n; i < a.bw; ++i) @@ -1743,7 +1834,7 @@ namespace bv { return a.set_repair(true, m_tmp); } - bool sls_eval::try_repair_rotate_left(bvect const& e, bvval& a, bvval& b, unsigned i) { + bool bv_eval::try_repair_rotate_left(bvect const& e, bvval& a, bvval& b, unsigned i) { if (i == 0) { rational n = b.get_value(); n = mod(n, rational(b.bw)); @@ -1757,7 +1848,7 @@ namespace bv { } } - bool sls_eval::try_repair_rotate_right(bvect const& e, bvval& a, bvval& b, unsigned i) { + bool bv_eval::try_repair_rotate_right(bvect const& e, bvval& a, bvval& b, unsigned i) { if (i == 0) { rational n = b.get_value(); n = mod(b.bw - n, rational(b.bw)); @@ -1771,7 +1862,7 @@ namespace bv { } } - bool sls_eval::try_repair_umul_ovfl(bool e, bvval& a, bvval& b, unsigned i) { + bool bv_eval::try_repair_umul_ovfl(bool e, bvval& a, bvval& b, unsigned i) { if (e) { // maximize if (i == 0) { @@ -1800,7 +1891,7 @@ namespace bv { // prefix of e must be 1s or 0 and match bit position of last bit in a. // set a to suffix of e, matching signs. // - bool sls_eval::try_repair_sign_ext(bvect const& e, bvval& a) { + bool bv_eval::try_repair_sign_ext(bvect const& e, bvval& a) { for (unsigned i = a.bw; i < e.bw; ++i) if (e.get(i) != e.get(a.bw - 1)) return false; @@ -1814,7 +1905,7 @@ namespace bv { // // prefix of e must be 0s. // - bool sls_eval::try_repair_zero_ext(bvect const& e, bvval& a) { + bool bv_eval::try_repair_zero_ext(bvect const& e, bvval& a) { for (unsigned i = a.bw; i < e.bw; ++i) if (e.get(i)) return false; @@ -1825,21 +1916,16 @@ namespace bv { return a.try_set(m_tmp); } - bool sls_eval::try_repair_concat(bvect const& e, bvval& a, bvval& b, unsigned idx) { - bool r = false; - if (idx == 0) { - for (unsigned i = 0; i < a.bw; ++i) - m_tmp.set(i, e.get(i + b.bw)); - a.clear_overflow_bits(m_tmp); - r = a.try_set(m_tmp); - } - else { - for (unsigned i = 0; i < b.bw; ++i) - m_tmp.set(i, e.get(i)); - b.clear_overflow_bits(m_tmp); - r = b.try_set(m_tmp); - } - return r; + bool bv_eval::try_repair_concat(app* e, unsigned idx) { + unsigned bw = 0; + auto& ve = assign_value(e); + for (unsigned j = e->get_num_args() - 1; j > idx; --j) + bw += bv.get_bv_size(e->get_arg(j)); + auto& a = wval(e, idx); + for (unsigned i = 0; i < a.bw; ++i) + m_tmp.set(i, ve.get(i + bw)); + a.clear_overflow_bits(m_tmp); + return a.try_set(m_tmp); } // @@ -1847,8 +1933,9 @@ namespace bv { // for the randomized assignment, // set a outside of [hi:lo] to random values with preference to 0 or 1 bits // - bool sls_eval::try_repair_extract(bvect const& e, bvval& a, unsigned lo) { - if (m_rand(m_config.m_prob_randomize_extract) <= 100) { + bool bv_eval::try_repair_extract(bvect const& e, bvval& a, unsigned lo) { + // verbose_stream() << "repair extract a[" << lo << ":" << lo + e.bw - 1 << "] = " << e << "\n"; + if (false && m_rand(m_config.m_prob_randomize_extract) <= 100) { a.get_variant(m_tmp, m_rand); if (0 == (m_rand(2))) { auto bit = 0 == (m_rand(2)); @@ -1865,12 +1952,26 @@ namespace bv { a.get(m_tmp); for (unsigned i = 0; i < e.bw; ++i) m_tmp.set(i + lo, e.get(i)); - if (a.try_set(m_tmp)) + m_tmp.set_bw(a.bw); + // verbose_stream() << a << " := " << m_tmp << "\n"; + if (m_rand(20) != 0 && a.try_set(m_tmp)) return true; - return a.set_random(m_rand); + if (m_rand(20) != 0) + return false; + bool ok = a.set_random(m_rand); + // verbose_stream() << "set random " << ok << " " << a << "\n"; + return ok; } - void sls_eval::set_div(bvect const& a, bvect const& b, unsigned bw, + bool bv_eval::try_repair_int2bv(bvect const& e, expr* arg) { + rational r = e.get_value(e.nw); + arith_util a(m); + expr_ref intval(a.mk_int(r), m); + ctx.set_value(arg, intval); + return true; + } + + void bv_eval::set_div(bvect const& a, bvect const& b, unsigned bw, bvect& quot, bvect& rem) const { unsigned nw = (bw + 8 * sizeof(digit_t) - 1) / (8 * sizeof(digit_t)); unsigned bnw = nw; @@ -1890,79 +1991,75 @@ namespace bv { } } - bool sls_eval::repair_up(expr* e) { - if (!is_app(e)) + bool bv_eval::repair_up(expr* e) { + if (!is_app(e) || !can_eval1(to_app(e))) return false; if (m.is_bool(e)) { - auto b = bval1(to_app(e)); - if (is_fixed0(e)) - return b == bval0(e); - m_eval[e->get_id()] = b; + bool b = bval1(to_app(e)); + auto v = ctx.atom2bool_var(e); + if (v == sat::null_bool_var) + ctx.set_value(e, b ? m.mk_true() : m.mk_false()); + else if (ctx.is_true(v) != b) + ctx.flip(v); return true; } - if (bv.is_bv(e)) { - auto& v = eval(to_app(e)); - - for (unsigned i = 0; i < v.nw; ++i) - if (0 != (v.fixed[i] & (v.bits()[i] ^ v.eval[i]))) { - v.bits().copy_to(v.nw, v.eval); - return false; - } - if (v.commit_eval()) - return true; - v.bits().copy_to(v.nw, v.eval); + + if (!bv.is_bv(e)) return false; + auto& v = eval(to_app(e)); + + for (unsigned i = 0; i < v.nw; ++i) { + if (0 != (v.fixed[i] & (v.bits()[i] ^ v.eval[i]))) { + //verbose_stream() << "Fail to set " << mk_bounded_pp(e, m) << " " << i << " " << v.fixed[i] << " " << v.bits()[i] << " " << v.eval[i] << "\n"; + v.bits().copy_to(v.nw, v.eval); + + return false; + } } + + if (v.eval == v.bits()) + return true; + //verbose_stream() << "repair up " << mk_bounded_pp(e, m) << " " << v << "\n"; + if (v.commit_eval()) { + ctx.new_value_eh(e); + return true; + } + //verbose_stream() << "could not repair up " << mk_bounded_pp(e, m) << " " << v << "\n"; + //for (expr* arg : *to_app(e)) + // verbose_stream() << mk_bounded_pp(arg, m) << " " << wval(arg) << "\n"; + v.bits().copy_to(v.nw, v.eval); return false; } - sls_valuation& sls_eval::wval(expr* e) const { + sls::bv_valuation& bv_eval::wval(expr* e) const { // if (!m_values[e->get_id()]) verbose_stream() << mk_bounded_pp(e, m) << "\n"; return *m_values[e->get_id()]; } - void sls_eval::init_eval(app* t) { - if (m.is_bool(t)) - set(t, bval1(t)); - else if (bv.is_bv(t)) { - auto& v = wval(t); - v.bits().copy_to(v.nw, v.eval); - } - } - void sls_eval::commit_eval(app* e) { - if (m.is_bool(e)) { - set(e, bval1(e)); - } - else { - VERIFY(wval(e).commit_eval()); - } - } - - void sls_eval::set_random(app* e) { - if (bv.is_bv(e)) - eval(e).set_random(m_rand); + void bv_eval::commit_eval(expr* p, app* e) { + if (!bv.is_bv(e)) + return; + SASSERT(wval(e).commit_eval()); + VERIFY(wval(e).commit_eval()); } - bool sls_eval::eval_is_correct(app* e) { - if (!can_eval1(e)) - return false; - if (m.is_bool(e)) - return bval0(e) == bval1(e); + void bv_eval::set_random(app* e) { if (bv.is_bv(e)) { - auto const& v = wval(e); - return v.eval == v.bits(); + auto& v = wval(e); + if (v.set_random(m_rand)) + VERIFY(v.commit_eval()); } - UNREACHABLE(); - return false; } - bool sls_eval::re_eval_is_correct(app* e) { + bool bv_eval::eval_is_correct(app* e) { if (!can_eval1(e)) return false; if (m.is_bool(e)) - return bval0(e) ==bval1(e); + return bval0(e) == bval1(e); if (bv.is_bv(e)) { + if (m.is_ite(e)) + return true; auto const& v = eval(e); return v.eval == v.bits(); } @@ -1970,7 +2067,7 @@ namespace bv { return false; } - expr_ref sls_eval::get_value(app* e) { + expr_ref bv_eval::get_value(app* e) { if (m.is_bool(e)) return expr_ref(m.mk_bool_val(bval0(e)), m); else if (bv.is_bv(e)) { @@ -1981,23 +2078,92 @@ namespace bv { return expr_ref(m); } - std::ostream& sls_eval::display(std::ostream& out, expr_ref_vector const& es) { - auto& terms = sort_assertions(es); + std::ostream& bv_eval::display(std::ostream& out) const { + auto& terms = ctx.subterms(); for (expr* e : terms) { + if (!bv.is_bv(e)) + continue; out << e->get_id() << ": " << mk_bounded_pp(e, m, 1) << " "; if (is_fixed0(e)) out << "f "; display_value(out, e) << "\n"; } - terms.reset(); return out; } - std::ostream& sls_eval::display_value(std::ostream& out, expr* e) { + std::ostream& bv_eval::display_value(std::ostream& out, expr* e) const { if (bv.is_bv(e)) return out << wval(e); - if (m.is_bool(e)) - return out << (bval0(e)?"T":"F"); return out << "?"; } + + double bv_eval::lookahead(expr* e, bvect const& new_value) { + SASSERT(bv.is_bv(e)); + SASSERT(is_uninterp(e)); + SASSERT(m_restore.empty()); + + bool has_tabu = false; + double break_count = 0, make_count = 0; + + wval(e).eval = new_value; + if (!insert_update(e)) { + restore_lookahead(); + return -1000000; + } + insert_update_stack(e); + unsigned max_depth = get_depth(e); + for (unsigned depth = max_depth; depth <= max_depth; ++depth) { + for (unsigned i = 0; !has_tabu && i < m_update_stack[depth].size(); ++i) { + e = m_update_stack[depth][i]; + if (bv.is_bv(e)) { + auto& v = eval(to_app(e)); + if (insert_update(e)) { + for (auto p : ctx.parents(e)) { + insert_update_stack(p); + max_depth = std::max(max_depth, get_depth(p)); + } + } + else + has_tabu = true; + } + else if (m.is_bool(e) && can_eval1(to_app(e))) { + bool is_true = ctx.is_true(e); + bool is_true_new = bval1(to_app(e)); + bool is_true_old = bval1_tmp(to_app(e)); + // verbose_stream() << "parent " << mk_bounded_pp(e, m) << " " << is_true << " " << is_true_new << " " << is_true_old << "\n"; + if (is_true == is_true_new && is_true_new != is_true_old) + ++make_count; + if (is_true == is_true_old && is_true_new != is_true_old) + ++break_count; + } + } + m_update_stack[depth].reset(); + } + restore_lookahead(); + if (has_tabu) + return -10000; + return make_count - break_count; + } + + bool bv_eval::insert_update(expr* e) { + m_restore.push_back(e); + m_on_restore.mark(e); + auto& v = wval(e); + v.save_value(); + return v.commit_eval(); + } + + void bv_eval::insert_update_stack(expr* e) { + unsigned depth = get_depth(e); + m_update_stack.reserve(depth + 1); + if (!m_update_stack[depth].contains(e)) + m_update_stack[depth].push_back(e); + } + + void bv_eval::restore_lookahead() { + for (auto e : m_restore) + wval(e).restore_value(); + m_restore.reset(); + m_on_restore.reset(); + } } diff --git a/src/ast/sls/bv_sls_eval.h b/src/ast/sls/sls_bv_eval.h similarity index 62% rename from src/ast/sls/bv_sls_eval.h rename to src/ast/sls/sls_bv_eval.h index 4384660e79a..a4587b987b6 100644 --- a/src/ast/sls/bv_sls_eval.h +++ b/src/ast/sls/sls_bv_eval.h @@ -17,79 +17,81 @@ Module Name: #pragma once #include "ast/ast.h" -#include "ast/sls/sls_valuation.h" -#include "ast/sls/bv_sls_fixed.h" +#include "ast/sls/sls_bv_valuation.h" +#include "ast/sls/sls_bv_fixed.h" +#include "ast/sls/sls_context.h" #include "ast/bv_decl_plugin.h" -namespace bv { + +namespace sls { - class sls_fixed; + class bv_terms; - class sls_eval_plugin { - public: - virtual ~sls_eval_plugin() {} - - }; - class sls_eval { + using bvect = sls::bvect; + + class bv_eval { struct config { unsigned m_prob_randomize_extract = 50; }; - friend class sls_fixed; + friend class sls::bv_fixed; friend class sls_test; ast_manager& m; + sls::context& ctx; + sls::bv_terms& terms; bv_util bv; - sls_fixed m_fix; + sls::bv_fixed m_fix; mutable mpn_manager mpn; ptr_vector m_todo; random_gen m_rand; config m_config; + bool_vector m_fixed; + - scoped_ptr_vector m_plugins; - - - - scoped_ptr_vector m_values; // expr-id -> bv valuation - bool_vector m_eval; // expr-id -> boolean valuation - bool_vector m_fixed; // expr-id -> is Boolean fixed + scoped_ptr_vector m_values; // expr-id -> bv valuation - mutable bvect m_tmp, m_tmp2, m_tmp3, m_tmp4, m_zero, m_one, m_minus_one; + mutable bvect m_tmp, m_tmp2, m_tmp3, m_tmp4, m_mul_tmp, m_zero, m_one, m_minus_one; bvect m_a, m_b, m_nextb, m_nexta, m_aux; - using bvval = sls_valuation; + using bvval = sls::bv_valuation; - - void init_eval_basic(app* e); void init_eval_bv(app* e); + + ptr_vector m_restore; + vector> m_update_stack; + expr_mark m_on_restore; + void insert_update_stack(expr* e); + bool insert_update(expr* e); + double lookahead(expr* e, bvect const& new_value); + void restore_lookahead(); /** * Register e as a bit-vector. * Return true if not already registered, false if already registered. */ - bool add_bit_vector(app* e); - sls_valuation* alloc_valuation(app* e); + void add_bit_vector(app* e); + sls::bv_valuation* alloc_valuation(app* e); - bool bval1_basic(app* e) const; - bool bval1_bv(app* e) const; + bool bval1_bv(app* e, bool use_current) const; + bool bval1_tmp(app* e) const; + + void fold_oper(bvect& out, app* e, unsigned i, std::function const& f); /** * Repair operations */ - bool try_repair_basic(app* e, unsigned i); bool try_repair_bv(app * e, unsigned i); - bool try_repair_and_or(app* e, unsigned i); - bool try_repair_not(app* e); - bool try_repair_eq(app* e, unsigned i); - bool try_repair_xor(app* e, unsigned i); - bool try_repair_ite(app* e, unsigned i); - bool try_repair_implies(app* e, unsigned i); bool try_repair_band(bvect const& e, bvval& a, bvval const& b); + bool try_repair_band(app* t, unsigned i); bool try_repair_bor(bvect const& e, bvval& a, bvval const& b); + bool try_repair_bor(app* t, unsigned i); bool try_repair_add(bvect const& e, bvval& a, bvval const& b); + bool try_repair_add(app* t, unsigned i); bool try_repair_sub(bvect const& e, bvval& a, bvval& b, unsigned i); - bool try_repair_mul(bvect const& e, bvval& a, bvval const& b); + bool try_repair_mul(bvect const& e, bvval& a, bvect const& b); bool try_repair_bxor(bvect const& e, bvval& a, bvval const& b); + bool try_repair_bxor(app* t, unsigned i); bool try_repair_bnot(bvect const& e, bvval& a); bool try_repair_bneg(bvect const& e, bvval& a); bool try_repair_ule(bool e, bvval& a, bvval const& b); @@ -116,11 +118,14 @@ namespace bv { bool try_repair_umul_ovfl(bool e, bvval& a, bvval& b, unsigned i); bool try_repair_zero_ext(bvect const& e, bvval& a); bool try_repair_sign_ext(bvect const& e, bvval& a); - bool try_repair_concat(bvect const& e, bvval& a, bvval& b, unsigned i); + bool try_repair_concat(app* e, unsigned i); bool try_repair_extract(bvect const& e, bvval& a, unsigned lo); bool try_repair_comp(bvect const& e, bvval& a, bvval& b, unsigned i); bool try_repair_eq(bool is_true, bvval& a, bvval const& b); - void add_p2_1(bvval const& a, bvect& t) const; + bool try_repair_eq(app* e, unsigned i); + bool try_repair_eq_lookahead(app* e); + bool try_repair_int2bv(bvect const& e, expr* arg); + void add_p2_1(bvval const& a, bool use_current, bvect& t) const; bool add_overflow_on_fixed(bvval const& a, bvect const& t); bool mul_overflow_on_fixed(bvval const& a, bvect const& t); @@ -130,66 +135,58 @@ namespace bv { digit_t random_bits(); bool random_bool() { return m_rand() % 2 == 0; } - sls_valuation& wval(app* e, unsigned i) { return wval(e->get_arg(i)); } + sls::bv_valuation& wval(app* e, unsigned i) { return wval(e->get_arg(i)); } - void eval(app* e, sls_valuation& val) const; + void eval(app* e, sls::bv_valuation& val) const; - bvect const& eval_value(app* e) const { return wval(e).eval; } + bvect const& assign_value(app* e) const { return wval(e).bits(); } - public: - sls_eval(ast_manager& m); - void init_eval(expr_ref_vector const& es, std::function const& eval); + /** + * Retrieve evaluation based on immediate children. + */ + + bool can_eval1(app* e) const; + + void commit_eval(expr* p, app* e); - void tighten_range(expr_ref_vector const& es) { m_fix.init(es); } + public: + bv_eval(sls::bv_terms& terms, sls::context& ctx); + + void init() { m_fix.init(); } - ptr_vector& sort_assertions(expr_ref_vector const& es); + void register_term(expr* e); /** * Retrieve evaluation based on cache. * bval - Boolean values * wval - Word (bit-vector) values - */ - - bool bval0(expr* e) const { return m_eval[e->get_id()]; } + */ - sls_valuation& wval(expr* e) const; + sls::bv_valuation& wval(expr* e) const; - bool is_fixed0(expr* e) const { return m_fixed.get(e->get_id(), false); } + void set(expr* e, sls::bv_valuation const& val); - /** - * Retrieve evaluation based on immediate children. - */ - bool bval1(app* e) const; - bool can_eval1(app* e) const; + bool is_fixed0(expr* e) const { return m_fixed.get(e->get_id(), false); } - sls_valuation& eval(app* e) const; - - void commit_eval(app* e); - - void init_eval(app* e); + sls::bv_valuation& eval(app* e) const; void set_random(app* e); bool eval_is_correct(app* e); - bool re_eval_is_correct(app* e); + bool is_uninterpreted(app* e) const; expr_ref get_value(app* e); - /** - * Override evaluaton. - */ - - void set(expr* e, bool b) { - m_eval[e->get_id()] = b; - } - + bool bval0(expr* e) const { return ctx.is_true(e); } + bool bval1(app* e) const; + /* - * Try to invert value of child to repair value assignment of parent. - */ + * Try to invert value of child to repair value assignment of parent. + */ - bool try_repair(app* e, unsigned i); + bool repair_down(app* e, unsigned i); /* * Propagate repair up to parent @@ -197,8 +194,8 @@ namespace bv { bool repair_up(expr* e); - std::ostream& display(std::ostream& out, expr_ref_vector const& es); + std::ostream& display(std::ostream& out) const; - std::ostream& display_value(std::ostream& out, expr* e); + std::ostream& display_value(std::ostream& out, expr* e) const; }; } diff --git a/src/ast/sls/sls_evaluator.h b/src/ast/sls/sls_bv_evaluator.h similarity index 99% rename from src/ast/sls/sls_evaluator.h rename to src/ast/sls/sls_bv_evaluator.h index 2ee03c928f4..d0a71f15879 100644 --- a/src/ast/sls/sls_evaluator.h +++ b/src/ast/sls/sls_bv_evaluator.h @@ -22,7 +22,7 @@ Module Name: #include "model/model_evaluator.h" #include "ast/sls/sls_powers.h" -#include "ast/sls/sls_tracker.h" +#include "ast/sls/sls_bv_tracker.h" class sls_evaluator { ast_manager & m_manager; diff --git a/src/ast/sls/bv_sls_fixed.cpp b/src/ast/sls/sls_bv_fixed.cpp similarity index 60% rename from src/ast/sls/bv_sls_fixed.cpp rename to src/ast/sls/sls_bv_fixed.cpp index 9f897a7bdca..784576992f1 100644 --- a/src/ast/sls/bv_sls_fixed.cpp +++ b/src/ast/sls/sls_bv_fixed.cpp @@ -13,56 +13,52 @@ Module Name: #include "ast/ast_pp.h" #include "ast/ast_ll_pp.h" -#include "ast/sls/bv_sls_fixed.h" -#include "ast/sls/bv_sls_eval.h" +#include "ast/sls/sls_bv_fixed.h" +#include "ast/sls/sls_bv_terms.h" +#include "ast/sls/sls_bv_eval.h" -namespace bv { +namespace sls { - sls_fixed::sls_fixed(sls_eval& ev): + bv_fixed::bv_fixed(bv_eval& ev, bv_terms& terms, sls::context& ctx): ev(ev), + terms(terms), m(ev.m), - bv(ev.bv) + bv(ev.bv), + ctx(ctx) {} - void sls_fixed::init(expr_ref_vector const& es) { - ev.sort_assertions(es); - for (expr* e : ev.m_todo) { - if (!is_app(e)) + void bv_fixed::init() { + + for (auto e : ctx.subterms()) + set_fixed(e); + + //ctx.display(verbose_stream()); + + for (auto lit : ctx.unit_literals()) { + auto a = ctx.atom(lit.var()); + if (!a) continue; - app* a = to_app(e); - ev.m_fixed.setx(a->get_id(), is_fixed1(a), false); - if (a->get_family_id() == basic_family_id) - init_fixed_basic(a); - else if (a->get_family_id() == bv.get_family_id()) - init_fixed_bv(a); - else - ; + if (is_app(a)) + init_range(to_app(a), lit.sign()); + ev.m_fixed.setx(a->get_id(), true, false); } - init_ranges(es); - ev.m_todo.reset(); - } + //ctx.display(verbose_stream()); + for (auto e : ctx.subterms()) + propagate_range_up(e); - void sls_fixed::init_ranges(expr_ref_vector const& es) { - for (expr* e : es) { - bool sign = m.is_not(e, e); - if (is_app(e)) - init_range(to_app(e), sign); - } - - for (expr* e : ev.m_todo) - propagate_range_up(e); + //ctx.display(verbose_stream()); } - void sls_fixed::propagate_range_up(expr* e) { + void bv_fixed::propagate_range_up(expr* e) { expr* t, * s; rational v; if (bv.is_concat(e, t, s)) { - auto& vals = wval(s); + auto& vals = ev.wval(s); if (vals.lo() != vals.hi() && (vals.lo() < vals.hi() || vals.hi() == 0)) // lo <= e add_range(e, vals.lo(), rational::zero(), false); - auto valt = wval(t); + auto valt = ev.wval(t); if (valt.lo() != valt.hi() && (valt.lo() < valt.hi() || valt.hi() == 0)) { // (2^|s|) * lo <= e < (2^|s|) * hi auto p = rational::power_of_two(bv.get_bv_size(s)); @@ -70,12 +66,12 @@ namespace bv { } } else if (bv.is_bv_add(e, s, t) && bv.is_numeral(s, v)) { - auto& val = wval(t); + auto& val = ev.wval(t); if (val.lo() != val.hi()) add_range(e, v + val.lo(), v + val.hi(), false); } else if (bv.is_bv_add(e, t, s) && bv.is_numeral(s, v)) { - auto& val = wval(t); + auto& val = ev.wval(t); if (val.lo() != val.hi()) add_range(e, v + val.lo(), v + val.hi(), false); } @@ -83,7 +79,7 @@ namespace bv { // x in [lo, hi[ => -x in [-hi + 1, -lo + 1[ else if (bv.is_bv_mul(e, s, t) && bv.is_numeral(s, v) && v + 1 == rational::power_of_two(bv.get_bv_size(e))) { - auto& val = wval(t); + auto& val = ev.wval(t); if (val.lo() != val.hi()) add_range(e, -val.hi() + 1, - val.lo() + 1, false); } @@ -91,7 +87,7 @@ namespace bv { // s <=s t <=> s + K <= t + K, K = 2^{bw-1} - bool sls_fixed::init_range(app* e, bool sign) { + bool bv_fixed::init_range(app* e, bool sign) { expr* s, * t, * x, * y; rational a, b; unsigned idx; @@ -149,7 +145,7 @@ namespace bv { return true; } else if (bv.is_bit2bool(e, s, idx)) { - auto& val = wval(s); + auto& val = ev.wval(s); val.try_set_bit(idx, !sign); val.fixed.set(idx, true); val.tighten_range(); @@ -159,17 +155,17 @@ namespace bv { return false; } - bool sls_fixed::init_eq(expr* t, rational const& a, bool sign) { + bool bv_fixed::init_eq(expr* t, rational const& a, bool sign) { unsigned lo, hi; rational b(0); - // verbose_stream() << mk_bounded_pp(t, m) << " == " << a << "\n"; expr* s = nullptr; - if (sign) + if (sign && true) // 1 <= t - a init_range(nullptr, rational(1), t, -a, false); - else + if (!sign) // t - a <= 0 init_range(t, -a, nullptr, rational::zero(), false); + if (!sign && bv.is_bv_not(t, s)) { for (unsigned i = 0; i < bv.get_bv_size(s); ++i) if (!a.get_bit(i)) @@ -187,20 +183,21 @@ namespace bv { } if (bv.is_extract(t, lo, hi, s)) { if (hi == lo) { - sign = sign ? a == 1 : a == 0; - auto& val = wval(s); - if (val.try_set_bit(lo, !sign)) - val.fixed.set(lo, true); + auto sign1 = sign ? a == 1 : a == 0; + auto& val = ev.wval(s); + if (val.try_set_bit(lo, !sign1)) + val.fixed.set(lo, true); + val.tighten_range(); + } else if (!sign) { - auto& val = wval(s); + auto& val = ev.wval(s); for (unsigned i = lo; i <= hi; ++i) if (val.try_set_bit(i, a.get_bit(i - lo))) val.fixed.set(i, true); val.tighten_range(); - // verbose_stream() << lo << " " << hi << " " << val << " := " << a << "\n"; - } + } if (!sign && hi + 1 == bv.get_bv_size(s)) { // s < 2^lo * (a + 1) @@ -223,7 +220,7 @@ namespace bv { // a < x + b <=> ! (x + b <= a) <=> x not in [-a, b - a [ <=> x in [b - a, -a [ a != -1 // x + a < x + b <=> ! (x + b <= x + a) <=> x in [-a, -b [ a != b // - bool sls_fixed::init_range(expr* x, rational const& a, expr* y, rational const& b, bool sign) { + bool bv_fixed::init_range(expr* x, rational const& a, expr* y, rational const& b, bool sign) { if (!x && !y) return false; if (!x) @@ -235,8 +232,8 @@ namespace bv { return false; } - bool sls_fixed::add_range(expr* e, rational lo, rational hi, bool sign) { - auto& v = wval(e); + bool bv_fixed::add_range(expr* e, rational lo, rational hi, bool sign) { + auto& v = ev.wval(e); lo = mod(lo, rational::power_of_two(bv.get_bv_size(e))); hi = mod(hi, rational::power_of_two(bv.get_bv_size(e))); if (lo == hi) @@ -262,7 +259,7 @@ namespace bv { return true; } - void sls_fixed::get_offset(expr* e, expr*& x, rational& offset) { + void bv_fixed::get_offset(expr* e, expr*& x, rational& offset) { expr* s, * t; x = e; offset = 0; @@ -285,177 +282,173 @@ namespace bv { x = nullptr; } - sls_valuation& sls_fixed::wval(expr* e) { - return ev.wval(e); - } - - void sls_fixed::init_fixed_basic(app* e) { - if (bv.is_bv(e) && m.is_ite(e)) { - auto& val = wval(e); - auto& val_th = wval(e->get_arg(1)); - auto& val_el = wval(e->get_arg(2)); - for (unsigned i = 0; i < val.nw; ++i) - val.fixed[i] = val_el.fixed[i] & val_th.fixed[i] & ~(val_el.bits(i) ^ val_th.bits(i)); - } - } - - void sls_fixed::init_fixed_bv(app* e) { - if (bv.is_bv(e)) - set_fixed_bw(e); - } - - bool sls_fixed::is_fixed1(app* e) const { + bool bv_fixed::is_fixed1(app* e) const { if (is_uninterp(e)) return false; - if (e->get_family_id() == basic_family_id) - return is_fixed1_basic(e); return all_of(*e, [&](expr* arg) { return ev.is_fixed0(arg); }); } - - bool sls_fixed::is_fixed1_basic(app* e) const { - switch (e->get_decl_kind()) { - case OP_TRUE: - case OP_FALSE: - return true; - case OP_AND: - return any_of(*e, [&](expr* arg) { return ev.is_fixed0(arg) && !ev.bval0(e); }); - case OP_OR: - return any_of(*e, [&](expr* arg) { return ev.is_fixed0(arg) && ev.bval0(e); }); - default: - return all_of(*e, [&](expr* arg) { return ev.is_fixed0(arg); }); - } - } - void sls_fixed::set_fixed_bw(app* e) { - SASSERT(bv.is_bv(e)); - SASSERT(e->get_family_id() == bv.get_fid()); - auto& v = ev.wval(e); - if (all_of(*e, [&](expr* arg) { return ev.is_fixed0(arg); })) { - for (unsigned i = 0; i < v.bw; ++i) - v.fixed.set(i, true); + void bv_fixed::set_fixed(expr* _e) { + if (!is_app(_e)) + return; + auto e = to_app(_e); + + if (e->get_family_id() == bv.get_family_id() && all_of(*e, [&](expr* arg) { return ev.is_fixed0(arg); })) { + if (bv.is_bv(e)) { + auto& v = ev.wval(e); + for (unsigned i = 0; i < v.bw; ++i) + v.fixed.set(i, true); + } ev.m_fixed.setx(e->get_id(), true, false); return; } + + if (!bv.is_bv(e)) + return; + auto& v = ev.wval(e); + + if (m.is_ite(e)) { + auto& val_th = ev.wval(e->get_arg(1)); + auto& val_el = ev.wval(e->get_arg(2)); + for (unsigned i = 0; i < v.nw; ++i) + v.fixed[i] = val_el.fixed[i] & val_th.fixed[i] & ~(val_el.bits(i) ^ val_th.bits(i)); + return; + } + + if (e->get_family_id() != bv.get_fid()) + return; switch (e->get_decl_kind()) { case OP_BAND: { - auto& a = wval(e->get_arg(0)); - auto& b = wval(e->get_arg(1)); - // (a.fixed & b.fixed) | (a.fixed & ~a.bits) | (b.fixed & ~b.bits) - for (unsigned i = 0; i < a.nw; ++i) - v.fixed[i] = (a.fixed[i] & b.fixed[i]) | (a.fixed[i] & ~a.bits(i)) | (b.fixed[i] & ~b.bits(i)); + if (e->get_num_args() == 2) { + auto& a = ev.wval(e->get_arg(0)); + auto& b = ev.wval(e->get_arg(1)); + // (a.fixed & b.fixed) | (a.fixed & ~a.bits) | (b.fixed & ~b.bits) + for (unsigned i = 0; i < a.nw; ++i) + v.fixed[i] = (a.fixed[i] & b.fixed[i]) | (a.fixed[i] & ~a.bits(i)) | (b.fixed[i] & ~b.bits(i)); + } break; } case OP_BOR: { - auto& a = wval(e->get_arg(0)); - auto& b = wval(e->get_arg(1)); - // (a.fixed & b.fixed) | (a.fixed & a.bits) | (b.fixed & b.bits) - for (unsigned i = 0; i < a.nw; ++i) - v.fixed[i] = (a.fixed[i] & b.fixed[i]) | (a.fixed[i] & a.bits(i)) | (b.fixed[i] & b.bits(i)); + if (e->get_num_args() == 2) { + auto& a = ev.wval(e->get_arg(0)); + auto& b = ev.wval(e->get_arg(1)); + // (a.fixed & b.fixed) | (a.fixed & a.bits) | (b.fixed & b.bits) + for (unsigned i = 0; i < a.nw; ++i) + v.fixed[i] = (a.fixed[i] & b.fixed[i]) | (a.fixed[i] & a.bits(i)) | (b.fixed[i] & b.bits(i)); + } break; } case OP_BXOR: { - auto& a = wval(e->get_arg(0)); - auto& b = wval(e->get_arg(1)); - for (unsigned i = 0; i < a.nw; ++i) - v.fixed[i] = a.fixed[i] & b.fixed[i]; + if (e->get_num_args() == 2) { + auto& a = ev.wval(e->get_arg(0)); + auto& b = ev.wval(e->get_arg(1)); + for (unsigned i = 0; i < a.nw; ++i) + v.fixed[i] = a.fixed[i] & b.fixed[i]; + } break; } case OP_BNOT: { - auto& a = wval(e->get_arg(0)); + auto& a = ev.wval(e->get_arg(0)); for (unsigned i = 0; i < a.nw; ++i) v.fixed[i] = a.fixed[i]; break; } case OP_BADD: { - auto& a = wval(e->get_arg(0)); - auto& b = wval(e->get_arg(1)); bool pfixed = true; for (unsigned i = 0; i < v.bw; ++i) { - if (pfixed && a.fixed.get(i) && b.fixed.get(i)) - v.fixed.set(i, true); - else if (!pfixed && a.fixed.get(i) && b.fixed.get(i) && - !a.get_bit(i) && !b.get_bit(i)) { - pfixed = true; - v.fixed.set(i, false); - } - else { - pfixed = false; - v.fixed.set(i, false); + for (unsigned j = 0; pfixed && j < e->get_num_args(); ++j) { + auto& a = ev.wval(e->get_arg(j)); + pfixed &= a.fixed.get(i); } + v.fixed.set(i, pfixed); } break; } case OP_BMUL: { - auto& a = wval(e->get_arg(0)); - auto& b = wval(e->get_arg(1)); - unsigned j = 0, k = 0, zj = 0, zk = 0, hzj = 0, hzk = 0; - // i'th bit depends on bits j + k = i - // if the first j, resp k bits are 0, the bits j + k are 0 - for (; j < v.bw; ++j) - if (!a.fixed.get(j)) - break; - for (; k < v.bw; ++k) - if (!b.fixed.get(k)) - break; - for (; zj < v.bw; ++zj) - if (!a.fixed.get(zj) || a.get_bit(zj)) - break; - for (; zk < v.bw; ++zk) - if (!b.fixed.get(zk) || b.get_bit(zk)) - break; - for (; hzj < v.bw; ++hzj) - if (!a.fixed.get(v.bw - hzj - 1) || a.get_bit(v.bw - hzj - 1)) - break; - for (; hzk < v.bw; ++hzk) - if (!b.fixed.get(v.bw - hzk - 1) || b.get_bit(v.bw - hzk - 1)) - break; - - - if (j > 0 && k > 0) { - for (unsigned i = 0; i < std::min(k, j); ++i) { - SASSERT(!v.get_bit(i)); - v.fixed.set(i, true); + if (e->get_num_args() == 2) { + SASSERT(e->get_num_args() == 2); + auto& a = ev.wval(e->get_arg(0)); + auto& b = ev.wval(e->get_arg(1)); + unsigned j = 0, k = 0, zj = 0, zk = 0, hzj = 0, hzk = 0; + // i'th bit depends on bits j + k = i + // if the first j, resp k bits are 0, the bits j + k are 0 + for (; j < v.bw; ++j) + if (!a.fixed.get(j)) + break; + for (; k < v.bw; ++k) + if (!b.fixed.get(k)) + break; + for (; zj < v.bw; ++zj) + if (!a.fixed.get(zj) || a.get_bit(zj)) + break; + for (; zk < v.bw; ++zk) + if (!b.fixed.get(zk) || b.get_bit(zk)) + break; + for (; hzj < v.bw; ++hzj) + if (!a.fixed.get(v.bw - hzj - 1) || a.get_bit(v.bw - hzj - 1)) + break; + for (; hzk < v.bw; ++hzk) + if (!b.fixed.get(v.bw - hzk - 1) || b.get_bit(v.bw - hzk - 1)) + break; + + + if (j > 0 && k > 0) { + for (unsigned i = 0; i < std::min(k, j); ++i) { + SASSERT(!v.get_bit(i)); + v.fixed.set(i, true); + } } - } - // lower zj + jk bits are 0 - if (zk > 0 || zj > 0) { - for (unsigned i = 0; i < zk + zj; ++i) { - SASSERT(!v.get_bit(i)); - v.fixed.set(i, true); + // lower zj + jk bits are 0 + if (zk > 0 || zj > 0) { + for (unsigned i = 0; i < zk + zj; ++i) { + SASSERT(!v.get_bit(i)); + v.fixed.set(i, true); + } + } + // upper bits are 0, if enough high order bits of a, b are 0. + // TODO - buggy + if (false && hzj < v.bw && hzk < v.bw && hzj + hzk > v.bw) { + hzj = v.bw - hzj; + hzk = v.bw - hzk; + for (unsigned i = hzj + hzk - 1; i < v.bw; ++i) { + SASSERT(!v.get_bit(i)); + v.fixed.set(i, true); + } } } - // upper bits are 0, if enough high order bits of a, b are 0. - // TODO - buggy - if (false && hzj < v.bw && hzk < v.bw && hzj + hzk > v.bw) { - hzj = v.bw - hzj; - hzk = v.bw - hzk; - for (unsigned i = hzj + hzk - 1; i < v.bw; ++i) { - SASSERT(!v.get_bit(i)); - v.fixed.set(i, true); + else { + bool pfixed = true; + for (unsigned i = 0; i < v.bw; ++i) { + for (unsigned j = 0; pfixed && j < e->get_num_args(); ++j) { + auto& a = ev.wval(e->get_arg(j)); + pfixed &= a.fixed.get(i); + } + v.fixed.set(i, pfixed); } - } + } break; } case OP_CONCAT: { - auto& a = wval(e->get_arg(0)); - auto& b = wval(e->get_arg(1)); - for (unsigned i = 0; i < b.bw; ++i) - v.fixed.set(i, b.fixed.get(i)); - for (unsigned i = 0; i < a.bw; ++i) - v.fixed.set(i + b.bw, a.fixed.get(i)); + unsigned bw = 0; + for (unsigned i = e->get_num_args(); i-- > 0; ) { + auto& a = ev.wval(e->get_arg(i)); + for (unsigned j = 0; j < a.bw; ++j) + v.fixed.set(bw + j, a.fixed.get(j)); + bw += a.bw; + } break; } case OP_EXTRACT: { expr* child; unsigned lo, hi; VERIFY(bv.is_extract(e, lo, hi, child)); - auto& a = wval(child); + auto& a = ev.wval(child); for (unsigned i = lo; i <= hi; ++i) v.fixed.set(i - lo, a.fixed.get(i)); break; } case OP_BNEG: { - auto& a = wval(e->get_arg(0)); + auto& a = ev.wval(e->get_arg(0)); bool pfixed = true; for (unsigned i = 0; i < v.bw; ++i) { if (pfixed && a.fixed.get(i)) diff --git a/src/ast/sls/bv_sls_fixed.h b/src/ast/sls/sls_bv_fixed.h similarity index 57% rename from src/ast/sls/bv_sls_fixed.h rename to src/ast/sls/sls_bv_fixed.h index 2e88484c55d..d175e9d6a40 100644 --- a/src/ast/sls/bv_sls_fixed.h +++ b/src/ast/sls/sls_bv_fixed.h @@ -17,19 +17,23 @@ Module Name: #pragma once #include "ast/ast.h" -#include "ast/sls/sls_valuation.h" +#include "ast/sls/sls_bv_valuation.h" +#include "ast/sls/sls_context.h" #include "ast/bv_decl_plugin.h" -namespace bv { - class sls_eval; +namespace sls { + + class bv_terms; + class bv_eval; - class sls_fixed { - sls_eval& ev; - ast_manager& m; - bv_util& bv; + class bv_fixed { + bv_eval& ev; + bv_terms& terms; + ast_manager& m; + bv_util& bv; + sls::context& ctx; - void init_ranges(expr_ref_vector const& es); bool init_range(app* e, bool sign); void propagate_range_up(expr* e); bool init_range(expr* x, rational const& a, expr* y, rational const& b, bool sign); @@ -37,19 +41,11 @@ namespace bv { bool init_eq(expr* e, rational const& v, bool sign); bool add_range(expr* e, rational lo, rational hi, bool sign); - void init_fixed_basic(app* e); - void init_fixed_bv(app* e); - bool is_fixed1(app* e) const; - bool is_fixed1_basic(app* e) const; - void set_fixed_bw(app* e); - - sls_valuation& wval(expr* e); + void set_fixed(expr* e); public: - sls_fixed(sls_eval& ev); - - void init(expr_ref_vector const& es); - + bv_fixed(bv_eval& ev, bv_terms& terms, sls::context& ctx); + void init(); }; } diff --git a/src/ast/sls/sls_bv_plugin.cpp b/src/ast/sls/sls_bv_plugin.cpp new file mode 100644 index 00000000000..8b4b29be743 --- /dev/null +++ b/src/ast/sls/sls_bv_plugin.cpp @@ -0,0 +1,206 @@ +/*++ +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + sls_bv_plugin.cpp + +Abstract: + + Theory plugin for bit-vector local search + +Author: + + Nikolaj Bjorner (nbjorner) 2024-07-06 + +--*/ +#include "ast/sls/sls_bv_plugin.h" +#include "ast/ast_ll_pp.h" +#include "ast/ast_pp.h" + +namespace sls { + + bv_plugin::bv_plugin(context& ctx): + plugin(ctx), + bv(m), + m_terms(ctx), + m_eval(m_terms, ctx) { + m_fid = bv.get_family_id(); + } + + void bv_plugin::register_term(expr* e) { + m_terms.register_term(e); + m_eval.register_term(e); + } + + expr_ref bv_plugin::get_value(expr* e) { + SASSERT(bv.is_bv(e)); + auto const & val = m_eval.wval(e); + return expr_ref(bv.mk_numeral(val.get_value(), e->get_sort()), m); + } + + bool bv_plugin::is_bv_predicate(expr* e) { + if (!e || !is_app(e)) + return false; + auto a = to_app(e); + if (a->get_family_id() == bv.get_family_id()) + return true; + if (m.is_eq(e) && bv.is_bv(a->get_arg(0))) + return true; + return false; + } + + void bv_plugin::propagate_literal(sat::literal lit) { + SASSERT(ctx.is_true(lit)); + auto e = ctx.atom(lit.var()); + if (!is_bv_predicate(e)) + return; + auto a = to_app(e); + + if (!m_eval.eval_is_correct(a)) { + IF_VERBOSE(20, verbose_stream() << "repair " << lit << " " << mk_bounded_pp(e, m) << "\n"); + ctx.new_value_eh(e); + } + } + + bool bv_plugin::propagate() { + auto& axioms = m_terms.axioms(); + if (!axioms.empty()) { + for (auto* e : axioms) + ctx.add_constraint(e); + axioms.reset(); + return true; + } + return false; + } + + void bv_plugin::initialize() { + if (!m_initialized) { + m_eval.init(); + m_initialized = true; + } + } + + void bv_plugin::init_bool_var_assignment(sat::bool_var v) { + auto a = ctx.atom(v); + if (!a || !is_app(a)) + return; + if (to_app(a)->get_family_id() != bv.get_family_id()) + return; + bool is_true = m_eval.bval1(to_app(a)); + + if (is_true != ctx.is_true(v)) + ctx.flip(v); + } + + bool bv_plugin::is_sat() { + bool is_sat = true; + for (auto t : ctx.subterms()) + if (is_app(t) && bv.is_bv(t) && to_app(t)->get_family_id() == bv.get_fid() && !m_eval.eval_is_correct(to_app(t))) { + ctx.new_value_eh(t); + is_sat = false; + } + return is_sat; + } + + std::ostream& bv_plugin::display(std::ostream& out) const { + return m_eval.display(out); + } + + bool bv_plugin::set_value(expr* e, expr* v) { + if (!bv.is_bv(e)) + return false; + rational val; + VERIFY(bv.is_numeral(v, val)); + auto& w = m_eval.eval(to_app(e)); + w.set_value(w.eval, val); + return w.commit_eval(); + } + + bool bv_plugin::repair_down(app* e) { + unsigned n = e->get_num_args(); + bool status = true; + if (n == 0 || m_eval.is_uninterpreted(e) || m_eval.eval_is_correct(e)) + goto done; + + if (n == 2) { + auto d1 = get_depth(e->get_arg(0)); + auto d2 = get_depth(e->get_arg(1)); + unsigned s = ctx.rand(d1 + d2 + 2); + if (s <= d1 && m_eval.repair_down(e, 0)) + goto done; + + if (m_eval.repair_down(e, 1)) + goto done; + + if (m_eval.repair_down(e, 0)) + goto done; + } + else { + unsigned s = ctx.rand(n); + for (unsigned i = 0; i < n; ++i) { + auto j = (i + s) % n; + if (m_eval.repair_down(e, j)) + goto done; + } + } + status = false; + + done: + log(e, false, status); + return status; + } + + void bv_plugin::repair_up(app* e) { + if (m_eval.repair_up(e)) { + if (!m_eval.eval_is_correct(e)) { + verbose_stream() << "Incorrect eval #" << e->get_id() << " " << mk_bounded_pp(e, m) << "\n"; + } + log(e, true, true); + SASSERT(m_eval.eval_is_correct(e)); + if (m.is_bool(e)) { + if (ctx.is_true(e) != m_eval.bval1(e)) + ctx.flip(ctx.atom2bool_var(e)); + } + } + else if (bv.is_bv(e)) { + log(e, true, false); + IF_VERBOSE(5, verbose_stream() << "repair-up "; trace_repair(true, e)); + auto& v = m_eval.wval(e); + m_eval.set_random(e); + ctx.new_value_eh(e); + } + else + log(e, true, false); + + } + + void bv_plugin::repair_literal(sat::literal lit) { + SASSERT(ctx.is_true(lit)); + auto e = ctx.atom(lit.var()); + if (!is_bv_predicate(e)) + return; + auto a = to_app(e); + if (!m_eval.eval_is_correct(a)) + ctx.flip(lit.var()); + } + + std::ostream& bv_plugin::trace_repair(bool down, expr* e) { + verbose_stream() << (down ? "d #" : "u #") + << e->get_id() << ": " + << mk_bounded_pp(e, m, 1) << " "; + return m_eval.display_value(verbose_stream(), e) << "\n"; + } + + void bv_plugin::trace() { + IF_VERBOSE(2, verbose_stream() + << "(bvsls :restarts " << m_stats.m_restarts << ")\n"); + } + + void bv_plugin::log(expr* e, bool up_down, bool success) { + IF_VERBOSE(11, verbose_stream() << mk_bounded_pp(e, m) << " " << (up_down?"u":"d") << " " << (success ? "S" : "F"); + if (bv.is_bv(e)) verbose_stream() << " " << m_eval.wval(e); + verbose_stream() << "\n"); + } + +} diff --git a/src/ast/sls/sls_bv_plugin.h b/src/ast/sls/sls_bv_plugin.h new file mode 100644 index 00000000000..7d9e338e7db --- /dev/null +++ b/src/ast/sls/sls_bv_plugin.h @@ -0,0 +1,62 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + sls_bv_plugin.h + +Abstract: + + Theory plugin for bit-vector local search + +Author: + + Nikolaj Bjorner (nbjorner) 2024-07-06 + +--*/ +#pragma once + +#include "ast/sls/sls_context.h" +#include "ast/bv_decl_plugin.h" +#include "ast/sls/sls_bv_terms.h" +#include "ast/sls/sls_bv_eval.h" + +namespace sls { + + class bv_plugin : public plugin { + bv_util bv; + bv_terms m_terms; + bv_eval m_eval; + bv::sls_stats m_stats; + bool m_initialized = false; + + void init_bool_var_assignment(sat::bool_var v); + std::ostream& trace_repair(bool down, expr* e); + void trace(); + bool can_propagate(); + bool is_bv_predicate(expr* e); + + void log(expr* e, bool up_down, bool success); + + public: + bv_plugin(context& ctx); + ~bv_plugin() override {} + void register_term(expr* e) override; + expr_ref get_value(expr* e) override; + void initialize() override; + void propagate_literal(sat::literal lit) override; + bool propagate() override; + bool repair_down(app* e) override; + void repair_up(app* e) override; + void repair_literal(sat::literal lit) override; + bool is_sat() override; + + void on_rescale() override {} + void on_restart() override {} + std::ostream& display(std::ostream& out) const override; + bool set_value(expr* e, expr* v) override; + void collect_statistics(statistics& st) const override {} + void reset_statistics() override {} + }; + +} diff --git a/src/ast/sls/sls_bv_terms.cpp b/src/ast/sls/sls_bv_terms.cpp new file mode 100644 index 00000000000..9004128b96d --- /dev/null +++ b/src/ast/sls/sls_bv_terms.cpp @@ -0,0 +1,143 @@ +/*++ +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + bv_sls_terms.cpp + +Abstract: + + normalize bit-vector expressions to use only binary operators. + +Author: + + Nikolaj Bjorner (nbjorner) 2024-02-07 + +--*/ + +#include "ast/ast_ll_pp.h" +#include "ast/sls/sls_bv_terms.h" +#include "ast/rewriter/bool_rewriter.h" +#include "ast/rewriter/bv_rewriter.h" + +namespace sls { + + bv_terms::bv_terms(sls::context& ctx): + m(ctx.get_manager()), + bv(m), + m_axioms(m) {} + + void bv_terms::register_term(expr* e) { + auto r = ensure_binary(e); + if (r != e) + m_axioms.push_back(m.mk_eq(e, r)); + register_uninterp(e); + } + + expr_ref bv_terms::ensure_binary(expr* e) { + expr* x, * y; + expr_ref r(m); + if (bv.is_bv_sdiv(e, x, y) || bv.is_bv_sdiv0(e, x, y) || bv.is_bv_sdivi(e, x, y)) + r = mk_sdiv(x, y); + else if (bv.is_bv_smod(e, x, y) || bv.is_bv_smod0(e, x, y) || bv.is_bv_smodi(e, x, y)) + r = mk_smod(x, y); + else if (bv.is_bv_srem(e, x, y) || bv.is_bv_srem0(e, x, y) || bv.is_bv_sremi(e, x, y)) + r = mk_srem(x, y); + else + r = e; + return r; + } + + expr_ref bv_terms::mk_sdiv(expr* x, expr* y) { + // d = udiv(abs(x), abs(y)) + // y = 0, x >= 0 -> -1 + // y = 0, x < 0 -> 1 + // x = 0, y != 0 -> 0 + // x > 0, y < 0 -> -d + // x < 0, y > 0 -> -d + // x > 0, y > 0 -> d + // x < 0, y < 0 -> d + bool_rewriter br(m); + bv_rewriter bvr(m); + unsigned sz = bv.get_bv_size(x); + rational N = rational::power_of_two(sz); + expr_ref z(bv.mk_zero(sz), m); + expr_ref o(bv.mk_one(sz), m); + expr_ref n1(bv.mk_numeral(N - 1, sz), m); + expr_ref signx = bvr.mk_ule(bv.mk_numeral(N / 2, sz), x); + expr_ref signy = bvr.mk_ule(bv.mk_numeral(N / 2, sz), y); + expr_ref absx = br.mk_ite(signx, bvr.mk_bv_neg(x), x); + expr_ref absy = br.mk_ite(signy, bvr.mk_bv_neg(y), y); + expr_ref d = expr_ref(bv.mk_bv_udiv(absx, absy), m); + expr_ref r = br.mk_ite(br.mk_eq(signx, signy), d, bvr.mk_bv_neg(d)); + r = br.mk_ite(br.mk_eq(z, y), + br.mk_ite(signx, o, n1), + br.mk_ite(br.mk_eq(x, z), z, r)); + return r; + } + + expr_ref bv_terms::mk_smod(expr* x, expr* y) { + // u := umod(abs(x), abs(y)) + // u = 0 -> 0 + // y = 0 -> x + // x < 0, y < 0 -> -u + // x < 0, y >= 0 -> y - u + // x >= 0, y < 0 -> y + u + // x >= 0, y >= 0 -> u + bool_rewriter br(m); + bv_rewriter bvr(m); + unsigned sz = bv.get_bv_size(x); + expr_ref z(bv.mk_zero(sz), m); + expr_ref abs_x = br.mk_ite(bvr.mk_sle(z, x), x, bvr.mk_bv_neg(x)); + expr_ref abs_y = br.mk_ite(bvr.mk_sle(z, y), y, bvr.mk_bv_neg(y)); + expr_ref u = bvr.mk_bv_urem(abs_x, abs_y); + expr_ref r(m); + r = br.mk_ite(br.mk_eq(u, z), z, + br.mk_ite(br.mk_eq(y, z), x, + br.mk_ite(br.mk_and(bvr.mk_sle(z, x), bvr.mk_sle(z, x)), u, + br.mk_ite(bvr.mk_sle(z, x), bvr.mk_bv_add(y, u), + br.mk_ite(bv.mk_sle(z, y), bvr.mk_bv_sub(y, u), bvr.mk_bv_neg(u)))))); + return r; + } + + expr_ref bv_terms::mk_srem(expr* x, expr* y) { + // y = 0 -> x + // else x - sdiv(x, y) * y + expr_ref r(m); + bool_rewriter br(m); + bv_rewriter bvr(m); + expr_ref z(bv.mk_zero(bv.get_bv_size(x)), m); + r = br.mk_ite(br.mk_eq(y, z), x, bvr.mk_bv_sub(x, bvr.mk_bv_mul(y, mk_sdiv(x, y)))); + return r; + } + + void bv_terms::register_uninterp(expr* e) { + if (!m.is_bool(e)) + return; + expr* x, *y; + + if (m.is_eq(e, x, y) && bv.is_bv(x)) + ; + else if (is_app(e) && to_app(e)->get_family_id() == bv.get_fid()) + ; + else + return; + m_uninterp_occurs.reserve(e->get_id() + 1); + auto& occs = m_uninterp_occurs[e->get_id()]; + ptr_vector todo; + todo.append(to_app(e)->get_num_args(), to_app(e)->get_args()); + expr_mark marked; + for (unsigned i = 0; i < todo.size(); ++i) { + e = todo[i]; + if (marked.is_marked(e)) + continue; + marked.mark(e); + if (is_app(e) && to_app(e)->get_family_id() == bv.get_fid()) { + for (expr* arg : *to_app(e)) + todo.push_back(arg); + } + else if (bv.is_bv(e)) + occs.push_back(e); + } + } +} diff --git a/src/ast/sls/sls_bv_terms.h b/src/ast/sls/sls_bv_terms.h new file mode 100644 index 00000000000..effd74eebfb --- /dev/null +++ b/src/ast/sls/sls_bv_terms.h @@ -0,0 +1,54 @@ +/*++ +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + bv_sls_terms.h + +Abstract: + + A Stochastic Local Search (SLS) engine + +Author: + + Nikolaj Bjorner (nbjorner) 2024-02-07 + +--*/ +#pragma once + +#include "util/lbool.h" +#include "util/scoped_ptr_vector.h" +#include "util/uint_set.h" +#include "ast/ast.h" +#include "ast/bv_decl_plugin.h" +#include "ast/sls/sls_stats.h" +#include "ast/sls/sls_powers.h" +#include "ast/sls/sls_bv_valuation.h" +#include "ast/sls/sls_context.h" + +namespace sls { + + class bv_terms { + ast_manager& m; + bv_util bv; + expr_ref_vector m_axioms; + vector> m_uninterp_occurs; + + expr_ref ensure_binary(expr* e); + + expr_ref mk_sdiv(expr* x, expr* y); + expr_ref mk_smod(expr* x, expr* y); + expr_ref mk_srem(expr* x, expr* y); + + void register_uninterp(expr* e); + + public: + bv_terms(sls::context& ctx); + + void register_term(expr* e); + + expr_ref_vector& axioms() { return m_axioms; } + + ptr_vector const& uninterp_occurs(expr* e) { m_uninterp_occurs.reserve(e->get_id() + 1); return m_uninterp_occurs[e->get_id()]; } + }; +} diff --git a/src/ast/sls/sls_tracker.h b/src/ast/sls/sls_bv_tracker.h similarity index 100% rename from src/ast/sls/sls_tracker.h rename to src/ast/sls/sls_bv_tracker.h diff --git a/src/ast/sls/sls_valuation.cpp b/src/ast/sls/sls_bv_valuation.cpp similarity index 70% rename from src/ast/sls/sls_valuation.cpp rename to src/ast/sls/sls_bv_valuation.cpp index b5f04a3a712..29fc6b51705 100644 --- a/src/ast/sls/sls_valuation.cpp +++ b/src/ast/sls/sls_bv_valuation.cpp @@ -18,9 +18,9 @@ Module Name: --*/ -#include "ast/sls/sls_valuation.h" +#include "ast/sls/sls_bv_valuation.h" -namespace bv { +namespace sls { void bvect::set_bw(unsigned bw) { this->bw = bw; @@ -138,6 +138,7 @@ namespace bv { set_bw(a.bw); SASSERT(a.bw == b.bw); unsigned shift = b.to_nat(b.bw); + if (shift == 0) a.copy_to(a.nw, *this); else if (shift >= a.bw) @@ -148,7 +149,7 @@ namespace bv { return *this; } - sls_valuation::sls_valuation(unsigned bw) { + bv_valuation::bv_valuation(unsigned bw) { set_bw(bw); m_lo.set_bw(bw); m_hi.set_bw(bw); @@ -162,7 +163,7 @@ namespace bv { fixed[nw - 1] = ~mask; } - void sls_valuation::set_bw(unsigned b) { + void bv_valuation::set_bw(unsigned b) { bw = b; nw = (bw + sizeof(digit_t) * 8 - 1) / (8 * sizeof(digit_t)); mask = (1 << (bw % (8 * sizeof(digit_t)))) - 1; @@ -170,7 +171,7 @@ namespace bv { mask = ~(digit_t)0; } - bool sls_valuation::commit_eval() { + bool bv_valuation::commit_eval() { for (unsigned i = 0; i < nw; ++i) if (0 != (fixed[i] & (m_bits[i] ^ eval[i]))) return false; @@ -180,11 +181,12 @@ namespace bv { for (unsigned i = 0; i < nw; ++i) m_bits[i] = eval[i]; + SASSERT(well_formed()); return true; } - bool sls_valuation::in_range(bvect const& bits) const { + bool bv_valuation::in_range(bvect const& bits) const { mpn_manager m; auto c = m.compare(m_lo.data(), nw, m_hi.data(), nw); SASSERT(!has_overflow(bits)); @@ -207,7 +209,7 @@ namespace bv { // largest dst <= src and dst is feasible // - bool sls_valuation::get_at_most(bvect const& src, bvect& dst) const { + bool bv_valuation::get_at_most(bvect const& src, bvect& dst) const { SASSERT(!has_overflow(src)); src.copy_to(nw, dst); sup_feasible(dst); @@ -227,7 +229,7 @@ namespace bv { // // smallest dst >= src and dst is feasible with respect to this. - bool sls_valuation::get_at_least(bvect const& src, bvect& dst) const { + bool bv_valuation::get_at_least(bvect const& src, bvect& dst) const { SASSERT(!has_overflow(src)); src.copy_to(nw, dst); dst.set_bw(bw); @@ -244,34 +246,38 @@ namespace bv { return true; } - bool sls_valuation::set_random_at_most(bvect const& src, random_gen& r) { + bool bv_valuation::set_random_at_most(bvect const& src, random_gen& r) { m_tmp.set_bw(bw); + //verbose_stream() << "set_random_at_most " << src << "\n"; if (!get_at_most(src, m_tmp)) return false; - if (is_zero(m_tmp) || (0 != r(10))) - return try_set(m_tmp); + if (is_zero(m_tmp) && (0 != r(2))) + return try_set(m_tmp) && m_tmp <= src; // random value below tmp set_random_below(m_tmp, r); + + //verbose_stream() << "can set " << m_tmp << " " << can_set(m_tmp) << "\n"; - return (can_set(m_tmp) || get_at_most(src, m_tmp)) && try_set(m_tmp); + return (can_set(m_tmp) || get_at_most(src, m_tmp)) && m_tmp <= src && try_set(m_tmp); } - bool sls_valuation::set_random_at_least(bvect const& src, random_gen& r) { + bool bv_valuation::set_random_at_least(bvect const& src, random_gen& r) { + m_tmp.set_bw(bw); if (!get_at_least(src, m_tmp)) return false; - if (is_ones(m_tmp) || (0 != r(10))) + if (is_ones(m_tmp) && (0 != r(10))) return try_set(m_tmp); // random value at least tmp set_random_above(m_tmp, r); - return (can_set(m_tmp) || get_at_least(src, m_tmp)) && try_set(m_tmp); + return (can_set(m_tmp) || get_at_least(src, m_tmp)) && src <= m_tmp && try_set(m_tmp); } - bool sls_valuation::set_random_in_range(bvect const& lo, bvect const& hi, random_gen& r) { + bool bv_valuation::set_random_in_range(bvect const& lo, bvect const& hi, random_gen& r) { bvect& tmp = m_tmp; if (0 == r(2)) { if (!get_at_least(lo, tmp)) @@ -279,14 +285,10 @@ namespace bv { SASSERT(in_range(tmp)); if (hi < tmp) return false; - - if (is_ones(tmp) || (0 == r() % 2)) - return try_set(tmp); set_random_above(tmp, r); round_down(tmp, [&](bvect const& t) { return hi >= t && in_range(t); }); - if (in_range(tmp) && lo <= tmp && hi >= tmp) - return try_set(tmp); - return get_at_least(lo, tmp) && hi >= tmp && try_set(tmp); + if (in_range(tmp) || get_at_least(lo, tmp)) + return lo <= tmp && tmp <= hi && try_set(tmp); } else { if (!get_at_most(hi, tmp)) @@ -294,37 +296,35 @@ namespace bv { SASSERT(in_range(tmp)); if (lo > tmp) return false; - if (is_zero(tmp) || (0 == r() % 2)) - return try_set(tmp); set_random_below(tmp, r); round_up(tmp, [&](bvect const& t) { return lo <= t && in_range(t); }); - if (in_range(tmp) && lo <= tmp && hi >= tmp) - return try_set(tmp); - return get_at_most(hi, tmp) && lo <= tmp && try_set(tmp); + if (in_range(tmp) || get_at_most(hi, tmp)) + return lo <= tmp && tmp <= hi && try_set(tmp); } + return false; } - void sls_valuation::round_down(bvect& dst, std::function const& is_feasible) { + void bv_valuation::round_down(bvect& dst, std::function const& is_feasible) { for (unsigned i = bw; !is_feasible(dst) && i-- > 0; ) if (!fixed.get(i) && dst.get(i)) dst.set(i, false); repair_sign_bits(dst); } - void sls_valuation::round_up(bvect& dst, std::function const& is_feasible) { + void bv_valuation::round_up(bvect& dst, std::function const& is_feasible) { for (unsigned i = 0; !is_feasible(dst) && i < bw; ++i) if (!fixed.get(i) && !dst.get(i)) dst.set(i, true); repair_sign_bits(dst); } - void sls_valuation::set_random_above(bvect& dst, random_gen& r) { + void bv_valuation::set_random_above(bvect& dst, random_gen& r) { for (unsigned i = 0; i < nw; ++i) dst[i] = dst[i] | (random_bits(r) & ~fixed[i]); repair_sign_bits(dst); } - void sls_valuation::set_random_below(bvect& dst, random_gen& r) { + void bv_valuation::set_random_below(bvect& dst, random_gen& r) { if (is_zero(dst)) return; unsigned n = 0, idx = UINT_MAX; @@ -341,7 +341,7 @@ namespace bv { repair_sign_bits(dst); } - bool sls_valuation::set_repair(bool try_down, bvect& dst) { + bool bv_valuation::set_repair(bool try_down, bvect& dst) { for (unsigned i = 0; i < nw; ++i) dst[i] = (~fixed[i] & dst[i]) | (fixed[i] & m_bits[i]); clear_overflow_bits(dst); @@ -358,7 +358,7 @@ namespace bv { dst.set(i, false); for (unsigned i = 0; i < bw && dst < m_lo && !in_range(dst); ++i) if (!fixed.get(i) && !dst.get(i)) - dst.set(i, true); + dst.set(i, true); } else { for (unsigned i = 0; !in_range(dst) && i < bw; ++i) @@ -377,7 +377,7 @@ namespace bv { return repaired; } - void sls_valuation::min_feasible(bvect& out) const { + void bv_valuation::min_feasible(bvect& out) const { if (m_lo < m_hi) m_lo.copy_to(nw, out); else { @@ -388,7 +388,7 @@ namespace bv { SASSERT(!has_overflow(out)); } - void sls_valuation::max_feasible(bvect& out) const { + void bv_valuation::max_feasible(bvect& out) const { if (m_lo < m_hi) { m_hi.copy_to(nw, out); sub1(out); @@ -401,7 +401,7 @@ namespace bv { SASSERT(!has_overflow(out)); } - unsigned sls_valuation::msb(bvect const& src) const { + unsigned bv_valuation::msb(bvect const& src) const { SASSERT(!has_overflow(src)); for (unsigned i = nw; i-- > 0; ) if (src[i] != 0) @@ -409,7 +409,7 @@ namespace bv { return bw; } - unsigned sls_valuation::clz(bvect const& src) const { + unsigned bv_valuation::clz(bvect const& src) const { SASSERT(!has_overflow(src)); unsigned i = bw; for (; i-- > 0; ) @@ -419,36 +419,64 @@ namespace bv { } - void sls_valuation::set_value(bvect& bits, rational const& n) { + void bv_valuation::set_value(bvect& bits, rational const& n) { for (unsigned i = 0; i < bw; ++i) bits.set(i, n.get_bit(i)); clear_overflow_bits(bits); } - void sls_valuation::get(bvect& dst) const { + void bv_valuation::get(bvect& dst) const { m_bits.copy_to(nw, dst); } - digit_t sls_valuation::random_bits(random_gen& rand) { + digit_t bv_valuation::random_bits(random_gen& rand) { digit_t r = 0; for (digit_t i = 0; i < sizeof(digit_t); ++i) r ^= rand() << (8 * i); return r; } - void sls_valuation::get_variant(bvect& dst, random_gen& r) const { + void bv_valuation::get_variant(bvect& dst, random_gen& r) const { for (unsigned i = 0; i < nw; ++i) dst[i] = (random_bits(r) & ~fixed[i]) | (fixed[i] & m_bits[i]); repair_sign_bits(dst); clear_overflow_bits(dst); } - bool sls_valuation::set_random(random_gen& r) { + bool bv_valuation::set_random(random_gen& r) { get_variant(m_tmp, r); - return set_repair(r(2) == 0, m_tmp); + repair_sign_bits(m_tmp); + if (in_range(m_tmp)) { + set(eval, m_tmp); + return true; + } + for (unsigned i = 0; i < nw; ++i) + m_tmp[i] = random_bits(r); + clear_overflow_bits(m_tmp); + // find a random offset within [lo, hi[ + SASSERT(m_lo != m_hi); + set_sub(eval, m_hi, m_lo); + for (unsigned i = bw; i-- > 0 && m_tmp >= eval; ) + m_tmp.set(i, false); + + // set eval back to m_bits. It was garbage. + set(eval, m_bits); + + // tmp := lo + tmp is within [lo, hi[ + set_add(m_tmp, m_tmp, m_lo); + // respect fixed bits + for (unsigned i = 0; i < bw; ++i) + if (fixed.get(i)) + m_tmp.set(i, m_bits.get(i)); + // decrease tmp until it is in range again + for (unsigned i = bw; i-- > 0 && !in_range(m_tmp); ) + if (!fixed.get(i)) + m_tmp.set(i, false); + repair_sign_bits(m_tmp); + return try_set(m_tmp); } - void sls_valuation::repair_sign_bits(bvect& dst) const { + void bv_valuation::repair_sign_bits(bvect& dst) const { if (m_signed_prefix == 0) return; bool sign = m_signed_prefix == bw ? dst.get(bw - 1) : dst.get(bw - m_signed_prefix - 1); @@ -474,7 +502,7 @@ namespace bv { // 0 = (new_bits ^ bits) & fixedf // also check that new_bits are in range // - bool sls_valuation::can_set(bvect const& new_bits) const { + bool bv_valuation::can_set(bvect const& new_bits) const { SASSERT(!has_overflow(new_bits)); for (unsigned i = 0; i < nw; ++i) if (0 != ((new_bits[i] ^ m_bits[i]) & fixed[i])) @@ -482,28 +510,28 @@ namespace bv { return in_range(new_bits); } - unsigned sls_valuation::to_nat(unsigned max_n) const { + unsigned bv_valuation::to_nat(unsigned max_n) const { bvect const& d = m_bits; SASSERT(!has_overflow(d)); return d.to_nat(max_n); } - void sls_valuation::shift_right(bvect& out, unsigned shift) const { + void bv_valuation::shift_right(bvect& out, unsigned shift) const { SASSERT(shift < bw); for (unsigned i = 0; i < bw; ++i) - out.set(i, i + shift < bw ? m_bits.get(i + shift) : false); + out.set(i, i + shift < bw ? out.get(i + shift) : false); SASSERT(well_formed()); } - void sls_valuation::add_range(rational l, rational h) { + void bv_valuation::add_range(rational l, rational h) { l = mod(l, rational::power_of_two(bw)); h = mod(h, rational::power_of_two(bw)); if (h == l) return; -// verbose_stream() << *this << " " << l << " " << h << " --> "; + //verbose_stream() << *this << " lo " << l << " hi " << h << " --> "; if (m_lo == m_hi) { set_value(m_lo, l); @@ -555,7 +583,7 @@ namespace bv { // update bits based on ranges // - unsigned sls_valuation::diff_index(bvect const& a) const { + unsigned bv_valuation::diff_index(bvect const& a) const { unsigned index = 0; for (unsigned i = nw; i-- > 0; ) { auto diff = fixed[i] & (m_bits[i] ^ a[i]); @@ -565,55 +593,87 @@ namespace bv { return index; } - void sls_valuation::inf_feasible(bvect& a) const { + // The least a' >= a, such that the fixed bits in bits agree with a'. + // 0 if there is no such a'. + void bv_valuation::inf_feasible(bvect& a) const { unsigned lo_index = diff_index(a); - if (lo_index != 0) { - lo_index--; - SASSERT(a.get(lo_index) != m_bits.get(lo_index)); - SASSERT(fixed.get(lo_index)); - for (unsigned i = 0; i <= lo_index; ++i) { - if (!fixed.get(i)) - a.set(i, false); - else if (fixed.get(i)) - a.set(i, m_bits.get(i)); - } - if (!a.get(lo_index)) { - for (unsigned i = lo_index + 1; i < bw; ++i) - if (!fixed.get(i) && !a.get(i)) { - a.set(i, true); - break; - } + if (lo_index == 0) + return; + --lo_index; + + // decrement a[lo_index:0] maximally + SASSERT(a.get(lo_index) != m_bits.get(lo_index)); + SASSERT(fixed.get(lo_index)); + for (unsigned i = 0; i <= lo_index; ++i) { + if (!fixed.get(i)) + a.set(i, false); + else if (fixed.get(i)) + a.set(i, m_bits.get(i)); + } + + // the previous value of a[lo_index] was 0. + // a[lo_index:0] was incremented, so no need to adjust bits a[:lo_index+1] + if (a.get(lo_index)) + return; + + // find the minimal increment within a[:lo_index+1] + for (unsigned i = lo_index + 1; i < bw; ++i) { + if (!fixed.get(i) && !a.get(i)) { + a.set(i, true); + return; } } + // there is no feasiable value a' >= a, so find the least + // feasiable value a' >= 0. + for (unsigned i = 0; i < bw; ++i) + if (!fixed.get(i)) + a.set(i, false); } - void sls_valuation::sup_feasible(bvect& a) const { + // The greatest a' <= a, such that the fixed bits in bits agree with a'. + // the greatest a' <= -1 if there is no such a'. + + void bv_valuation::sup_feasible(bvect& a) const { unsigned hi_index = diff_index(a); - if (hi_index != 0) { - hi_index--; - SASSERT(a.get(hi_index) != m_bits.get(hi_index)); - SASSERT(fixed.get(hi_index)); - for (unsigned i = 0; i <= hi_index; ++i) { - if (!fixed.get(i)) - a.set(i, true); - else if (fixed.get(i)) - a.set(i, m_bits.get(i)); - } - if (a.get(hi_index)) { - for (unsigned i = hi_index + 1; i < bw; ++i) - if (!fixed.get(i) && a.get(i)) { - a.set(i, false); - break; - } + if (hi_index == 0) + return; + --hi_index; + SASSERT(a.get(hi_index) != m_bits.get(hi_index)); + SASSERT(fixed.get(hi_index)); + + // increment a[hi_index:0] maximally + for (unsigned i = 0; i <= hi_index; ++i) { + if (!fixed.get(i)) + a.set(i, true); + else if (fixed.get(i)) + a.set(i, m_bits.get(i)); + } + + // If a[hi_index:0] was decremented, then no need to adjust bits a[:hi_index+1] + if (!a.get(hi_index)) + return; + + // find the minimal decrement within a[:hi_index+1] + for (unsigned i = hi_index + 1; i < bw; ++i) { + if (!fixed.get(i) && a.get(i)) { + a.set(i, false); + return; } } + + // a[hi_index:0] was incremented, but a[:hi_index+1] cannot be decremented. + // maximize a[:hi_index+1] to model wrap around behavior. + for (unsigned i = hi_index + 1; i < bw; ++i) + if (!fixed.get(i)) + a.set(i, true); } - void sls_valuation::tighten_range() { + void bv_valuation::tighten_range() { + // verbose_stream() << "tighten " << m_lo << " " << m_hi << " " << m_bits << "\n"; if (m_lo == m_hi) - return; + return; inf_feasible(m_lo); @@ -625,59 +685,8 @@ namespace bv { add1(hi1); hi1.copy_to(nw, m_hi); - /* - unsigned lo_index = 0, hi_index = 0; - for (unsigned i = nw; i-- > 0; ) { - auto lo_diff = (fixed[i] & (m_bits[i] ^ m_lo[i])); - if (lo_diff != 0 && lo_index == 0) - lo_index = 1 + i * 8 * sizeof(digit_t) + log2(lo_diff); - auto hi_diff = (fixed[i] & (m_bits[i] ^ hi1[i])); - if (hi_diff != 0 && hi_index == 0) - hi_index = 1 + i * 8 * sizeof(digit_t) + log2(hi_diff); - } - - if (lo_index != 0) { - lo_index--; - SASSERT(m_lo.get(lo_index) != m_bits.get(lo_index)); - SASSERT(fixed.get(lo_index)); - for (unsigned i = 0; i <= lo_index; ++i) { - if (!fixed.get(i)) - m_lo.set(i, false); - else if (fixed.get(i)) - m_lo.set(i, m_bits.get(i)); - } - if (!m_bits.get(lo_index)) { - for (unsigned i = lo_index + 1; i < bw; ++i) - if (!fixed.get(i) && !m_lo.get(i)) { - m_lo.set(i, true); - break; - } - } - } - if (hi_index != 0) { - hi_index--; - SASSERT(hi1.get(hi_index) != m_bits.get(hi_index)); - SASSERT(fixed.get(hi_index)); - for (unsigned i = 0; i <= hi_index; ++i) { - if (!fixed.get(i)) - hi1.set(i, true); - else if (fixed.get(i)) - hi1.set(i, m_bits.get(i)); - } - if (m_bits.get(hi_index)) { - for (unsigned i = hi_index + 1; i < bw; ++i) - if (!fixed.get(i) && hi1.get(i)) { - hi1.set(i, false); - break; - } - } - add1(hi1); - hi1.copy_to(nw, m_hi); - } - */ - if (has_range() && !in_range(m_bits)) - m_bits = m_lo; + m_lo.copy_to(nw, m_bits); if (mod(lo() + 1, rational::power_of_two(bw)) == hi()) for (unsigned i = 0; i < nw; ++i) @@ -687,16 +696,17 @@ namespace bv { if (hi() < rational::power_of_two(i)) fixed.set(i, true); + // verbose_stream() << "post tighten " << m_lo << " " << m_hi << " " << m_bits << "\n"; SASSERT(well_formed()); } - void sls_valuation::set_sub(bvect& out, bvect const& a, bvect const& b) const { + void bv_valuation::set_sub(bvect& out, bvect const& a, bvect const& b) const { digit_t c; mpn_manager().sub(a.data(), nw, b.data(), nw, out.data(), &c); clear_overflow_bits(out); } - bool sls_valuation::set_add(bvect& out, bvect const& a, bvect const& b) const { + bool bv_valuation::set_add(bvect& out, bvect const& a, bvect const& b) const { digit_t c; mpn_manager().add(a.data(), nw, b.data(), nw, out.data(), nw + 1, &c); bool ovfl = out[nw] != 0 || has_overflow(out); @@ -704,7 +714,9 @@ namespace bv { return ovfl; } - bool sls_valuation::set_mul(bvect& out, bvect const& a, bvect const& b, bool check_overflow) const { + bool bv_valuation::set_mul(bvect& out, bvect const& a, bvect const& b, bool check_overflow) const { + out.reserve(2 * nw); + SASSERT(out.size() >= 2 * nw); mpn_manager().mul(a.data(), nw, b.data(), nw, out.data()); bool ovfl = false; if (check_overflow) { @@ -716,7 +728,7 @@ namespace bv { return ovfl; } - bool sls_valuation::is_power_of2(bvect const& src) const { + bool bv_valuation::is_power_of2(bvect const& src) const { unsigned c = 0; for (unsigned i = 0; i < nw; ++i) c += get_num_1bits(src[i]); diff --git a/src/ast/sls/sls_valuation.h b/src/ast/sls/sls_bv_valuation.h similarity index 94% rename from src/ast/sls/sls_valuation.h rename to src/ast/sls/sls_bv_valuation.h index f10fd8f0bfc..9156c934fef 100644 --- a/src/ast/sls/sls_valuation.h +++ b/src/ast/sls/sls_bv_valuation.h @@ -3,7 +3,7 @@ Copyright (c) 2024 Microsoft Corporation Module Name: - sls_valuation.h + sls_bv_valuation.h Abstract: @@ -20,12 +20,10 @@ Module Name: #include "util/params.h" #include "util/scoped_ptr_vector.h" #include "util/uint_set.h" -#include "ast/ast.h" -#include "ast/sls/sls_stats.h" -#include "ast/sls/sls_powers.h" -#include "ast/bv_decl_plugin.h" +#include "util/mpz.h" +#include "util/rational.h" -namespace bv { +namespace sls { class bvect : public svector { public: @@ -106,7 +104,7 @@ namespace bv { inline bool operator!=(bvect const& a, bvect const& b) { return !(a == b); } std::ostream& operator<<(std::ostream& out, bvect const& v); - class sls_valuation { + class bv_valuation { protected: bvect m_bits; bvect m_lo, m_hi; // range assignment to bit-vector, as wrap-around interval @@ -124,8 +122,8 @@ namespace bv { bvect fixed; // bit assignment and don't care bit bvect eval; // current evaluation - - sls_valuation(unsigned bw); + + bv_valuation(unsigned bw); void set_bw(unsigned bw); void set_signed(unsigned prefix) { m_signed_prefix = prefix; } @@ -134,7 +132,9 @@ namespace bv { digit_t bits(unsigned i) const { return m_bits[i]; } bvect const& bits() const { return m_bits; } + bvect const& tmp_bits(bool use_current) const { return use_current ? m_bits : m_tmp; } bool commit_eval(); + bool is_fixed() const { for (unsigned i = bw; i-- > 0; ) if (!fixed.get(i)) return false; return true; } bool get_bit(unsigned i) const { return m_bits.get(i); } bool try_set_bit(unsigned i, bool b) { @@ -166,6 +166,9 @@ namespace bv { bool has_range() const { return m_lo != m_hi; } void tighten_range(); + void save_value() { m_bits.copy_to(nw, m_tmp); } + void restore_value() { m_tmp.copy_to(nw, m_bits); } + void clear_overflow_bits(bvect& bits) const { SASSERT(nw > 0); bits[nw - 1] &= mask; @@ -175,7 +178,7 @@ namespace bv { bool in_range(bvect const& bits) const; bool can_set(bvect const& bits) const; - bool eq(sls_valuation const& other) const { return eq(other.m_bits); } + bool eq(bv_valuation const& other) const { return eq(other.m_bits); } bool eq(bvect const& other) const { return other == m_bits; } bool is_zero() const { return is_zero(m_bits); } @@ -342,6 +345,6 @@ namespace bv { }; - inline std::ostream& operator<<(std::ostream& out, sls_valuation const& v) { return v.display(out); } + inline std::ostream& operator<<(std::ostream& out, bv_valuation const& v) { return v.display(out); } } diff --git a/src/ast/sls/sls_context.cpp b/src/ast/sls/sls_context.cpp new file mode 100644 index 00000000000..051ae98a7c9 --- /dev/null +++ b/src/ast/sls/sls_context.cpp @@ -0,0 +1,654 @@ +/*++ +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + smt_sls.cpp + +Abstract: + + A Stochastic Local Search (SLS) Context. + +Author: + + Nikolaj Bjorner (nbjorner) 2024-06-24 + +--*/ + +#include "ast/sls/sls_context.h" +#include "ast/sls/sls_euf_plugin.h" +#include "ast/sls/sls_arith_plugin.h" +#include "ast/sls/sls_array_plugin.h" +#include "ast/sls/sls_bv_plugin.h" +#include "ast/sls/sls_basic_plugin.h" +#include "ast/sls/sls_datatype_plugin.h" +#include "ast/ast_ll_pp.h" +#include "ast/ast_pp.h" +#include "smt/params/smt_params_helper.hpp" + +namespace sls { + + plugin::plugin(context& c): + ctx(c), + m(c.get_manager()) { + } + + context::context(ast_manager& m, sat_solver_context& s) : + m(m), s(s), m_atoms(m), m_allterms(m), + m_gd(*this), + m_ld(*this), + m_repair_down(m.get_num_asts(), m_gd), + m_repair_up(m.get_num_asts(), m_ld), + m_constraint_trail(m), + m_todo(m) { + } + + void context::updt_params(params_ref const& p) { + smt_params_helper smtp(p); + m_rand.set_seed(smtp.random_seed()); + m_params.append(p); + } + + void context::register_plugin(plugin* p) { + m_plugins.reserve(p->fid() + 1); + m_plugins.set(p->fid(), p); + } + + void context::ensure_plugin(family_id fid) { + if (m_plugins.get(fid, nullptr)) + return; + else if (fid == arith_family_id) + register_plugin(alloc(arith_plugin, *this)); + else if (fid == user_sort_family_id) + register_plugin(alloc(euf_plugin, *this)); + else if (fid == basic_family_id) + register_plugin(alloc(basic_plugin, *this)); + else if (fid == bv_util(m).get_family_id()) + register_plugin(alloc(bv_plugin, *this)); + else if (fid == array_util(m).get_family_id()) + register_plugin(alloc(array_plugin, *this)); + else if (fid == datatype_util(m).get_family_id()) + register_plugin(alloc(datatype_plugin, *this)); + else if (fid == null_family_id) + ; + else + verbose_stream() << "did not find plugin for " << fid << "\n"; + } + + scoped_ptr& context::egraph() { + return euf().egraph(); + } + + euf_plugin& context::euf() { + auto fid = user_sort_family_id; + auto p = m_plugins.get(fid, nullptr); + if (!p) { + p = alloc(euf_plugin, *this); + register_plugin(p); + } + return *dynamic_cast(p); + } + + void context::ensure_plugin(expr* e) { + auto fid = get_fid(e); + ensure_plugin(fid); + fid = e->get_sort()->get_family_id(); + ensure_plugin(fid); + } + + + void context::register_atom(sat::bool_var v, expr* e) { + m_atoms.setx(v, e); + m_atom2bool_var.setx(e->get_id(), v, sat::null_bool_var); + } + + void context::on_restart() { + for (auto p : m_plugins) + if (p) + p->on_restart(); + } + + lbool context::check() { + // + // initialize data-structures if not done before. + // identify minimal feasible assignment to literals. + // sub-expressions within assignment are relevant. + // Use timestamps to make it incremental. + // + init(); + while (unsat().empty() && m.inc()) { + + propagate_boolean_assignment(); + + + // verbose_stream() << "propagate " << unsat().size() << " " << m_new_constraint << "\n"; + + if (m_new_constraint || !unsat().empty()) + return l_undef; + + if (all_of(m_plugins, [&](auto* p) { return !p || p->is_sat(); })) { + values2model(); + return l_true; + } + } + return l_undef; + } + + void context::values2model() { + model_ref mdl = alloc(model, m); + expr_ref_vector args(m); + for (expr* e : subterms()) + if (is_uninterp_const(e)) + mdl->register_decl(to_app(e)->get_decl(), get_value(e)); + + for (expr* e : subterms()) { + if (!is_app(e)) + continue; + auto f = to_app(e)->get_decl(); + if (!include_func_interp(f)) + continue; + auto v = get_value(e); + auto fi = mdl->get_func_interp(f); + if (!fi) { + fi = alloc(func_interp, m, f->get_arity()); + mdl->register_decl(f, fi); + } + args.reset(); + for (expr* arg : *to_app(e)) { + args.push_back(get_value(arg)); + SASSERT(args.back()); + } + SASSERT(f->get_arity() == args.size()); + if (!fi->get_entry(args.data())) + fi->insert_new_entry(args.data(), v); + } + + s.on_model(mdl); + // verbose_stream() << *mdl << "\n"; + TRACE("sls", display(tout)); + } + + void context::propagate_boolean_assignment() { + reinit_relevant(); + + for (auto p : m_plugins) + if (p) + p->start_propagation(); + + for (sat::literal lit : root_literals()) + propagate_literal(lit); + + if (m_new_constraint) + return; + + while (!m_new_constraint && m.inc() && (!m_repair_up.empty() || !m_repair_down.empty())) { + while (!m_repair_down.empty() && !m_new_constraint && m.inc()) { + auto id = m_repair_down.erase_min(); + expr* e = term(id); + TRACE("sls", tout << "repair down " << mk_bounded_pp(e, m) << "\n"); + if (is_app(e)) { + auto p = m_plugins.get(get_fid(e), nullptr); + ++m_stats.m_num_repair_down; + if (p && !p->repair_down(to_app(e)) && !m_repair_up.contains(e->get_id())) { + IF_VERBOSE(3, verbose_stream() << "revert repair: " << mk_bounded_pp(e, m) << "\n"); + m_repair_up.insert(e->get_id()); + } + } + } + while (!m_repair_up.empty() && !m_new_constraint && m.inc()) { + auto id = m_repair_up.erase_min(); + expr* e = term(id); + ++m_stats.m_num_repair_up; + TRACE("sls", tout << "repair up " << mk_bounded_pp(e, m) << "\n"); + if (is_app(e)) { + auto p = m_plugins.get(get_fid(e), nullptr); + if (p) + p->repair_up(to_app(e)); + } + } + } + + repair_literals(); + + // propagate "final checks" + bool propagated = true; + while (propagated && !m_new_constraint) { + propagated = false; + for (auto p : m_plugins) + propagated |= p && !m_new_constraint && p->propagate(); + } + + } + + void context::repair_literals() { + for (sat::bool_var v = 0; v < s.num_vars() && !m_new_constraint; ++v) { + auto a = atom(v); + if (!a) + continue; + sat::literal lit(v, !is_true(v)); + auto p = m_plugins.get(get_fid(a), nullptr); + if (p) + p->repair_literal(lit); + } + } + + family_id context::get_fid(expr* e) const { + if (!is_app(e)) + return user_sort_family_id; + family_id fid = to_app(e)->get_family_id(); + if (m.is_eq(e)) + fid = to_app(e)->get_arg(0)->get_sort()->get_family_id(); + if (m.is_distinct(e)) + fid = to_app(e)->get_arg(0)->get_sort()->get_family_id(); + if ((fid == null_family_id && to_app(e)->get_num_args() > 0) || fid == model_value_family_id) + fid = user_sort_family_id; + return fid; + } + + void context::propagate_literal(sat::literal lit) { + if (!is_true(lit)) + return; + auto a = atom(lit.var()); + if (!a) + return; + family_id fid = get_fid(a); + auto p = m_plugins.get(fid, nullptr); + if (p) + p->propagate_literal(lit); + if (!is_true(lit)) { + m_new_constraint = true; + } + } + + bool context::is_true(expr* e) { + SASSERT(m.is_bool(e)); + auto v = m_atom2bool_var.get(e->get_id(), sat::null_bool_var); + if (v != sat::null_bool_var) + return m.is_true(m_plugins[basic_family_id]->get_value(e)); + else + return is_true(v); + } + + bool context::is_fixed(expr* e) { + // is this a Boolean literal that is a unit? + return false; + } + + expr_ref context::get_value(expr* e) { + sort* s = e->get_sort(); + auto fid = s->get_family_id(); + auto p = m_plugins.get(fid, nullptr); + if (p) + return p->get_value(e); + verbose_stream() << fid << " " << m.get_family_name(fid) << " " << mk_pp(e, m) << "\n"; + UNREACHABLE(); + return expr_ref(e, m); + } + + bool context::set_value(expr * e, expr * v) { + return any_of(m_plugins, [&](auto p) { return p && p->set_value(e, v); }); + } + + bool context::is_relevant(expr* e) { + unsigned id = e->get_id(); + if (m_relevant.contains(id)) + return true; + if (m_visited.contains(id)) + return false; + m_visited.insert(id); + if (m_parents.size() <= id) + verbose_stream() << "not in map " << mk_bounded_pp(e, m) << "\n"; + for (auto p : m_parents[id]) { + if (is_relevant(p)) { + m_relevant.insert(id); + return true; + } + } + return false; + } + + void context::add_constraint(expr* e) { + if (m_constraint_ids.contains(e->get_id())) + return; + m_constraint_ids.insert(e->get_id()); + m_constraint_trail.push_back(e); + add_clause(e); + m_new_constraint = true; + ++m_stats.m_num_constraints; + } + + void context::add_clause(expr* f) { + expr_ref _e(f, m); + expr* g, * h, * k; + sat::literal_vector clause; + if (m.is_true(f)) + return; + if (m.is_not(f, g) && m.is_not(g, g)) { + add_clause(g); + return; + } + bool sign = m.is_not(f, f); + if (!sign && m.is_or(f)) { + clause.reset(); + for (auto arg : *to_app(f)) + clause.push_back(mk_literal(arg)); + s.add_clause(clause.size(), clause.data()); + } + else if (!sign && m.is_and(f)) { + for (auto arg : *to_app(f)) + add_clause(arg); + } + else if (sign && m.is_or(f)) { + for (auto arg : *to_app(f)) { + expr_ref fml(m.mk_not(arg), m); + add_clause(fml); + } + } + else if (!sign && m.is_implies(f, g, h)) { + clause.reset(); + clause.push_back(~mk_literal(g)); + clause.push_back(mk_literal(h)); + s.add_clause(clause.size(), clause.data()); + } + else if (sign && m.is_implies(f, g, h)) { + expr_ref fml(m.mk_not(h), m); + add_clause(fml); + add_clause(g); + } + else if (sign && m.is_and(f)) { + clause.reset(); + for (auto arg : *to_app(f)) + clause.push_back(~mk_literal(arg)); + s.add_clause(clause.size(), clause.data()); + } + else if (m.is_iff(f, g, h)) { + auto lit1 = mk_literal(g); + auto lit2 = mk_literal(h); + sat::literal cls1[2] = { sign ? lit1 : ~lit1, lit2 }; + sat::literal cls2[2] = { sign ? ~lit1 : lit1, ~lit2 }; + s.add_clause(2, cls1); + s.add_clause(2, cls2); + } + else if (m.is_ite(f, g, h, k)) { + auto lit1 = mk_literal(g); + auto lit2 = mk_literal(h); + auto lit3 = mk_literal(k); + // (g -> h) & (~g -> k) + // (g & h) | (~g & k) + // negated: (g -> ~h) & (g -> ~k) + sat::literal cls1[2] = { ~lit1, sign ? ~lit2 : lit2 }; + sat::literal cls2[2] = { lit1, sign ? ~lit3 : lit3 }; + s.add_clause(2, cls1); + s.add_clause(2, cls2); + } + else { + sat::literal lit = mk_literal(f); + if (sign) + lit.neg(); + s.add_clause(1, &lit); + } + } + + void context::add_clause(sat::literal_vector const& lits) { + s.add_clause(lits.size(), lits.data()); + m_new_constraint = true; + ++m_stats.m_num_constraints; + } + + sat::literal context::mk_literal() { + sat::bool_var v = s.add_var(); + return sat::literal(v, false); + } + + sat::literal context::mk_literal(expr* e) { + expr_ref _e(e, m); + sat::literal lit; + bool neg = false; + expr* a, * b, * c; + while (m.is_not(e, e)) + neg = !neg; + auto v = m_atom2bool_var.get(e->get_id(), sat::null_bool_var); + if (v != sat::null_bool_var) + return sat::literal(v, neg); + sat::literal_vector clause; + lit = mk_literal(); + register_atom(lit.var(), e); + if (m.is_true(e)) { + clause.push_back(lit); + s.add_clause(clause.size(), clause.data()); + } + else if (m.is_false(e)) { + clause.push_back(~lit); + s.add_clause(clause.size(), clause.data()); + } + else if (m.is_and(e)) { + for (expr* arg : *to_app(e)) { + auto lit2 = mk_literal(arg); + clause.push_back(~lit2); + sat::literal lits[2] = { ~lit, lit2 }; + s.add_clause(2, lits); + } + clause.push_back(lit); + s.add_clause(clause.size(), clause.data()); + } + else if (m.is_or(e)) { + for (expr* arg : *to_app(e)) { + auto lit2 = mk_literal(arg); + clause.push_back(lit2); + sat::literal lits[2] = { lit, ~lit2 }; + s.add_clause(2, lits); + } + clause.push_back(~lit); + s.add_clause(clause.size(), clause.data()); + } + else if (m.is_iff(e, a, b) || m.is_xor(e, a, b)) { + auto lit1 = mk_literal(a); + auto lit2 = mk_literal(b); + if (m.is_xor(e)) + lit2.neg(); + sat::literal cls1[3] = { ~lit, ~lit1, lit2 }; + sat::literal cls2[3] = { ~lit, lit1, ~lit2 }; + sat::literal cls3[3] = { lit, lit1, lit2 }; + sat::literal cls4[3] = { lit, ~lit1, ~lit2 }; + s.add_clause(3, cls1); + s.add_clause(3, cls2); + s.add_clause(3, cls3); + s.add_clause(3, cls4); + } + else if (m.is_ite(e, a, b, c)) { + auto lit1 = mk_literal(a); + auto lit2 = mk_literal(b); + auto lit3 = mk_literal(c); + sat::literal cls1[3] = { ~lit, ~lit1, lit2 }; + sat::literal cls2[3] = { ~lit, lit1, lit3 }; + sat::literal cls3[3] = { lit, ~lit1, ~lit2 }; + sat::literal cls4[3] = { lit, lit1, ~lit3 }; + s.add_clause(3, cls1); + s.add_clause(3, cls2); + s.add_clause(3, cls3); + s.add_clause(3, cls4); + } + else + register_terms(e); + + return neg ? ~lit : lit; + } + + + void context::init() { + m_new_constraint = false; + if (m_initialized) + return; + m_initialized = true; + m_unit_literals.reset(); + m_unit_indices.reset(); + for (auto const& clause : s.clauses()) + if (clause.m_clause.size() == 1) + m_unit_literals.push_back(clause.m_clause[0]); + for (sat::literal lit : m_unit_literals) + m_unit_indices.insert(lit.index()); + + IF_VERBOSE(3, verbose_stream() << "UNITS " << m_unit_literals << "\n"); + for (unsigned i = 0; i < m_atoms.size(); ++i) + if (m_atoms.get(i)) + register_terms(m_atoms.get(i)); + for (auto p : m_plugins) + if (p) + p->initialize(); + } + + void context::register_terms(expr* e) { + auto is_visited = [&](expr* e) { + return nullptr != m_allterms.get(e->get_id(), nullptr); + }; + + auto visit = [&](expr* e) { + m_allterms.setx(e->get_id(), e); + ensure_plugin(e); + register_term(e); + }; + if (is_visited(e)) + return; + m_subterms.reset(); + m_todo.push_back(e); + if (m_todo.size() > 1) + return; + while (!m_todo.empty()) { + expr* e = m_todo.back(); + if (is_visited(e)) + m_todo.pop_back(); + else if (is_app(e)) { + if (all_of(*to_app(e), [&](expr* arg) { return is_visited(arg); })) { + expr_ref _e(e, m); + m_todo.pop_back(); + m_parents.reserve(to_app(e)->get_id() + 1); + for (expr* arg : *to_app(e)) { + m_parents.reserve(arg->get_id() + 1); + m_parents[arg->get_id()].push_back(e); + } + if (m.is_bool(e)) + mk_literal(e); + visit(e); + } + else { + for (expr* arg : *to_app(e)) + m_todo.push_back(arg); + } + } + else { + expr_ref _e(e, m); + m_todo.pop_back(); + visit(e); + } + } + } + + void context::new_value_eh(expr* e) { + DEBUG_CODE( + if (m.is_bool(e)) { + auto v = m_atom2bool_var.get(e->get_id(), sat::null_bool_var); + if (v != sat::null_bool_var) { + SASSERT(m.is_true(get_value(e)) == is_true(v)); + } + } + ); + + m_repair_down.reserve(e->get_id() + 1); + m_repair_up.reserve(e->get_id() + 1); + if (!term(e->get_id())) + verbose_stream() << "no term " << mk_bounded_pp(e, m) << "\n"; + SASSERT(e == term(e->get_id())); + if (!m_repair_down.contains(e->get_id())) + m_repair_down.insert(e->get_id()); + for (auto p : parents(e)) { + auto pid = p->get_id(); + m_repair_up.reserve(pid + 1); + m_repair_down.reserve(pid + 1); + if (!m_repair_up.contains(pid)) + m_repair_up.insert(pid); + } + } + + void context::register_term(expr* e) { + for (auto p : m_plugins) + if (p) + p->register_term(e); + } + + ptr_vector const& context::subterms() { + if (!m_subterms.empty()) + return m_subterms; + for (auto e : m_allterms) + if (e) + m_subterms.push_back(e); + std::stable_sort(m_subterms.begin(), m_subterms.end(), + [](expr* a, expr* b) { return get_depth(a) < get_depth(b); }); + return m_subterms; + } + + void context::reinit_relevant() { + m_relevant.reset(); + m_visited.reset(); + m_root_literals.reset(); + + + for (auto const& clause : s.clauses()) { + bool has_relevant = false; + unsigned n = 0; + sat::literal selected_lit = sat::null_literal; + for (auto lit : clause) { + auto atm = m_atoms.get(lit.var(), nullptr); + if (!atm) + continue; + auto a = atm->get_id(); + if (!is_true(lit)) + continue; + if (m_relevant.contains(a)) { + has_relevant = true; + break; + } + if (m_rand() % ++n == 0) + selected_lit = lit; + } + if (!has_relevant && selected_lit != sat::null_literal) { + m_relevant.insert(m_atoms[selected_lit.var()]->get_id()); + m_root_literals.push_back(selected_lit); + } + } + shuffle(m_root_literals.size(), m_root_literals.data(), m_rand); + } + + std::ostream& context::display(std::ostream& out) const { + for (auto id : m_repair_down) + out << "d " << mk_bounded_pp(term(id), m) << "\n"; + for (auto id : m_repair_up) + out << "u " << mk_bounded_pp(term(id), m) << "\n"; + for (unsigned v = 0; v < m_atoms.size(); ++v) { + auto e = m_atoms[v]; + if (e) + out << v << ": " << mk_bounded_pp(e, m) << " := " << (is_true(v)?"T":"F") << "\n"; + + } + for (auto p : m_plugins) + if (p) + p->display(out); + + return out; + } + + void context::collect_statistics(statistics& st) const { + for (auto p : m_plugins) + if (p) + p->collect_statistics(st); + st.update("sls-repair-down", m_stats.m_num_repair_down); + st.update("sls-repair-up", m_stats.m_num_repair_up); + st.update("sls-constraints", m_stats.m_num_constraints); + } + + void context::reset_statistics() { + for (auto p : m_plugins) + if (p) + p->reset_statistics(); + m_stats.reset(); + } +} diff --git a/src/ast/sls/sls_context.h b/src/ast/sls/sls_context.h new file mode 100644 index 00000000000..95f9142922a --- /dev/null +++ b/src/ast/sls/sls_context.h @@ -0,0 +1,212 @@ +/*++ +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + sls_context.h + +Abstract: + + A Stochastic Local Search (SLS) Context. + +Author: + + Nikolaj Bjorner (nbjorner) 2024-06-24 + +--*/ +#pragma once + +#include "util/sat_literal.h" +#include "util/sat_sls.h" +#include "util/statistics.h" +#include "ast/ast.h" +#include "ast/euf/euf_egraph.h" +#include "model/model.h" +#include "util/scoped_ptr_vector.h" +#include "util/obj_hashtable.h" +#include "util/heap.h" + +namespace sls { + + class context; + class euf_plugin; + + class plugin { + protected: + context& ctx; + ast_manager& m; + family_id m_fid; + public: + plugin(context& c); + virtual ~plugin() {} + virtual family_id fid() { return m_fid; } + virtual void register_term(expr* e) = 0; + virtual expr_ref get_value(expr* e) = 0; + virtual void initialize() = 0; + virtual void start_propagation() {}; + virtual bool propagate() = 0; + virtual void propagate_literal(sat::literal lit) = 0; + virtual void repair_literal(sat::literal lit) = 0; + virtual bool repair_down(app* e) = 0; + virtual void repair_up(app* e) = 0; + virtual bool is_sat() = 0; + virtual void on_rescale() {}; + virtual void on_restart() {}; + virtual std::ostream& display(std::ostream& out) const = 0; + virtual bool set_value(expr* e, expr* v) = 0; + virtual void collect_statistics(statistics& st) const = 0; + virtual void reset_statistics() = 0; + virtual bool include_func_interp(func_decl* f) const { return false; } + }; + + using clause = ptr_iterator; + + class sat_solver_context { + public: + virtual ~sat_solver_context() {} + virtual vector const& clauses() const = 0; + virtual sat::clause_info const& get_clause(unsigned idx) const = 0; + virtual ptr_iterator get_use_list(sat::literal lit) = 0; + virtual void flip(sat::bool_var v) = 0; + virtual double reward(sat::bool_var v) = 0; + virtual double get_weigth(unsigned clause_idx) = 0; + virtual bool is_true(sat::literal lit) = 0; + virtual unsigned num_vars() const = 0; + virtual indexed_uint_set const& unsat() const = 0; + virtual void on_model(model_ref& mdl) = 0; + virtual sat::bool_var add_var() = 0; + virtual void add_clause(unsigned n, sat::literal const* lits) = 0; + virtual void force_restart() = 0; + virtual std::ostream& display(std::ostream& out) = 0; + }; + + class context { + struct greater_depth { + context& c; + greater_depth(context& c) : c(c) {} + bool operator()(unsigned x, unsigned y) const { + return get_depth(c.term(x)) > get_depth(c.term(y)); + } + }; + + struct less_depth { + context& c; + less_depth(context& c) : c(c) {} + bool operator()(unsigned x, unsigned y) const { + return get_depth(c.term(x)) < get_depth(c.term(y)); + } + }; + + struct stats { + unsigned m_num_repair_down = 0; + unsigned m_num_repair_up = 0; + unsigned m_num_constraints = 0; + void reset() { memset(this, 0, sizeof(*this)); } + }; + + ast_manager& m; + sat_solver_context& s; + scoped_ptr_vector m_plugins; + indexed_uint_set m_relevant, m_visited; + expr_ref_vector m_atoms; + unsigned_vector m_atom2bool_var; + params_ref m_params; + vector> m_parents; + sat::literal_vector m_root_literals, m_unit_literals; + indexed_uint_set m_unit_indices; + random_gen m_rand; + bool m_initialized = false; + bool m_new_constraint = false; + bool m_dirty = false; + expr_ref_vector m_allterms; + ptr_vector m_subterms; + greater_depth m_gd; + less_depth m_ld; + heap m_repair_down; + heap m_repair_up; + uint_set m_constraint_ids; + expr_ref_vector m_constraint_trail; + stats m_stats; + + void register_plugin(plugin* p); + + void init(); + expr_ref_vector m_todo; + void register_terms(expr* e); + void register_term(expr* e); + + void propagate_boolean_assignment(); + void propagate_literal(sat::literal lit); + void repair_literals(); + + void values2model(); + + void ensure_plugin(expr* e); + void ensure_plugin(family_id fid); + family_id get_fid(expr* e) const; + + + sat::literal mk_literal(); + + public: + context(ast_manager& m, sat_solver_context& s); + + // Between SAT/SMT solver and context. + void register_atom(sat::bool_var v, expr* e); + lbool check(); + + void on_restart(); + void updt_params(params_ref const& p); + params_ref const& get_params() const { return m_params; } + + // expose sat_solver to plugins + vector const& clauses() const { return s.clauses(); } + sat::clause_info const& get_clause(unsigned idx) const { return s.get_clause(idx); } + ptr_iterator get_use_list(sat::literal lit) { return s.get_use_list(lit); } + double get_weight(unsigned clause_idx) { return s.get_weigth(clause_idx); } + unsigned num_bool_vars() const { return s.num_vars(); } + bool is_true(sat::literal lit) { return s.is_true(lit); } + bool is_true(sat::bool_var v) const { return s.is_true(sat::literal(v, false)); } + expr* atom(sat::bool_var v) { return m_atoms.get(v, nullptr); } + expr* term(unsigned id) const { return m_allterms.get(id); } + sat::bool_var atom2bool_var(expr* e) const { return m_atom2bool_var.get(e->get_id(), sat::null_bool_var); } + sat::literal mk_literal(expr* e); + void add_clause(expr* f); + void add_clause(sat::literal_vector const& lits); + void flip(sat::bool_var v) { s.flip(v); } + double reward(sat::bool_var v) { return s.reward(v); } + indexed_uint_set const& unsat() const { return s.unsat(); } + unsigned rand() { return m_rand(); } + unsigned rand(unsigned n) { return m_rand(n); } + sat::literal_vector const& root_literals() const { return m_root_literals; } + sat::literal_vector const& unit_literals() const { return m_unit_literals; } + bool is_unit(sat::literal lit) const { return m_unit_indices.contains(lit.index()); } + void reinit_relevant(); + void force_restart() { s.force_restart(); } + bool include_func_interp(func_decl* f) const { return any_of(m_plugins, [&](plugin* p) { return p && p->include_func_interp(f); }); } + + ptr_vector const& parents(expr* e) { + m_parents.reserve(e->get_id() + 1); + return m_parents[e->get_id()]; + } + + // Between plugin solvers + expr_ref get_value(expr* e); + bool set_value(expr* e, expr* v); + void new_value_eh(expr* e); + bool is_true(expr* e); + bool is_fixed(expr* e); + bool is_relevant(expr* e); + void add_constraint(expr* e); + ptr_vector const& subterms(); + ast_manager& get_manager() { return m; } + std::ostream& display(std::ostream& out) const; + std::ostream& display_all(std::ostream& out) const { return s.display(out); } + scoped_ptr& egraph(); + euf_plugin& euf(); + + void collect_statistics(statistics& st) const; + void reset_statistics(); + + }; +} diff --git a/src/ast/sls/sls_datatype_plugin.cpp b/src/ast/sls/sls_datatype_plugin.cpp new file mode 100644 index 00000000000..d4789d00691 --- /dev/null +++ b/src/ast/sls/sls_datatype_plugin.cpp @@ -0,0 +1,956 @@ +/*++ +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + sls_datatype_plugin.cpp + +Abstract: + + Algebraic Datatypes for SLS + +Author: + + Nikolaj Bjorner (nbjorner) 2024-10-14 + +Notes: + +Eager reduction to EUF: + is-c(c(t)) for each c(t) in T + acc_i(c(t_i)) = t_i for each c(..t_i..) in T + is-c(t) => t = c(...acc_j(t)..) for each acc_j(t) in T + + sum_i is-c_i(t) = 1 + is-c(t) <=> c = t for each 0-ary constructor c + + is-c(t) <=> t = c(acc_1(t)..acc_n(t)) + + s = acc(...(acc(t)) => s != t if t is recursive + + or_i t = t_i if t is a finite sort with terms t_i + + + s := acc(t) => s < t in P + a := s = acc(t), a is a unit => s < t in P + a := s = acc(t), a in Atoms => (a => s < t) in P + + s << t if there is a path P with conditions L. + L => s != t + + This disregards if acc is applied to non-matching constructor. + In this case we rely on that the interpretation of acc can be + forced. + If this is incorrect, include is-c(t) assumptions in path axioms. + + Is P sufficient? Should we just consider all possible paths of depth at most k to be safe? + Example: + C(acc(t)) == C(s) + triggers equation acc(t) = s, but the equation is implicit, so acc(t) and s are not directly + connected. + Even, the axioms extracted from P don't consider transitivity of =. + So the can-be-equal alias approximation is too strong. + We therefore add an occurs check during propagation and lazily add missed axioms. + + +Model-repair based: + +1. Initialize uninterpreted datatype nodes to hold arbitrary values. +2. Initialize datatype nodes by induced evaluation. +3. Atomic constraints are of the form for datatype terms + x = y, x = t, x != y, x != t; s = t, s != t + + violated x = y: x <- eval(y), y <- eval(x) or x, y <- fresh + violated x = t: x <- eval(t), repair t using the shape of x + violated x != y: x <- fresh, y <- fresh + violated x != t: x <- fresh, subterm y of t: y <- fresh + + acc(x) = t: eval(x) = c(u, v) acc(c(u,v)) = u -> repair(u = t) + acc(x) = t: eval(x) does not match acc -> acc(x) + has a fixed interpretation, so repair over t instead, or update interpretation of x + + uses: + model::get_fresh_value(s) + model::get_some_value(s) + +--*/ + +#include "ast/sls/sls_datatype_plugin.h" +#include "ast/sls/sls_euf_plugin.h" +#include "ast/ast_pp.h" +#include "params/sls_params.hpp" + +namespace sls { + + datatype_plugin::datatype_plugin(context& c): + plugin(c), + euf(c.euf()), + g(c.egraph()), + dt(m), + m_axioms(m), + m_values(m), + m_eval(m) { + m_fid = dt.get_family_id(); + } + + datatype_plugin::~datatype_plugin() {} + + void datatype_plugin::collect_path_axioms() { + expr* t = nullptr, *z = nullptr; + for (auto s : ctx.subterms()) { + if (dt.is_accessor(s, t) && dt.is_recursive(t) && dt.is_recursive(s)) + add_edge(s, t, m.mk_app(dt.get_constructor_is(dt.get_accessor_constructor(to_app(s)->get_decl())), t)); + if (dt.is_constructor(s) && dt.is_recursive(s)) { + for (auto arg : *to_app(s)) + add_edge(arg, s, nullptr); + } + } + expr* x = nullptr, *y = nullptr; + for (sat::bool_var v = 0; v < ctx.num_bool_vars(); ++v) { + expr* e = ctx.atom(v); + if (!e) + continue; + if (!m.is_eq(e, x, y)) + continue; + if (!dt.is_recursive(x)) + continue; + sat::literal lp(v, false), ln(v, true); + if (dt.is_accessor(x, z) && dt.is_recursive(z)) { + if (ctx.is_unit(lp)) + add_edge(y, z, nullptr); + else if (ctx.is_unit(ln)) + ; + else + add_edge(y, z, e); + } + if (dt.is_accessor(y, z) && dt.is_recursive(z)) { + if (ctx.is_unit(lp)) + add_edge(x, z, m.mk_app(dt.get_constructor_is(dt.get_accessor_constructor(to_app(y)->get_decl())), z)); + else if (ctx.is_unit(ln)) + ; + else + add_edge(x, z, e); + } + } + add_path_axioms(); + } + + void datatype_plugin::add_edge(expr* child, expr* parent, expr* cond) { + m_parents.insert_if_not_there(child, vector()).push_back({parent, expr_ref(cond, m)}); + TRACE("dt", tout << mk_bounded_pp(child, m) << " <- " << mk_bounded_pp(parent, m) << " " << mk_bounded_pp(cond, m) << "\n"); + } + + void datatype_plugin::add_path_axioms() { + ptr_vector path; + sat::literal_vector lits; + for (auto [child, parents] : m_parents) { + path.reset(); + lits.reset(); + path.push_back(child); + add_path_axioms(path, lits, parents); + } + } + + void datatype_plugin::add_path_axioms(ptr_vector& children, sat::literal_vector& lits, vector const& parents) { + for (auto const& [parent, cond] : parents) { + if (cond) + lits.push_back(~ctx.mk_literal(cond)); + if (children.contains(parent)) { + // only assert loop clauses for proper loops + if (parent == children[0]) + ctx.add_clause(lits); + if (cond) + lits.pop_back(); + continue; + } + if (children[0]->get_sort() == parent->get_sort()) { + lits.push_back(~ctx.mk_literal(m.mk_eq(children[0], parent))); + TRACE("dt", for (auto lit : lits) tout << (lit.sign() ? "~": "") << mk_pp(ctx.atom(lit.var()), m) << "\n";); + ctx.add_clause(lits); + lits.pop_back(); + } + auto child = children.back(); + if (m_parents.contains(child)) { + children.push_back(parent); + auto& parents2 = m_parents[child]; + add_path_axioms(children, lits, parents2); + children.pop_back(); + } + if (cond) + lits.pop_back(); + } + } + + void datatype_plugin::add_axioms() { + expr_ref_vector axioms(m); + for (auto t : ctx.subterms()) { + auto s = t->get_sort(); + if (dt.is_datatype(s)) + m_dts.insert_if_not_there(s, ptr_vector()).push_back(t); + + if (!is_app(t)) + continue; + auto ta = to_app(t); + auto f = ta->get_decl(); + + if (dt.is_constructor(t)) { + auto r = dt.get_constructor_is(f); + m_axioms.push_back(m.mk_app(r, t)); + auto& acc = *dt.get_constructor_accessors(f); + for (unsigned i = 0; i < ta->get_num_args(); ++i) { + auto ti = ta->get_arg(i); + m_axioms.push_back(m.mk_eq(ti, m.mk_app(acc[i], t))); + } + auto& cns = *dt.get_datatype_constructors(s); + for (auto c : cns) { + if (c != f) { + auto r2 = dt.get_constructor_is(c); + m_axioms.push_back(m.mk_not(m.mk_app(r2, t))); + } + } + continue; + } + + if (dt.is_recognizer0(f)) { + auto u = ta->get_arg(0); + auto c = dt.get_recognizer_constructor(f); + m_axioms.push_back(m.mk_iff(t, m.mk_app(dt.get_constructor_is(c), u))); + } + + if (dt.is_update_field(t)) { + NOT_IMPLEMENTED_YET(); + } + + if (dt.is_datatype(s)) { + auto& cns = *dt.get_datatype_constructors(s); + expr_ref_vector ors(m); + for (auto c : cns) { + auto r = dt.get_constructor_is(c); + ors.push_back(m.mk_app(r, t)); + } + m_axioms.push_back(m.mk_or(ors)); +#if 0 + // expanded lazily + // EUF already handles enumeration datatype case. + for (unsigned i = 0; i < cns.size(); ++i) { + auto r1 = dt.get_constructor_is(cns[i]); + for (unsigned j = i + 1; j < cns.size(); ++j) { + auto r2 = dt.get_constructor_is(cns[j]); + m_axioms.push_back(m.mk_or(m.mk_not(m.mk_app(r1, t)), m.mk_not(m.mk_app(r2, t)))); + } + } +#endif + for (auto c : cns) { + auto r = dt.get_constructor_is(c); + auto& acc = *dt.get_constructor_accessors(c); + expr_ref_vector args(m); + for (auto a : acc) + args.push_back(m.mk_app(a, t)); + m_axioms.push_back(m.mk_iff(m.mk_app(r, t), m.mk_eq(t, m.mk_app(c, args)))); + } + } + } + //collect_path_axioms(); + + TRACE("dt", for (auto a : m_axioms) tout << mk_pp(a, m) << "\n";); + + for (auto a : m_axioms) + ctx.add_constraint(a); + } + + void datatype_plugin::initialize() { + sls_params sp(ctx.get_params()); + m_axiomatic_mode = sp.dt_axiomatic(); + if (m_axiomatic_mode) + add_axioms(); + } + + expr_ref datatype_plugin::get_value(expr* e) { + if (!dt.is_datatype(e)) + return expr_ref(m); + if (m_axiomatic_mode) { + init_values(); + return expr_ref(m_values.get(g->find(e)->get_root_id()), m); + } + return expr_ref(m_eval.get(e->get_id()), m); + } + + void datatype_plugin::init_values() { + if (!m_values.empty()) + return; + TRACE("dt", g->display(tout)); + m_model = alloc(model, m); + // retrieve e-graph from sls_euf_solver: add bridge in sls_context to share e-graph + SASSERT(g); + // build top_sort similar to dt_solver.cpp + top_sort deps; + for (auto* n : g->nodes()) + if (n->is_root()) + add_dep(n, deps); + + auto trace_assignment = [&](std::ostream& out, euf::enode* n) { + for (auto sib : euf::enode_class(n)) + out << g->bpp(sib) << " "; + out << " <- " << mk_bounded_pp(m_values.get(n->get_id()), m) << "\n"; + }; + deps.topological_sort(); + expr_ref_vector args(m); + euf::enode_vector leaves, worklist; + obj_map leaf2root; + // walk topological sort in order of leaves to roots, attaching values to nodes. + for (euf::enode* n : deps.top_sorted()) { + SASSERT(n->is_root()); + unsigned id = n->get_id(); + if (m_values.get(id, nullptr)) + continue; + expr* e = n->get_expr(); + m_values.reserve(id + 1); + if (!dt.is_datatype(e)) + continue; + euf::enode* con = get_constructor(n); + if (!con) { + leaves.push_back(n); + continue; + } + auto f = con->get_decl(); + args.reset(); + bool has_null = false; + for (auto arg : euf::enode_args(con)) { + if (dt.is_datatype(arg->get_sort())) { + auto val_arg = m_values.get(arg->get_root_id()); + if (!val_arg) + has_null = true; + leaf2root.insert_if_not_there(arg->get_root(), euf::enode_vector()).push_back(n); + args.push_back(val_arg); + } + else + args.push_back(ctx.get_value(arg->get_expr())); + } + if (!has_null) { + m_values.setx(id, m.mk_app(f, args)); + m_model->register_value(m_values.get(id)); + TRACE("dt", tout << "Set interpretation "; trace_assignment(tout, n);); + } + } + + TRACE("dt", + for (euf::enode* n : deps.top_sorted()) { + tout << g->bpp(n) << ": "; + tout << g->bpp(get_constructor(n)) << " :: "; + auto s = deps.get_dep(n); + if (s) { + tout << " -> "; + for (auto t : *s) + tout << g->bpp(t) << " "; + } + tout << "\n"; + } + ); + + auto process_workitem = [&](euf::enode* n) { + if (!leaf2root.contains(n)) + return true; + bool all_processed = true; + for (auto p : leaf2root[n]) { + if (m_values.get(p->get_id(), nullptr)) + continue; + auto con = get_constructor(p); + SASSERT(con); + auto f = con->get_decl(); + args.reset(); + bool has_missing = false; + for (auto arg : euf::enode_args(con)) { + if (dt.is_datatype(arg->get_sort())) { + auto arg_val = m_values.get(arg->get_root_id()); + if (!arg_val) + has_missing = true; + args.push_back(arg_val); + } + else + args.push_back(ctx.get_value(arg->get_expr())); + } + if (has_missing) { + all_processed = false; + continue; + } + worklist.push_back(p); + SASSERT(all_of(args, [&](expr* e) { return e != nullptr; })); + m_values.setx(p->get_id(), m.mk_app(f, args)); + TRACE("dt", tout << "Patched interpretation "; trace_assignment(tout, p);); + m_model->register_value(m_values.get(p->get_id())); + } + return all_processed; + }; + + auto process_worklist = [&](euf::enode_vector& worklist) { + unsigned j = 0, sz = worklist.size(); + for (unsigned i = 0; i < worklist.size(); ++i) + if (!process_workitem(worklist[i])) + worklist[j++] = worklist[i]; + worklist.shrink(j); + return j < sz; + }; + + // attach fresh values to each leaf, walk up parents to assign them values. + while (!leaves.empty()) { + auto n = leaves.back(); + leaves.pop_back(); + SASSERT(!get_constructor(n)); + auto v = m_model->get_fresh_value(n->get_sort()); + if (!v) + v = m_model->get_some_value(n->get_sort()); + SASSERT(v); + unsigned id = n->get_id(); + m_values.setx(id, v); + TRACE("dt", tout << "Fresh interpretation "; trace_assignment(tout, n);); + worklist.reset(); + worklist.push_back(n); + while (process_worklist(worklist)) + ; + } + } + + void datatype_plugin::add_dep(euf::enode* n, top_sort& dep) { + if (!dt.is_datatype(n->get_expr())) + return; + euf::enode* con = get_constructor(n); + TRACE("dt", tout << g->bpp(n) << " con: " << g->bpp(con) << "\n";); + if (!con) + dep.insert(n, nullptr); + else if (con->num_args() == 0) + dep.insert(n, nullptr); + else + for (euf::enode* arg : euf::enode_args(con)) + dep.add(n, arg->get_root()); + } + + + void datatype_plugin::start_propagation() { + m_values.reset(); + m_model = nullptr; + } + + euf::enode* datatype_plugin::get_constructor(euf::enode* n) const { + for (auto sib : euf::enode_class(n)) + if (dt.is_constructor(sib->get_expr())) + return sib; + return nullptr; + } + + bool datatype_plugin::propagate() { + enum color_t { white, grey, black }; + svector color; + ptr_vector stack; + obj_map> sorts; + + auto set_conflict = [&](euf::enode* n) { + expr_ref_vector diseqs(m); + while (true) { + auto n2 = stack.back(); + auto con2 = get_constructor(n2); + if (n2 != con2) + diseqs.push_back(m.mk_not(m.mk_eq(n2->get_expr(), con2->get_expr()))); + if (n2->get_root() == n->get_root()) { + if (n != n2) + diseqs.push_back(m.mk_not(m.mk_eq(n->get_expr(), n2->get_expr()))); + break; + } + stack.pop_back(); + } + IF_VERBOSE(1, verbose_stream() << "cycle\n"; for (auto e : diseqs) verbose_stream() << mk_pp(e, m) << "\n";); + ctx.add_constraint(m.mk_or(diseqs)); + ++m_stats.m_num_occurs; + }; + + for (auto n : g->nodes()) { + if (!n->is_root()) + continue; + euf::enode* con = nullptr; + for (auto sib : euf::enode_class(n)) { + if (dt.is_constructor(sib->get_expr())) { + if (!con) + con = sib; + if (con && con->get_decl() != sib->get_decl()) { + ctx.add_constraint(m.mk_not(m.mk_eq(con->get_expr(), sib->get_expr()))); + ++m_stats.m_num_occurs; + } + } + } + } + + for (auto n : g->nodes()) { + if (!n->is_root()) + continue; + expr* e = n->get_expr(); + if (!dt.is_datatype(e)) + continue; + if (!ctx.is_relevant(e)) + continue; + sort* s = e->get_sort(); + sorts.insert_if_not_there(s, ptr_vector()).push_back(e); + + auto c = color.get(e->get_id(), white); + SASSERT(c != grey); + if (c == black) + continue; + + // dfs traversal of enodes, starting with n, + // with outgoing edges the arguments of con, where con + // is a node in the same congruence class as n that is a constructor. + // For every cycle accumulate a conflict. + + stack.push_back(n); + while (!stack.empty()) { + n = stack.back(); + unsigned id = n->get_root_id(); + c = color.get(id, white); + euf::enode* con; + + switch (c) { + case black: + stack.pop_back(); + break; + case grey: + case white: + color.setx(id, grey, white); + con = get_constructor(n); + if (!con) + goto done_with_node; + for (auto child : euf::enode_args(con)) { + auto c2 = color.get(child->get_root_id(), white); + switch (c2) { + case black: + break; + case grey: + set_conflict(child); + return true; + case white: + stack.push_back(child); + goto node_pushed; + } + } + done_with_node: + color[id] = black; + stack.pop_back(); + node_pushed: + break; + } + } + } + + + for (auto const& [s, elems] : sorts) { + auto sz = s->get_num_elements(); + + if (!sz.is_finite() || sz.size() >= elems.size()) + continue; + ctx.add_constraint(m.mk_not(m.mk_distinct((unsigned)sz.size() + 1, elems.data()))); + } + + return false; + } + + bool datatype_plugin::include_func_interp(func_decl* f) const { + if (!dt.is_accessor(f)) + return false; + func_decl* con_decl = dt.get_accessor_constructor(f); + for (euf::enode* app : g->enodes_of(f)) { + euf::enode* con = get_constructor(app->get_arg(0)); + if (con && con->get_decl() != con_decl) + return true; + } + return false; + } + + std::ostream& datatype_plugin::display(std::ostream& out) const { + for (auto a : m_axioms) + out << mk_bounded_pp(a, m, 3) << "\n"; + return out; + } + + void datatype_plugin::propagate_literal(sat::literal lit) { + if (m_axiomatic_mode) + euf.propagate_literal(lit); + else + propagate_literal_model_building(lit); + } + + void datatype_plugin::propagate_literal_model_building(sat::literal lit) { + if (!ctx.is_true(lit)) + return; + auto a = ctx.atom(lit.var()); + if (!a || !is_app(a)) + return; + repair_down(to_app(a)); + } + + bool datatype_plugin::is_sat() { return true; } + + void datatype_plugin::register_term(expr* e) { + expr* t = nullptr; + if (dt.is_accessor(e, t)) { + auto f = to_app(e)->get_decl(); + m_occurs.insert_if_not_there(f, expr_set()).insert(e); + m_eval_accessor.insert_if_not_there(f, obj_map()); + } + } + + + bool datatype_plugin::repair_down(app* e) { + expr* t, * s; + auto v0 = eval0(e); + auto v1 = eval1(e); + if (v0 == v1) + return true; + IF_VERBOSE(2, verbose_stream() << "dt-repair-down " << mk_bounded_pp(e, m) << " " << v0 << " <- " << v1 << "\n"); + if (dt.is_constructor(e)) + repair_down_constructor(e, v0, v1); + else if (dt.is_accessor(e, t)) + repair_down_accessor(e, t, v0); + else if (dt.is_recognizer(e, t)) + repair_down_recognizer(e, t); + else if (m.is_eq(e, s, t)) + repair_down_eq(e, s, t); + else if (m.is_distinct(e)) + repair_down_distinct(e); + else { + UNREACHABLE(); + } + return false; + } + + // + // C(t) <- C(s) then repair t <- s + // C(t) <- D(s) then fail the repair. + // + void datatype_plugin::repair_down_constructor(app* e, expr* v0, expr* v1) { + SASSERT(dt.is_constructor(v0)); + SASSERT(dt.is_constructor(v1)); + SASSERT(e->get_decl() == to_app(v1)->get_decl()); + if (e->get_decl() == to_app(v0)->get_decl()) { + for (unsigned i = 0; i < e->get_num_args(); ++i) { + auto w0 = to_app(v0)->get_arg(i); + auto w1 = to_app(v1)->get_arg(i); + if (w0 == w1) + continue; + expr* arg = e->get_arg(i); + set_eval0(arg, w0); + ctx.new_value_eh(arg); + } + } + } + + // + // A_D(t) <- s, val(t) = D(..s'..) then update val(t) to agree with s + // A_D(t) <- s, val(t) = C(..) then set t to D(...s...) + // , eval(val(A_D(t))) = s' then update eval(val(A_D,(t))) to s' + void datatype_plugin::repair_down_accessor(app* e, expr* t, expr* v0) { + auto f = e->get_decl(); + auto c = dt.get_accessor_constructor(f); + auto val_t = eval0(t); + SASSERT(dt.is_constructor(val_t)); + expr_ref_vector args(m); + auto const& accs = *dt.get_constructor_accessors(c); + unsigned i; + for (i = 0; i < accs.size(); ++i) { + if (accs[i] == f) + break; + } + SASSERT(i < accs.size()); + if (to_app(val_t)->get_decl() == c) { + if (to_app(val_t)->get_arg(i) == v0) + return; + args.append(accs.size(), to_app(val_t)->get_args()); + args[i] = v0; + expr* new_val_t = m.mk_app(c, args); + set_eval0(t, new_val_t); + ctx.new_value_eh(t); + return; + } + if (ctx.rand(5) != 0) { + update_eval_accessor(e, val_t, v0); + return; + } + for (unsigned j = 0; j < accs.size(); ++j) { + if (i == j) + args[i] = v0; + else + args[j] = m_model->get_some_value(accs[j]->get_range()); + } + expr* new_val_t = m.mk_app(c, args); + set_eval0(t, new_val_t); + ctx.new_value_eh(t); + } + + void datatype_plugin::repair_down_recognizer(app* e, expr* t) { + auto bv = ctx.atom2bool_var(e); + auto is_true = ctx.is_true(bv); + auto c = dt.get_recognizer_constructor(e->get_decl()); + auto val_t = eval0(t); + auto const& cons = *dt.get_datatype_constructors(t->get_sort()); + + auto set_to_instance = [&](func_decl* c) { + auto const& accs = *dt.get_constructor_accessors(c); + expr_ref_vector args(m); + for (auto a : accs) + args.push_back(m_model->get_some_value(a->get_range())); + set_eval0(t, m.mk_app(c, args)); + ctx.new_value_eh(t); + }; + auto different_constructor = [&](func_decl* c) { + unsigned i = 0; + func_decl* c_new = nullptr; + for (auto c2 : cons) + if (c2 != c && ctx.rand(++i) == 0) + c_new = c2; + return c_new; + }; + + SASSERT(dt.is_constructor(val_t)); + if (c == to_app(val_t)->get_decl() && is_true) + return; + if (c != to_app(val_t)->get_decl() && !is_true) + return; + if (ctx.rand(10) == 0) + ctx.flip(bv); + else if (is_true) + set_to_instance(c); + else if (cons.size() == 1) + ctx.flip(bv); + else + set_to_instance(different_constructor(c)); + } + + void datatype_plugin::repair_down_eq(app* e, expr* s, expr* t) { + auto bv = ctx.atom2bool_var(e); + auto is_true = ctx.is_true(bv); + auto vs = eval0(s); + auto vt = eval0(t); + if (is_true && vs == vt) + return; + if (!is_true && vs != vt) + return; + + if (is_true) { + auto coin = ctx.rand(5); + if (coin <= 1) { + set_eval0(s, vt); + ctx.new_value_eh(s); + return; + } + if (coin <= 3) { + set_eval0(t, vs); + ctx.new_value_eh(t); + } + if (true) { + auto new_v = m_model->get_some_value(s->get_sort()); + set_eval0(s, new_v); + set_eval0(t, new_v); + ctx.new_value_eh(s); + ctx.new_value_eh(t); + return; + } + } + auto coin = ctx.rand(10); + if (coin <= 4) { + auto new_v = m_model->get_some_value(s->get_sort()); + set_eval0(s, new_v); + ctx.new_value_eh(s); + return; + } + if (coin <= 9) { + auto new_v = m_model->get_some_value(s->get_sort()); + set_eval0(t, new_v); + ctx.new_value_eh(t); + return; + } + } + + void datatype_plugin::repair_down_distinct(app* e) { + auto bv = ctx.atom2bool_var(e); + auto is_true = ctx.is_true(bv); + unsigned sz = e->get_num_args(); + for (unsigned i = 0; i < sz; ++i) { + auto val1 = eval0(e->get_arg(i)); + for (unsigned j = i + 1; j < sz; ++j) { + auto val2 = eval0(e->get_arg(j)); + if (val1 != val2) + continue; + if (!is_true) + return; + if (ctx.rand(2) == 0) + std::swap(i, j); + auto new_v = m_model->get_some_value(e->get_arg(i)->get_sort()); + set_eval0(e->get_arg(i), new_v); + ctx.new_value_eh(e->get_arg(i)); + return; + } + } + if (is_true) + return; + if (sz == 1) { + ctx.flip(bv); + return; + } + unsigned i = ctx.rand(sz); + unsigned j = ctx.rand(sz-1); + if (j == i) + ++j; + if (ctx.rand(2) == 0) + std::swap(i, j); + set_eval0(e->get_arg(i), eval0(e->get_arg(j))); + } + + void datatype_plugin::repair_up(app* e) { + IF_VERBOSE(2, verbose_stream() << "dt-repair-up " << mk_bounded_pp(e, m) << "\n"); + expr* t; + auto v0 = eval0(e); + auto v1 = eval1(e); + if (v0 == v1) + return; + if (dt.is_constructor(e)) + set_eval0(e, v1); + else if (m.is_bool(e)) + ctx.flip(ctx.atom2bool_var(e)); + else if (dt.is_accessor(e, t)) + repair_up_accessor(e, t, v1); + else { + UNREACHABLE(); + } + } + + void datatype_plugin::repair_up_accessor(app* e, expr* t, expr* v1) { + auto v_t = eval0(t); + auto f = e->get_decl(); + SASSERT(dt.is_constructor(v_t)); + auto c = dt.get_accessor_constructor(f); + if (to_app(v_t)->get_decl() != c) + update_eval_accessor(e, v_t, v1); + + set_eval0(e, v1); + } + + expr_ref datatype_plugin::eval1(expr* e) { + expr* s = nullptr, * t = nullptr; + if (m.is_eq(e, s, t)) + return expr_ref(m.mk_bool_val(eval0rec(s) == eval0rec(t)), m); + if (m.is_distinct(e)) { + expr_ref_vector args(m); + for (auto arg : *to_app(e)) + args.push_back(eval0(arg)); + bool d = true; + for (unsigned i = 0; i < args.size(); ++i) + for (unsigned j = i + 1; i < args.size(); ++j) + d &= args.get(i) != args.get(j); + return expr_ref(m.mk_bool_val(d), m); + } + if (dt.is_accessor(e, t)) { + auto f = to_app(e)->get_decl(); + auto v = eval0rec(t); + return eval_accessor(f, v); + } + if (dt.is_constructor(e)) { + expr_ref_vector args(m); + for (auto arg : *to_app(e)) + args.push_back(eval0rec(arg)); + return expr_ref(m.mk_app(to_app(e)->get_decl(), args), m); + } + if (dt.is_recognizer(e, t)) { + auto v = eval0rec(t); + SASSERT(dt.is_constructor(v)); + auto c = dt.get_recognizer_constructor(to_app(e)->get_decl()); + return expr_ref(m.mk_bool_val(c == to_app(v)->get_decl()), m); + } + return eval0(e); + } + + expr_ref datatype_plugin::eval0rec(expr* e) { + auto v = m_eval.get(e->get_id(), nullptr); + if (v) + return expr_ref(v, m); + if (!is_app(e) || to_app(e)->get_family_id() != m_fid) + return ctx.get_value(e); + auto w = eval1(e); + m_eval.set(e->get_id(), w); + return w; + } + + expr_ref datatype_plugin::eval_accessor(func_decl* f, expr* t) { + auto& t2val = m_eval_accessor[f]; + if (!t2val.contains(t)) { + auto val = m_model->get_some_value(f->get_range()); + m.inc_ref(t); + m.inc_ref(val); + } + return expr_ref(t2val[t], m); + } + + void datatype_plugin::update_eval_accessor(app* e, expr* t, expr* value) { + func_decl* f = e->get_decl(); + + auto& t2val = m_eval_accessor[f]; + expr* old_value = nullptr; + t2val.find(t, old_value); + if (old_value == value) + ; + else if (old_value) { + t2val[t] = value; + m.inc_ref(value); + m.dec_ref(old_value); + } + else { + m.inc_ref(t); + m.inc_ref(value); + t2val.insert(t, value); + } + + for (expr* b : m_occurs[f]) { + if (b == e) + continue; + expr* a; + VERIFY(dt.is_accessor(b, a)); + auto v_a = eval0(a); + if (v_a.get() == t) { + set_eval0(b, value); + ctx.new_value_eh(b); + } + } + } + + void datatype_plugin::del_eval_accessor() { + ptr_vector kv; + for (auto& [f, t2val] : m_eval_accessor) + for (auto& [k, val] : t2val) + kv.push_back(k), kv.push_back(val); + for (auto k : kv) + m.dec_ref(k); + } + + expr_ref datatype_plugin::eval0(expr* n) { + if (!dt.is_datatype(n->get_sort())) + return ctx.get_value(n); + auto v = m_eval.get(n->get_id(), nullptr); + if (v) + return expr_ref(v, m); + set_eval0(n, m_model->get_some_value(n->get_sort())); + return expr_ref(m_eval.get(n->get_id()), m); + } + + void datatype_plugin::set_eval0(expr* e, expr* value) { + if (dt.is_datatype(e->get_sort())) + m_eval[e->get_id()] = value; + else + ctx.set_value(e, value); + } + + expr_ref datatype_plugin::eval0(euf::enode* n) { + return eval0(n->get_root()->get_expr()); + } + + void datatype_plugin::collect_statistics(statistics& st) const { + st.update("sls-dt-axioms", m_axioms.size()); + st.update("sls-dt-occurs-conflicts", m_stats.m_num_occurs); + } + + void datatype_plugin::reset_statistics() {} + +} diff --git a/src/ast/sls/sls_datatype_plugin.h b/src/ast/sls/sls_datatype_plugin.h new file mode 100644 index 00000000000..f06e24963bc --- /dev/null +++ b/src/ast/sls/sls_datatype_plugin.h @@ -0,0 +1,107 @@ +/*++ +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + sls_datatype_plugin.h + +Abstract: + + Algebraic Datatypes for SLS + +Author: + + Nikolaj Bjorner (nbjorner) 2024-10-14 + +--*/ +#pragma once + +#include "ast/sls/sls_context.h" +#include "ast/datatype_decl_plugin.h" +#include "util/top_sort.h" + +namespace sls { + + class euf_plugin; + + class datatype_plugin : public plugin { + struct stats { + unsigned m_num_occurs = 0; + void reset() { memset(this, 0, sizeof(*this)); } + }; + struct parent_t { + expr* parent; + expr_ref condition; + }; + euf_plugin& euf; + scoped_ptr& g; + obj_map> m_dts; + obj_map> m_parents; + + bool m_axiomatic_mode = true; + mutable datatype_util dt; + expr_ref_vector m_axioms, m_values, m_eval; + model_ref m_model; + stats m_stats; + + void collect_path_axioms(); + void add_edge(expr* child, expr* parent, expr* cond); + void add_path_axioms(); + void add_path_axioms(ptr_vector& children, sat::literal_vector& lits, vector const& parents); + void add_axioms(); + + void init_values(); + void add_dep(euf::enode* n, top_sort& dep); + + euf::enode* get_constructor(euf::enode* n) const; + + // f -> v_t -> val + // e = A(t) + // val(t) <- val + // + typedef obj_hashtable expr_set; + obj_map> m_eval_accessor; + obj_map m_occurs; + expr_ref eval1(expr* e); + expr_ref eval0(euf::enode* n); + expr_ref eval0(expr* n); + expr_ref eval0rec(expr* n); + expr_ref eval_accessor(func_decl* f, expr* t); + void update_eval_accessor(app* e, expr* t, expr* value); + void del_eval_accessor(); + void set_eval0(expr* e, expr* val); + + void repair_down_constructor(app* e, expr* v0, expr* v1); + void repair_down_accessor(app* e, expr* t, expr* v1); + void repair_down_recognizer(app* e, expr* t); + void repair_down_eq(app* e, expr* s, expr* t); + void repair_down_distinct(app* e); + void repair_up_accessor(app* e, expr* t, expr* v0); + void propagate_literal_model_building(sat::literal lit); + + public: + datatype_plugin(context& c); + ~datatype_plugin() override; + family_id fid() override { return m_fid; } + expr_ref get_value(expr* e) override; + void initialize() override; + void start_propagation() override; + void propagate_literal(sat::literal lit) override; + bool propagate() override; + bool is_sat() override; + void register_term(expr* e) override; + + bool set_value(expr* e, expr* v) override { return false; } + void repair_literal(sat::literal lit) override {} + bool include_func_interp(func_decl* f) const override; + + bool repair_down(app* e) override; + void repair_up(app* e) override; + + std::ostream& display(std::ostream& out) const override; + void collect_statistics(statistics& st) const override; + void reset_statistics() override; + + }; + +} diff --git a/src/ast/sls/sls_euf_plugin.cpp b/src/ast/sls/sls_euf_plugin.cpp new file mode 100644 index 00000000000..48ab3cb9684 --- /dev/null +++ b/src/ast/sls/sls_euf_plugin.cpp @@ -0,0 +1,489 @@ +/*++ +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + sls_euf_plugin.cpp + +Abstract: + + Congruence Closure for SLS + +Author: + + Nikolaj Bjorner (nbjorner) 2024-06-24 + +Todo: + +- try incremental CC with backtracking for changing assignments +- try determining plateau moves. +- try generally a model rotation move. + +--*/ + +#include "ast/sls/sls_euf_plugin.h" +#include "ast/ast_ll_pp.h" +#include "ast/ast_pp.h" +#include "params/sls_params.hpp" + + +namespace sls { + + euf_plugin::euf_plugin(context& c): + plugin(c), + m_values(8U, value_hash(*this), value_eq(*this)) { + m_fid = user_sort_family_id; + } + + euf_plugin::~euf_plugin() {} + + void euf_plugin::initialize() { + sls_params sp(ctx.get_params()); + m_incremental_mode = sp.euf_incremental(); + m_incremental = 1 == m_incremental_mode; + IF_VERBOSE(2, verbose_stream() << "sls.euf: incremental " << m_incremental_mode << "\n"); + } + + void euf_plugin::start_propagation() { + if (m_incremental_mode == 2) + m_incremental = !m_incremental; + m_g = alloc(euf::egraph, m); + std::function dj = [&](std::ostream& out, void* j) { + out << "lit " << to_literal(reinterpret_cast(j)); + }; + m_g->set_display_justification(dj); + init_egraph(*m_g, !m_incremental); + } + + void euf_plugin::register_term(expr* e) { + if (!is_app(e)) + return; + if (!is_uninterp(e)) + return; + app* a = to_app(e); + if (a->get_num_args() == 0) + return; + auto f = a->get_decl(); + if (!m_app.contains(f)) + m_app.insert(f, ptr_vector()); + m_app[f].push_back(a); + } + + unsigned euf_plugin::value_hash::operator()(app* t) const { + unsigned r = 0; + for (auto arg : *t) + r *= 3, r += cc.ctx.get_value(arg)->hash(); + return r; + } + + bool euf_plugin::value_eq::operator()(app* a, app* b) const { + SASSERT(a->get_num_args() == b->get_num_args()); + for (unsigned i = a->get_num_args(); i-- > 0; ) + if (cc.ctx.get_value(a->get_arg(i)) != cc.ctx.get_value(b->get_arg(i))) + return false; + return true; + } + + void euf_plugin::propagate_literal_incremental(sat::literal lit) { + m_replay_stack.push_back(lit); + replay(); + } + + sat::literal euf_plugin::resolve_conflict() { + auto& g = *m_g; + SASSERT(g.inconsistent()); + ++m_stats.m_num_conflicts; + unsigned n = 0; + sat::literal_vector lits; + sat::literal flit = sat::null_literal; + ptr_vector explain; + g.begin_explain(); + g.explain(explain, nullptr); + g.end_explain(); + double reward = -1; + TRACE("enf", + for (auto p : explain) { + sat::literal l = to_literal(p); + tout << l << " " << mk_pp(ctx.atom(l.var()), m) << " " << ctx.is_unit(l) << "\n"; + }); + for (auto p : explain) { + sat::literal l = to_literal(p); + CTRACE("euf", !ctx.is_true(l), tout << "not true " << l << "\n"; ctx.display(tout);); + SASSERT(ctx.is_true(l)); + + if (ctx.is_unit(l)) + continue; + if (!lits.contains(~l)) + lits.push_back(~l); + + if (ctx.reward(l.var()) > reward) + n = 0, reward = ctx.reward(l.var()); + + if (ctx.rand(++n) == 0) + flit = l; + } + // flip the last literal on the replay stack + IF_VERBOSE(10, verbose_stream() << "sls.euf - flip " << flit << "\n"); + ctx.add_clause(lits); + return flit; + } + + void euf_plugin::resolve() { + auto& g = *m_g; + if (!g.inconsistent()) + return; + + auto flit = resolve_conflict(); + sat::literal slit; + if (flit == sat::null_literal) + return; + do { + slit = m_stack.back(); + g.pop(1); + m_replay_stack.push_back(slit); + m_stack.pop_back(); + } + while (slit != flit); + ctx.flip(flit.var()); + m_replay_stack.back().neg(); + + } + + void euf_plugin::replay() { + while (!m_replay_stack.empty()) { + auto l = m_replay_stack.back(); + m_replay_stack.pop_back(); + propagate_literal_incremental_step(l); + if (m_g->inconsistent()) + resolve(); + } + } + + + void euf_plugin::propagate_literal_incremental_step(sat::literal lit) { + SASSERT(ctx.is_true(lit)); + auto e = ctx.atom(lit.var()); + expr* x, * y; + auto& g = *m_g; + + if (!e) + return; + + TRACE("euf", tout << "propagate " << lit << "\n"); + m_stack.push_back(lit); + g.push(); + if (m.is_eq(e, x, y)) { + if (lit.sign()) + g.new_diseq(g.find(e), to_ptr(lit)); + else + g.merge(g.find(x), g.find(y), to_ptr(lit)); + g.merge(g.find(e), g.find(m.mk_bool_val(!lit.sign())), to_ptr(lit)); + } + else if (!lit.sign() && m.is_distinct(e)) { + auto n = to_app(e)->get_num_args(); + for (unsigned i = 0; i < n; ++i) { + expr* a = to_app(e)->get_arg(i); + for (unsigned j = i + 1; j < n; ++j) { + auto b = to_app(e)->get_arg(j); + expr_ref eq(m.mk_eq(a, b), m); + auto c = g.find(eq); + if (!c) { + euf::enode* args[2] = { g.find(a), g.find(b) }; + c = g.mk(eq, 0, 2, args); + } + g.new_diseq(c, to_ptr(lit)); + g.merge(c, g.find(m.mk_false()), to_ptr(lit)); + } + } + } +// else if (m.is_bool(e) && is_app(e) && to_app(e)->get_family_id() == basic_family_id) +// ; + else { + auto a = g.find(e); + auto b = g.find(m.mk_bool_val(!lit.sign())); + g.merge(a, b, to_ptr(lit)); + } + g.propagate(); + } + + void euf_plugin::propagate_literal(sat::literal lit) { + if (m_incremental) + propagate_literal_incremental(lit); + else + propagate_literal_non_incremental(lit); + } + + void euf_plugin::propagate_literal_non_incremental(sat::literal lit) { + SASSERT(ctx.is_true(lit)); + auto e = ctx.atom(lit.var()); + expr* x, * y; + + if (!e) + return; + + auto block = [&](euf::enode* a, euf::enode* b) { + TRACE("euf", tout << "block " << m_g->bpp(a) << " != " << m_g->bpp(b) << "\n"); + if (a->get_root() != b->get_root()) + return; + ptr_vector explain; + m_g->explain_eq(explain, nullptr, a, b); + m_g->end_explain(); + unsigned n = 1; + sat::literal_vector lits; + sat::literal flit = sat::null_literal; + if (!ctx.is_unit(lit)) { + flit = lit; + lits.push_back(~lit); + } + for (auto p : explain) { + sat::literal l = to_literal(p); + if (!ctx.is_true(l)) + return; + if (ctx.is_unit(l)) + continue; + lits.push_back(~l); + if (ctx.rand(++n) == 0) + flit = l; + } + ctx.add_clause(lits); + ++m_stats.m_num_conflicts; + if (flit != sat::null_literal) + ctx.flip(flit.var()); + }; + + if (lit.sign() && m.is_eq(e, x, y)) + block(m_g->find(x), m_g->find(y)); + else if (!lit.sign() && m.is_distinct(e)) { + auto n = to_app(e)->get_num_args(); + for (unsigned i = 0; i < n; ++i) { + auto a = m_g->find(to_app(e)->get_arg(i)); + for (unsigned j = i + 1; j < n; ++j) { + auto b = m_g->find(to_app(e)->get_arg(j)); + block(a, b); + } + } + } + else if (lit.sign()) { + auto a = m_g->find(e); + auto b = m_g->find(m.mk_true()); + block(a, b); + } + } + + void euf_plugin::init_egraph(euf::egraph& g, bool merge_eqs) { + ptr_vector args; + m_stack.reset(); + for (auto t : ctx.subterms()) { + args.reset(); + if (is_app(t)) + for (auto* arg : *to_app(t)) + args.push_back(g.find(arg)); + g.mk(t, 0, args.size(), args.data()); + } + if (!g.find(m.mk_true())) + g.mk(m.mk_true(), 0, 0, nullptr); + if (!g.find(m.mk_false())) + g.mk(m.mk_false(), 0, 0, nullptr); + + // merge all equalities + // check for conflict with disequalities during propagation + if (merge_eqs) { + TRACE("euf", tout << "root literals " << ctx.root_literals() << "\n"); + for (auto lit : ctx.root_literals()) { + if (!ctx.is_true(lit)) + lit.neg(); + auto e = ctx.atom(lit.var()); + expr* x, * y; + if (e && m.is_eq(e, x, y) && !lit.sign()) + g.merge(g.find(x), g.find(y), to_ptr(lit)); + else if (!lit.sign()) + g.merge(g.find(e), g.find(m.mk_true()), to_ptr(lit)); + } + g.propagate(); + + if (g.inconsistent()) + resolve_conflict(); + } + + typedef obj_map map1; + typedef obj_map map2; + + m_num_elems = alloc(map1); + m_root2value = alloc(map2); + m_pinned = alloc(expr_ref_vector, m); + + for (auto n : g.nodes()) { + if (n->is_root() && is_user_sort(n->get_sort())) { + // verbose_stream() << "init root " << g.pp(n) << "\n"; + unsigned num = 0; + m_num_elems->find(n->get_sort(), num); + expr* v = m.mk_model_value(num, n->get_sort()); + m_pinned->push_back(v); + m_root2value->insert(n, v); + m_num_elems->insert(n->get_sort(), num + 1); + } + } + } + + expr_ref euf_plugin::get_value(expr* e) { + if (m.is_model_value(e)) + return expr_ref(e, m); + + if (!m_g) { + m_g = alloc(euf::egraph, m); + init_egraph(*m_g, true); + } + auto n = m_g->find(e)->get_root(); + VERIFY(m_root2value->find(n, e)); + return expr_ref(e, m); + } + + bool euf_plugin::include_func_interp(func_decl* f) const { + return is_uninterp(f) && f->get_arity() > 0; + } + + bool euf_plugin::is_sat() { + for (auto& [f, ts] : m_app) { + if (ts.size() <= 1) + continue; + m_values.reset(); + for (auto* t : ts) { + app* u; + if (!ctx.is_relevant(t)) + continue; + if (m_values.find(t, u)) { + if (ctx.get_value(t) != ctx.get_value(u)) + return false; + } + else + m_values.insert(t); + } + } + // validate_model(); + return true; + } + + void euf_plugin::validate_model() { + auto& g = *m_g; + for (auto lit : ctx.root_literals()) { + euf::enode* a, * b; + if (!ctx.is_true(lit)) + continue; + auto e = ctx.atom(lit.var()); + if (!e) + continue; + if (!ctx.is_relevant(e)) + continue; + if (m.is_distinct(e)) + continue; + + if (m.is_eq(e)) { + a = g.find(to_app(e)->get_arg(0)); + b = g.find(to_app(e)->get_arg(1)); + } + if (lit.sign() && m.is_eq(e)) { + if (a->get_root() == b->get_root()) { + IF_VERBOSE(0, verbose_stream() << "not disequal " << lit << " " << mk_pp(e, m) << "\n"); + ctx.display(verbose_stream()); + UNREACHABLE(); + } + } + else if (!lit.sign() && m.is_eq(e)) { + if (a->get_root() != b->get_root()) { + IF_VERBOSE(0, verbose_stream() << "not equal " << lit << " " << mk_pp(e, m) << "\n"); + //UNREACHABLE(); + } + } + else if (to_app(e)->get_family_id() != basic_family_id && lit.sign() && g.find(e)->get_root() != g.find(m.mk_false())->get_root()) { + IF_VERBOSE(0, verbose_stream() << "not alse " << lit << " " << mk_pp(e, m) << "\n"); + //UNREACHABLE(); + } + else if (to_app(e)->get_family_id() != basic_family_id && !lit.sign() && g.find(e)->get_root() != g.find(m.mk_true())->get_root()) { + IF_VERBOSE(0, verbose_stream() << "not true " << lit << " " << mk_pp(e, m) << "\n"); + //UNREACHABLE(); + } + + } + } + + bool euf_plugin::propagate() { + bool new_constraint = false; + for (auto & [f, ts] : m_app) { + if (ts.size() <= 1) + continue; + m_values.reset(); + for (auto * t : ts) { + app* u; + if (!ctx.is_relevant(t)) + continue; + if (m_values.find(t, u)) { + if (ctx.get_value(t) == ctx.get_value(u)) + continue; + expr_ref_vector ors(m); + for (unsigned i = t->get_num_args(); i-- > 0; ) + ors.push_back(m.mk_not(m.mk_eq(t->get_arg(i), u->get_arg(i)))); + ors.push_back(m.mk_eq(t, u)); +#if 0 + verbose_stream() << "conflict: " << mk_bounded_pp(t, m) << " != " << mk_bounded_pp(u, m) << "\n"; + verbose_stream() << "value " << ctx.get_value(t) << " != " << ctx.get_value(u) << "\n"; + for (unsigned i = t->get_num_args(); i-- > 0; ) + verbose_stream() << ctx.get_value(t->get_arg(i)) << " == " << ctx.get_value(u->get_arg(i)) << "\n"; +#endif + expr_ref fml(m.mk_or(ors), m); + ctx.add_constraint(fml); + new_constraint = true; + + } + else + m_values.insert(t); + } + } + + for (auto lit : ctx.root_literals()) { + if (!ctx.is_true(lit)) + continue; + auto e = ctx.atom(lit.var()); + if (lit.sign() && e && m.is_distinct(e)) { + auto n = to_app(e)->get_num_args(); + expr_ref_vector eqs(m); + for (unsigned i = 0; i < n; ++i) { + auto a = m_g->find(to_app(e)->get_arg(i)); + for (unsigned j = i + 1; j < n; ++j) { + auto b = m_g->find(to_app(e)->get_arg(j)); + if (a->get_root() == b->get_root()) + goto done_distinct; + eqs.push_back(m.mk_eq(a->get_expr(), b->get_expr())); + } + } + // distinct(a, b, c) or a = b or a = c or b = c + eqs.push_back(e); + ctx.add_constraint(m.mk_or(eqs)); + new_constraint = true; + done_distinct: + ; + } + } + + return new_constraint; + } + + std::ostream& euf_plugin::display(std::ostream& out) const { + if (m_g) + m_g->display(out); + + for (auto& [f, ts] : m_app) { + for (auto* t : ts) + out << mk_bounded_pp(t, m) << "\n"; + out << "\n"; + } + return out; + } + + void euf_plugin::collect_statistics(statistics& st) const { + st.update("sls-euf-conflict", m_stats.m_num_conflicts); + } + + void euf_plugin::reset_statistics() { + m_stats.reset(); + } +} diff --git a/src/ast/sls/sls_euf_plugin.h b/src/ast/sls/sls_euf_plugin.h new file mode 100644 index 00000000000..a9e5032d25a --- /dev/null +++ b/src/ast/sls/sls_euf_plugin.h @@ -0,0 +1,96 @@ +/*++ +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + sls_euf_plugin.h + +Abstract: + + Congruence Closure for SLS + +Author: + + Nikolaj Bjorner (nbjorner) 2024-06-24 + +--*/ +#pragma once + +#include "util/hashtable.h" +#include "ast/sls/sls_context.h" +#include "ast/euf/euf_egraph.h" + +namespace sls { + + class euf_plugin : public plugin { + struct stats { + unsigned m_num_conflicts = 0; + void reset() { memset(this, 0, sizeof(*this)); } + }; + obj_map> m_app; + struct value_hash { + euf_plugin& cc; + value_hash(euf_plugin& cc) : cc(cc) {} + unsigned operator()(app* t) const; + }; + struct value_eq { + euf_plugin& cc; + value_eq(euf_plugin& cc) : cc(cc) {} + bool operator()(app* a, app* b) const; + }; + hashtable m_values; + + + + bool m_incremental = false; + unsigned m_incremental_mode = 0; + stats m_stats; + + scoped_ptr m_g; + scoped_ptr> m_num_elems; + scoped_ptr> m_root2value; + scoped_ptr m_pinned; + + void init_egraph(euf::egraph& g, bool merge_eqs); + sat::literal_vector m_stack, m_replay_stack; + void propagate_literal_incremental(sat::literal lit); + void propagate_literal_incremental_step(sat::literal lit); + void resolve(); + sat::literal resolve_conflict(); + void replay(); + + void propagate_literal_non_incremental(sat::literal lit); + bool is_user_sort(sort* s) { return s->get_family_id() == user_sort_family_id; } + + size_t* to_ptr(sat::literal l) { return reinterpret_cast((size_t)(l.index() << 4)); }; + sat::literal to_literal(size_t* p) { return sat::to_literal(static_cast(reinterpret_cast(p) >> 4)); }; + + void validate_model(); + + public: + euf_plugin(context& c); + ~euf_plugin() override; + expr_ref get_value(expr* e) override; + void initialize() override; + void start_propagation() override; + void propagate_literal(sat::literal lit) override; + bool propagate() override; + bool is_sat() override; + void register_term(expr* e) override; + std::ostream& display(std::ostream& out) const override; + bool set_value(expr* e, expr* v) override { return false; } + bool include_func_interp(func_decl* f) const override; + + void repair_up(app* e) override {} + bool repair_down(app* e) override { return false; } + void repair_literal(sat::literal lit) override {} + + void collect_statistics(statistics& st) const override; + void reset_statistics() override; + + + + scoped_ptr& egraph() { return m_g; } + }; + +} diff --git a/src/ast/sls/sls_smt_plugin.cpp b/src/ast/sls/sls_smt_plugin.cpp new file mode 100644 index 00000000000..2b704fa3f15 --- /dev/null +++ b/src/ast/sls/sls_smt_plugin.cpp @@ -0,0 +1,315 @@ +/*++ +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + sls_smt_plugin.cpp + +Abstract: + + A Stochastic Local Search (SLS) Plugin. + +Author: + + Nikolaj Bjorner (nbjorner) 2024-07-10 + +--*/ + + +#include "ast/sls/sls_smt_plugin.h" +#include "ast/for_each_expr.h" +#include "ast/bv_decl_plugin.h" + +namespace sls { + + smt_plugin::smt_plugin(smt_context& ctx) : + ctx(ctx), + m(ctx.get_manager()), + m_sls(), + m_sync(), + m_smt2sync_tr(m, m_sync), + m_smt2sls_tr(m, m_sls), + m_sync_uninterp(m_sync), + m_sls_uninterp(m_sls), + m_sync_values(m_sync), + m_context(m_sls, *this) + { + } + + smt_plugin::~smt_plugin() { + SASSERT(!m_ddfw); + } + + void smt_plugin::check(expr_ref_vector const& fmls, vector const& clauses) { + SASSERT(!m_ddfw); + // set up state for local search theory_sls here + m_result = l_undef; + m_completed = false; + m_units.reset(); + m_has_units = false; + m_sls_model = nullptr; + m_ddfw = alloc(sat::ddfw); + m_ddfw->set_plugin(this); + m_ddfw->updt_params(ctx.get_params()); + + for (auto const& clause : clauses) { + m_ddfw->add(clause.size(), clause.data()); + for (auto lit : clause) + add_shared_var(lit.var(), lit.var()); + } + + for (auto v : m_shared_bool_vars) { + expr* e = ctx.bool_var2expr(v); + if (!e) + continue; + m_context.register_atom(v, m_smt2sls_tr(e)); + for (auto t : subterms::all(expr_ref(e, m))) + add_shared_term(t); + } + + for (auto fml : fmls) + m_context.add_constraint(m_smt2sls_tr(fml)); + + for (unsigned v = 0; v < ctx.get_num_bool_vars(); ++v) { + expr* e = ctx.bool_var2expr(v); + if (!e) + continue; + + expr_ref sls_e(m_sls); + sls_e = m_smt2sls_tr(e); + auto w = m_context.atom2bool_var(sls_e); + if (w == sat::null_bool_var) + continue; + add_shared_var(v, w); + for (auto t : subterms::all(expr_ref(e, m))) + add_shared_term(t); + } + + m_thread = std::thread([this]() { run(); }); + } + + void smt_plugin::run() { + if (!m_ddfw) + return; + m_result = m_ddfw->check(0, nullptr); + m_ddfw->collect_statistics(m_st); + IF_VERBOSE(1, verbose_stream() << "sls-result " << m_result << "\n"); + m_completed = true; + } + + void smt_plugin::finalize(model_ref& mdl, ::statistics& st) { + auto* d = m_ddfw; + if (!d) + return; + bool canceled = !m_completed; + IF_VERBOSE(3, verbose_stream() << "finalize\n"); + if (!m_completed) + d->rlimit().cancel(); + if (m_thread.joinable()) + m_thread.join(); + SASSERT(m_completed); + st.copy(m_st); + mdl = nullptr; + if (m_result == l_true && m_sls_model) { + ast_translation tr(m_sls, m); + mdl = m_sls_model->translate(tr); + TRACE("sls", tout << "model: " << *m_sls_model << "\n";); + if (!canceled) + ctx.set_finished(); + } + m_ddfw = nullptr; + // m_ddfw owns the pointer to smt_plugin and destructs it. + dealloc(d); + } + + std::ostream& smt_plugin::display(std::ostream& out) { + m_ddfw->display(out); + m_context.display(out); + return out; + } + + bool smt_plugin::is_shared(sat::literal lit) { + auto w = m_smt_bool_var2sls_bool_var.get(lit.var(), sat::null_bool_var); + if (w != sat::null_bool_var) + return true; + auto e = ctx.bool_var2expr(lit.var()); + expr* t = nullptr; + if (!e) + return false; + bv_util bv(m); + if (bv.is_bit2bool(e, t) && m_shared_terms.contains(t->get_id())) { + verbose_stream() << "shared bit2bool " << mk_bounded_pp(e, ctx.get_manager()) << "\n"; + return true; + } + + // if arith.is_le(e, s, t) && t is a numeral, s is shared-term.... + return false; + } + + void smt_plugin::add_shared_var(sat::bool_var v, sat::bool_var w) { + m_smt_bool_var2sls_bool_var.setx(v, w, sat::null_bool_var); + m_sls_bool_var2smt_bool_var.setx(w, v, sat::null_bool_var); + m_sls_phase.reserve(v + 1); + m_sat_phase.reserve(v + 1); + m_rewards.reserve(v + 1); + m_shared_bool_vars.insert(v); + } + + void smt_plugin::add_unit(sat::literal lit) { + if (!is_shared(lit)) + return; + std::lock_guard lock(m_mutex); + m_units.push_back(lit); + m_has_units = true; + } + + void smt_plugin::import_phase_from_smt() { + if (m_has_new_sat_phase) + return; + m_has_new_sat_phase = true; + IF_VERBOSE(3, verbose_stream() << "new SMT -> SLS phase\n"); + ctx.set_has_new_best_phase(false); + std::lock_guard lock(m_mutex); + for (auto v : m_shared_bool_vars) + m_sat_phase[v] = ctx.get_best_phase(v); + } + + bool smt_plugin::export_to_sls() { + bool updated = false; + if (export_units_to_sls()) + updated = true; + if (export_phase_to_sls()) + updated = true; + return updated; + } + + bool smt_plugin::export_phase_to_sls() { + if (!m_has_new_sat_phase) + return false; + std::lock_guard lock(m_mutex); + IF_VERBOSE(3, verbose_stream() << "SMT -> SLS phase\n"); + for (auto v : m_shared_bool_vars) { + auto w = m_smt_bool_var2sls_bool_var[v]; + if (m_sat_phase[v] != is_true(sat::literal(w, false))) + flip(w); + m_ddfw->bias(w) = m_sat_phase[v] ? 1 : -1; + } + m_has_new_sat_phase = false; + return true; + } + + bool smt_plugin::export_units_to_sls() { + if (!m_has_units) + return false; + std::lock_guard lock(m_mutex); + IF_VERBOSE(2, verbose_stream() << "SMT -> SLS units " << m_units << "\n"); + for (auto lit : m_units) { + auto v = lit.var(); + if (m_shared_bool_vars.contains(v)) { + auto w = m_smt_bool_var2sls_bool_var[v]; + sat::literal sls_lit(w, lit.sign()); + IF_VERBOSE(10, verbose_stream() << "unit " << sls_lit << "\n"); + m_ddfw->add(1, &sls_lit); + } + else { + IF_VERBOSE(0, verbose_stream() << "value restriction " << lit << " " + << mk_bounded_pp(ctx.bool_var2expr(lit.var()), m) << "\n"); + } + } + m_has_units = false; + m_units.reset(); + return true; + } + + void smt_plugin::export_from_sls() { + if (unsat().size() > m_min_unsat_size) + return; + m_min_unsat_size = unsat().size(); + std::lock_guard lock(m_mutex); + for (auto v : m_shared_bool_vars) { + auto w = m_smt_bool_var2sls_bool_var[v]; + m_rewards[v] = m_ddfw->get_reward_avg(w); + //verbose_stream() << v << " " << w << "\n"; + VERIFY(m_ddfw->get_model().size() > w); + VERIFY(m_sls_phase.size() > v); + m_sls_phase[v] = l_true == m_ddfw->get_model()[w]; + m_has_new_sls_phase = true; + } + // export_values_from_sls(); + } + + void smt_plugin::export_values_from_sls() { + IF_VERBOSE(3, verbose_stream() << "import values from sls\n"); + std::lock_guard lock(m_mutex); + for (auto const& [t, t_sync] : m_sls2sync_uninterp) { + expr_ref val_t = m_context.get_value(t_sync); + m_sync_values.set(t_sync->get_id(), val_t.get()); + } + m_has_new_sls_values = true; + } + + void smt_plugin::import_from_sls() { + export_activity_to_smt(); + export_values_to_smt(); + export_phase_to_smt(); + } + + void smt_plugin::export_activity_to_smt() { + + } + + void smt_plugin::export_values_to_smt() { + if (!m_has_new_sls_values) + return; + IF_VERBOSE(3, verbose_stream() << "SLS -> SMT values\n"); + std::lock_guard lock(m_mutex); + ast_translation tr(m_sync, m); + for (auto const& [t, t_sync] : m_smt2sync_uninterp) { + expr* sync_val = m_sync_values.get(t_sync->get_id(), nullptr); + if (!sync_val) + continue; + expr_ref val(tr(sync_val), m); + ctx.initialize_value(t, val); + } + m_has_new_sls_values = false; + } + + void smt_plugin::export_phase_to_smt() { + if (!m_has_new_sls_phase) + return; + std::lock_guard lock(m_mutex); + IF_VERBOSE(3, verbose_stream() << "SLS -> SMT phase\n"); + for (auto v : m_shared_bool_vars) { + auto w = m_smt_bool_var2sls_bool_var[v]; + ctx.force_phase(sat::literal(w, m_sls_phase[v])); + } + m_has_new_sls_phase = false; + } + + void smt_plugin::add_shared_term(expr* t) { + m_shared_terms.insert(t->get_id()); + if (is_uninterp(t)) + add_uninterp(t); + } + + void smt_plugin::add_uninterp(expr* smt_t) { + auto sync_t = m_smt2sync_tr(smt_t); + auto sls_t = m_smt2sls_tr(smt_t); + m_sync_uninterp.push_back(sync_t); + m_sls_uninterp.push_back(sls_t); + m_smt2sync_uninterp.insert(smt_t, sync_t); + m_sls2sync_uninterp.insert(sls_t, sync_t); + } + + void smt_plugin::on_save_model() { + TRACE("sls", display(tout)); + while (unsat().empty()) { + m_context.check(); + if (!m_new_clause_added) + break; + m_ddfw->reinit(); + m_new_clause_added = false; + } + // export_from_sls(); + } +} diff --git a/src/ast/sls/sls_smt_plugin.h b/src/ast/sls/sls_smt_plugin.h new file mode 100644 index 00000000000..0ad28d6dc08 --- /dev/null +++ b/src/ast/sls/sls_smt_plugin.h @@ -0,0 +1,158 @@ +/*++ +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + sls_smt_plugin.h + +Abstract: + + A Stochastic Local Search (SLS) Plugin. + +Author: + + Nikolaj Bjorner (nbjorner) 2024-07-10 + +--*/ + +#pragma once + +#include "ast/sls/sls_context.h" +#include "ast/sls/sat_ddfw.h" +#include "util/statistics.h" +#include +#include + +namespace sls { + + class smt_context { + public: + virtual ~smt_context() {} + virtual ast_manager& get_manager() = 0; + virtual params_ref get_params() = 0; + virtual void initialize_value(expr* t, expr* v) = 0; + virtual void force_phase(sat::literal lit) = 0; + virtual void set_has_new_best_phase(bool b) = 0; + virtual bool get_best_phase(sat::bool_var v) = 0; + virtual expr* bool_var2expr(sat::bool_var v) = 0; + virtual void set_finished() = 0; + virtual unsigned get_num_bool_vars() const = 0; + }; + + + // + // m is accessed by the main thread + // m_sls is accessed by the sls thread + // m_sync is accessed by both + // + class smt_plugin : public sat::local_search_plugin, public sat_solver_context { + smt_context& ctx; + ast_manager& m; + ast_manager m_sls; + ast_manager m_sync; + ast_translation m_smt2sync_tr, m_smt2sls_tr; + expr_ref_vector m_sync_uninterp, m_sls_uninterp; + expr_ref_vector m_sync_values; + sat::ddfw* m_ddfw = nullptr; + sls::context m_context; + std::atomic m_result; + std::atomic m_completed, m_has_units; + std::thread m_thread; + std::mutex m_mutex; + + sat::literal_vector m_units; + model_ref m_sls_model; + ::statistics m_st; + bool m_new_clause_added = false; + unsigned m_min_unsat_size = UINT_MAX; + obj_map m_sls2sync_uninterp; // hashtable from sls-uninterp to sync uninterp + obj_map m_smt2sync_uninterp; // hashtable from external uninterp to sync uninterp + std::atomic m_has_new_sls_values = false; + uint_set m_shared_bool_vars, m_shared_terms; + svector m_sat_phase; + std::atomic m_has_new_sat_phase = false; + std::atomic m_has_new_sls_phase = false; + svector m_sls_phase; + svector m_rewards; + svector m_smt_bool_var2sls_bool_var, m_sls_bool_var2smt_bool_var; + + bool is_shared(sat::literal lit); + void run(); + void add_shared_term(expr* t); + void add_uninterp(expr* smt_t); + void add_shared_var(sat::bool_var v, sat::bool_var w); + + void import_phase_from_smt(); + void import_values_from_sls(); + void export_values_from_sls(); + void import_activity_from_sls(); + bool export_phase_to_sls(); + bool export_units_to_sls(); + void export_values_to_smt(); + void export_activity_to_smt(); + void export_phase_to_smt(); + + void export_from_sls(); + + friend class sat::ddfw; + ~smt_plugin(); + + public: + smt_plugin(smt_context& ctx); + + // interface to calling solver: + void check(expr_ref_vector const& fmls, vector const& clauses); + void finalize(model_ref& md, ::statistics& st); + void updt_params(params_ref& p) {} + std::ostream& display(std::ostream& out) override; + + bool export_to_sls(); + void import_from_sls(); + bool completed() { return m_completed; } + void add_unit(sat::literal lit); + + // local_search_plugin: + void on_restart() override { + if (export_to_sls()) + m_ddfw->reinit(); + } + + void on_save_model() override; + + void on_model(model_ref& mdl) override { + IF_VERBOSE(3, verbose_stream() << "on-model " << "\n"); + m_sls_model = mdl; + } + + void init_search() override {} + + void finish_search() override {} + + void on_rescale() override {} + + + + // sat_solver_context: + vector const& clauses() const override { return m_ddfw->clauses(); } + sat::clause_info const& get_clause(unsigned idx) const override { return m_ddfw->get_clause_info(idx); } + ptr_iterator get_use_list(sat::literal lit) override { return m_ddfw->use_list(lit); } + void flip(sat::bool_var v) override { + m_ddfw->flip(v); + } + double reward(sat::bool_var v) override { return m_ddfw->get_reward(v); } + double get_weigth(unsigned clause_idx) override { return m_ddfw->get_clause_info(clause_idx).m_weight; } + bool is_true(sat::literal lit) override { + return m_ddfw->get_value(lit.var()) != lit.sign(); + } + unsigned num_vars() const override { return m_ddfw->num_vars(); } + indexed_uint_set const& unsat() const override { return m_ddfw->unsat_set(); } + sat::bool_var add_var() override { + return m_ddfw->add_var(); + } + void add_clause(unsigned n, sat::literal const* lits) override { + m_ddfw->add(n, lits); + m_new_clause_added = true; + } + void force_restart() override { m_ddfw->force_restart(); } + }; +} diff --git a/src/ast/sls/sls_smt_solver.cpp b/src/ast/sls/sls_smt_solver.cpp new file mode 100644 index 00000000000..09efea1998c --- /dev/null +++ b/src/ast/sls/sls_smt_solver.cpp @@ -0,0 +1,171 @@ +/*++ +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + sls_smt_solver.cpp + +Abstract: + + A Stochastic Local Search (SLS) Solver. + +Author: + + Nikolaj Bjorner (nbjorner) 2024-07-10 + +--*/ + +#include "ast/sls/sls_context.h" +#include "ast/sls/sat_ddfw.h" +#include "ast/sls/sls_smt_solver.h" +#include "ast/ast_ll_pp.h" + + +namespace sls { + + class smt_solver::solver_ctx : public sat::local_search_plugin, public sls::sat_solver_context { + ast_manager& m; + sat::ddfw& m_ddfw; + context m_context; + bool m_dirty = false; + bool m_new_constraint = false; + model_ref m_model; + obj_map m_expr2lit; + public: + solver_ctx(ast_manager& m, sat::ddfw& d) : + m(m), m_ddfw(d), m_context(m, *this) { + m_ddfw.set_plugin(this); + m.limit().push_child(&m_ddfw.rlimit()); + } + + ~solver_ctx() override { + m.limit().pop_child(&m_ddfw.rlimit()); + } + + void init_search() override {} + + void finish_search() override {} + + void on_rescale() override {} + + void on_restart() override { + m_context.on_restart(); + } + + bool m_on_save_model = false; + void on_save_model() override { + if (m_on_save_model) + return; + flet _on_save_model(m_on_save_model, true); + CTRACE("sls", unsat().empty(), display(tout)); + while (unsat().empty()) { + m_context.check(); + if (!m_new_constraint) + break; + TRACE("sls", display(tout)); + //m_ddfw.simplify(); + m_ddfw.reinit(); + m_new_constraint = false; + } + } + + void on_model(model_ref& mdl) override { + IF_VERBOSE(1, verbose_stream() << "on-model " << "\n"); + m_model = mdl; + } + + void register_atom(sat::bool_var v, expr* e) { + m_context.register_atom(v, e); + } + + std::ostream& display(std::ostream& out) override { + m_ddfw.display(out); + m_context.display(out); + return out; + } + + vector const& clauses() const override { return m_ddfw.clauses(); } + sat::clause_info const& get_clause(unsigned idx) const override { return m_ddfw.get_clause_info(idx); } + ptr_iterator get_use_list(sat::literal lit) override { return m_ddfw.use_list(lit); } + void flip(sat::bool_var v) override { if (m_dirty) m_ddfw.reinit(), m_dirty = false; m_ddfw.flip(v); } + double reward(sat::bool_var v) override { return m_ddfw.get_reward(v); } + double get_weigth(unsigned clause_idx) override { return m_ddfw.get_clause_info(clause_idx).m_weight; } + bool is_true(sat::literal lit) override { return m_ddfw.get_value(lit.var()) != lit.sign(); } + unsigned num_vars() const override { return m_ddfw.num_vars(); } + indexed_uint_set const& unsat() const override { return m_ddfw.unsat_set(); } + sat::bool_var add_var() override { m_dirty = true; return m_ddfw.add_var(); } + void add_clause(expr* f) { m_context.add_clause(f); } + + void force_restart() override { m_ddfw.force_restart(); } + + void add_clause(unsigned n, sat::literal const* lits) override { + m_ddfw.add(n, lits); + m_new_constraint = true; + } + + sat::literal mk_literal() { + sat::bool_var v = add_var(); + return sat::literal(v, false); + } + + model_ref get_model() { return m_model; } + + void collect_statistics(statistics& st) { + m_ddfw.collect_statistics(st); + m_context.collect_statistics(st); + } + + void reset_statistics() { + m_ddfw.reset_statistics(); + m_context.reset_statistics(); + } + + void updt_params(params_ref const& p) { + m_ddfw.updt_params(p); + m_context.updt_params(p); + } + }; + + smt_solver::smt_solver(ast_manager& m, params_ref const& p): + m(m), + m_solver_ctx(alloc(solver_ctx, m, m_ddfw)), + m_assertions(m) { + + m_solver_ctx->updt_params(p); + } + + smt_solver::~smt_solver() { + } + + void smt_solver::assert_expr(expr* e) { + if (m.is_and(e)) { + for (expr* arg : *to_app(e)) + assert_expr(arg); + } + else + m_assertions.push_back(e); + } + + lbool smt_solver::check() { + for (auto f : m_assertions) + m_solver_ctx->add_clause(f); + IF_VERBOSE(10, m_solver_ctx->display(verbose_stream())); + return m_ddfw.check(0, nullptr); + } + + model_ref smt_solver::get_model() { + return m_solver_ctx->get_model(); + } + + std::ostream& smt_solver::display(std::ostream& out) { + return m_solver_ctx->display(out); + } + + void smt_solver::collect_statistics(statistics& st) { + m_solver_ctx->collect_statistics(st); + } + + void smt_solver::reset_statistics() { + m_solver_ctx->reset_statistics(); + } +} diff --git a/src/ast/sls/sls_smt_solver.h b/src/ast/sls/sls_smt_solver.h new file mode 100644 index 00000000000..914397fc1fb --- /dev/null +++ b/src/ast/sls/sls_smt_solver.h @@ -0,0 +1,44 @@ +/*++ +Copyright (c) 2024 Microsoft Corporation + +Module Name: + + sls_smt_solver.h + +Abstract: + + A Stochastic Local Search (SLS) Solver. + +Author: + + Nikolaj Bjorner (nbjorner) 2024-07-10 + +--*/ + +#pragma once +#include "ast/sls/sls_context.h" +#include "ast/sls/sat_ddfw.h" + + +namespace sls { + + class smt_solver { + ast_manager& m; + class solver_ctx; + sat::ddfw m_ddfw; + solver_ctx* m_solver_ctx = nullptr; + expr_ref_vector m_assertions; + statistics m_st; + + public: + smt_solver(ast_manager& m, params_ref const& p); + ~smt_solver(); + void assert_expr(expr* e); + lbool check(); + model_ref get_model(); + void updt_params(params_ref& p) {} + void collect_statistics(statistics& st); + std::ostream& display(std::ostream& out); + void reset_statistics(); + }; +} diff --git a/src/math/hilbert/hilbert_basis.h b/src/math/hilbert/hilbert_basis.h index f743f32ca01..0f089139ef5 100644 --- a/src/math/hilbert/hilbert_basis.h +++ b/src/math/hilbert/hilbert_basis.h @@ -43,7 +43,7 @@ class hilbert_basis { typedef vector num_vector; static checked_int64 to_numeral(rational const& r) { if (!r.is_int64()) { - throw checked_int64::overflow_exception(); + throw overflow_exception(); } return checked_int64(r.get_int64()); } diff --git a/src/model/model.cpp b/src/model/model.cpp index 362907e466b..ce2f3402e57 100644 --- a/src/model/model.cpp +++ b/src/model/model.cpp @@ -39,6 +39,7 @@ Revision History: #include "model/datatype_factory.h" #include "model/numeral_factory.h" #include "model/fpa_factory.h" +#include "model/char_factory.h" model::model(ast_manager & m): @@ -103,12 +104,14 @@ value_factory* model::get_factory(sort* s) { if (m_factories.plugins().empty()) { seq_util su(m); fpa_util fu(m); + m_factories.register_plugin(alloc(basic_factory, m, 0)); m_factories.register_plugin(alloc(array_factory, m, *this)); m_factories.register_plugin(alloc(datatype_factory, m, *this)); m_factories.register_plugin(alloc(bv_factory, m)); m_factories.register_plugin(alloc(arith_factory, m)); m_factories.register_plugin(alloc(seq_factory, m, su.get_family_id(), *this)); m_factories.register_plugin(alloc(fpa_value_factory, m, fu.get_family_id())); + //m_factories.register_plugin(alloc(char_factory, m, char_decl_plugin(m).get_family_id()); } family_id fid = s->get_family_id(); return m_factories.get_plugin(fid); diff --git a/src/opt/opt_context.cpp b/src/opt/opt_context.cpp index f3b4c78b7fb..81a2b80b57e 100644 --- a/src/opt/opt_context.cpp +++ b/src/opt/opt_context.cpp @@ -42,7 +42,7 @@ Module Name: #include "ast/converters/generic_model_converter.h" #include "ackermannization/ackermannize_bv_tactic.h" #include "sat/sat_solver/inc_sat_solver.h" -#include "sat/sat_params.hpp" +#include "params/sat_params.hpp" #include "opt/opt_context.h" #include "opt/opt_solver.h" #include "opt/opt_params.hpp" diff --git a/src/opt/opt_lns.cpp b/src/opt/opt_lns.cpp index ae9bff21552..3a519565889 100644 --- a/src/opt/opt_lns.cpp +++ b/src/opt/opt_lns.cpp @@ -20,7 +20,7 @@ Module Name: #include "ast/pb_decl_plugin.h" #include "opt/maxsmt.h" #include "opt/opt_lns.h" -#include "sat/sat_params.hpp" +#include "params/sat_params.hpp" #include namespace opt { diff --git a/src/params/CMakeLists.txt b/src/params/CMakeLists.txt index 763702caf56..3ecf00dd413 100644 --- a/src/params/CMakeLists.txt +++ b/src/params/CMakeLists.txt @@ -14,6 +14,7 @@ z3_add_component(params pattern_inference_params_helper.pyg poly_rewriter_params.pyg rewriter_params.pyg + sat_params.pyg seq_rewriter_params.pyg sls_params.pyg solver_params.pyg diff --git a/src/sat/sat_params.pyg b/src/params/sat_params.pyg similarity index 100% rename from src/sat/sat_params.pyg rename to src/params/sat_params.pyg diff --git a/src/params/sls_params.pyg b/src/params/sls_params.pyg index 18b8d337146..b1337600493 100644 --- a/src/params/sls_params.pyg +++ b/src/params/sls_params.pyg @@ -22,6 +22,8 @@ def_module_params('sls', ('early_prune', BOOL, 1, 'use early pruning for score prediction'), ('random_offset', BOOL, 1, 'use random offset for candidate evaluation'), ('rescore', BOOL, 1, 'rescore/normalize top-level score every base restart interval'), + ('euf_incremental', UINT, 0, '0 non-incremental, 1 incremental, 2 alternating EUF resolver'), + ('dt_axiomatic', BOOL, True, 'use axiomatic mode or model reduction for datatype solver'), ('track_unsat', BOOL, 0, 'keep a list of unsat assertions as done in SAT - currently disabled internally'), ('random_seed', UINT, 0, 'random seed') )) diff --git a/src/qe/qe_mbp.cpp b/src/qe/qe_mbp.cpp index 96ae5e85aab..5f575218756 100644 --- a/src/qe/qe_mbp.cpp +++ b/src/qe/qe_mbp.cpp @@ -42,7 +42,8 @@ Revision History: using namespace qe; -namespace { +namespace qembp { + // rewrite select(store(a, i, k), j) into k if m \models i = j and select(a, j) if m \models i != j struct rd_over_wr_rewriter : public default_rewriter_cfg { ast_manager &m; @@ -124,19 +125,19 @@ namespace { }; } -template class rewriter_tpl; -template class rewriter_tpl; +template class rewriter_tpl; +template class rewriter_tpl; void rewrite_as_const_arr(expr* in, model& mdl, expr_ref& out) { - app_const_arr_rewriter cfg(out.m(), mdl); - rewriter_tpl rw(out.m(), false, cfg); + qembp::app_const_arr_rewriter cfg(out.m(), mdl); + rewriter_tpl rw(out.m(), false, cfg); rw(in, out); } void rewrite_read_over_write(expr *in, model &mdl, expr_ref &out) { - rd_over_wr_rewriter cfg(out.m(), mdl); - rewriter_tpl rw(out.m(), false, cfg); + qembp::rd_over_wr_rewriter cfg(out.m(), mdl); + rewriter_tpl rw(out.m(), false, cfg); rw(in, out); if (cfg.m_sc.empty()) return; expr_ref_vector sc(out.m()); diff --git a/src/sat/CMakeLists.txt b/src/sat/CMakeLists.txt index b6f6a6f9454..48e6959b3c3 100644 --- a/src/sat/CMakeLists.txt +++ b/src/sat/CMakeLists.txt @@ -15,7 +15,7 @@ z3_add_component(sat sat_config.cpp sat_cut_simplifier.cpp sat_cutset.cpp - sat_ddfw.cpp + sat_ddfw_wrapper.cpp sat_drat.cpp sat_elim_eqs.cpp sat_elim_vars.cpp @@ -43,7 +43,6 @@ z3_add_component(sat params PYG_FILES sat_asymm_branch_params.pyg - sat_params.pyg sat_scc_params.pyg sat_simplifier_params.pyg ) diff --git a/src/sat/sat_config.cpp b/src/sat/sat_config.cpp index 73516f66d12..7ee747b5b9e 100644 --- a/src/sat/sat_config.cpp +++ b/src/sat/sat_config.cpp @@ -16,9 +16,9 @@ Module Name: Revision History: --*/ +#include "params/sat_params.hpp" #include "sat/sat_config.h" #include "sat/sat_types.h" -#include "sat/sat_params.hpp" #include "sat/sat_simplifier_params.hpp" #include "params/solver_params.hpp" diff --git a/src/sat/sat_ddfw_wrapper.cpp b/src/sat/sat_ddfw_wrapper.cpp new file mode 100644 index 00000000000..2fba213de94 --- /dev/null +++ b/src/sat/sat_ddfw_wrapper.cpp @@ -0,0 +1,85 @@ +/*++ + Copyright (c) 2019 Microsoft Corporation + + Module Name: + + sat_ddfw_wrapper.cpp + +*/ + +#include "sat/sat_ddfw_wrapper.h" +#include "sat/sat_solver.h" +#include "sat/sat_parallel.h" + +namespace sat { + + lbool ddfw_wrapper::check(unsigned sz, literal const* assumptions, parallel* p) { + flet _p(m_par, p); + m_ddfw.m_parallel_sync = nullptr; + if (m_par) { + m_ddfw.m_parallel_sync = [&]() -> bool { + if (should_parallel_sync()) { + do_parallel_sync(); + return true; + } + else + return false; + }; + } + return m_ddfw.check(sz, assumptions); + } + + bool ddfw_wrapper::should_parallel_sync() { + return m_par != nullptr && m_ddfw.m_flips >= m_parsync_next; + } + + void ddfw_wrapper::do_parallel_sync() { + if (m_par->from_solver(*this)) + m_par->to_solver(*this); + + ++m_parsync_count; + m_parsync_next *= 3; + m_parsync_next /= 2; + } + + + void ddfw_wrapper::reinit(solver& s, bool_vector const& phase) { + add(s); + m_ddfw.add_assumptions(); + for (unsigned v = 0; v < phase.size(); ++v) { + m_ddfw.value(v) = phase[v]; + m_ddfw.reward(v) = 0; + m_ddfw.make_count(v) = 0; + } + m_ddfw.init_clause_data(); + m_ddfw.flatten_use_list(); + } + + void ddfw_wrapper::add(solver const& s) { + m_ddfw.set_seed(s.get_config().m_random_seed); + m_ddfw.m_clauses.reset(); + m_ddfw.m_use_list.reset(); + m_ddfw.m_num_non_binary_clauses = 0; + + unsigned trail_sz = s.init_trail_size(); + for (unsigned i = 0; i < trail_sz; ++i) { + m_ddfw.add(1, s.m_trail.data() + i); + } + unsigned sz = s.m_watches.size(); + for (unsigned l_idx = 0; l_idx < sz; ++l_idx) { + literal l1 = ~to_literal(l_idx); + watch_list const & wlist = s.m_watches[l_idx]; + for (watched const& w : wlist) { + if (!w.is_binary_non_learned_clause()) + continue; + literal l2 = w.get_literal(); + if (l1.index() > l2.index()) + continue; + literal ls[2] = { l1, l2 }; + m_ddfw.add(2, ls); + } + } + for (clause* c : s.m_clauses) + m_ddfw.add(c->size(), c->begin()); + } +} diff --git a/src/sat/sat_ddfw_wrapper.h b/src/sat/sat_ddfw_wrapper.h new file mode 100644 index 00000000000..6c87c72bd38 --- /dev/null +++ b/src/sat/sat_ddfw_wrapper.h @@ -0,0 +1,89 @@ +/*++ + Copyright (c) 2019 Microsoft Corporation + + Module Name: + + sat_ddfw_wrapper.h + + + --*/ +#pragma once + +#include "util/uint_set.h" +#include "util/rlimit.h" +#include "util/params.h" +#include "util/ema.h" +#include "util/sat_sls.h" +#include "util/map.h" +#include "ast/sls/sat_ddfw.h" +#include "sat/sat_types.h" + +namespace sat { + class solver; + class parallel; + + + class ddfw_wrapper : public i_local_search { + protected: + ddfw m_ddfw; + parallel* m_par = nullptr; + unsigned m_parsync_count = 0; + uint64_t m_parsync_next = 0; + + void do_parallel_sync(); + bool should_parallel_sync(); + + public: + + ddfw_wrapper() {} + + ~ddfw_wrapper() override {} + + void set_plugin(local_search_plugin* p) { m_ddfw.set_plugin(p); } + + lbool check(unsigned sz, literal const* assumptions, parallel* p) override; + + void updt_params(params_ref const& p) override { m_ddfw.updt_params(p); } + + model const& get_model() const override { return m_ddfw.get_model(); } + + reslimit& rlimit() override { return m_ddfw.rlimit(); } + + void set_seed(unsigned n) override { m_ddfw.set_seed(n); } + + void add(solver const& s) override; + + bool get_value(bool_var v) const override { return m_ddfw.get_value(v); } + + std::ostream& display(std::ostream& out) const { return m_ddfw.display(out); } + + // for parallel integration + unsigned num_non_binary_clauses() const override { return m_ddfw.num_non_binary_clauses(); } + + void reinit(solver& s, bool_vector const& phase) override; + + void collect_statistics(statistics& st) const override {} + + double get_priority(bool_var v) const override { return m_ddfw.get_priority(v); } + + // access clause information and state of Boolean search + indexed_uint_set& unsat_set() { return m_ddfw.unsat_set(); } + + vector const& clauses() const { return m_ddfw.clauses(); } + + clause_info& get_clause_info(unsigned idx) { return m_ddfw.get_clause_info(idx); } + + void remove_assumptions() { m_ddfw.remove_assumptions(); } + + void flip(bool_var v) { m_ddfw.flip(v); } + + inline double get_reward(bool_var v) const { return m_ddfw.get_reward(v); } + + void add(unsigned sz, literal const* c) { m_ddfw.add(sz, c); } + + void reinit() { m_ddfw.reinit(); } + + + }; +} + diff --git a/src/sat/sat_local_search.cpp b/src/sat/sat_local_search.cpp index 92b4f7f5c2e..5cb983baa0f 100644 --- a/src/sat/sat_local_search.cpp +++ b/src/sat/sat_local_search.cpp @@ -19,7 +19,7 @@ Module Name: #include "sat/sat_local_search.h" #include "sat/sat_solver.h" -#include "sat/sat_params.hpp" +#include "params/sat_params.hpp" #include "util/timer.h" namespace sat { diff --git a/src/sat/sat_solver.cpp b/src/sat/sat_solver.cpp index 5fe5c5e4e77..69f3c0f377b 100644 --- a/src/sat/sat_solver.cpp +++ b/src/sat/sat_solver.cpp @@ -29,7 +29,7 @@ Revision History: #include "sat/sat_solver.h" #include "sat/sat_integrity_checker.h" #include "sat/sat_lookahead.h" -#include "sat/sat_ddfw.h" +#include "sat/sat_ddfw_wrapper.h" #include "sat/sat_prob.h" #include "sat/sat_anf_simplifier.h" #include "sat/sat_cut_simplifier.h" @@ -1365,7 +1365,7 @@ namespace sat { } literal_vector _lits; scoped_limits scoped_rl(rlimit()); - m_local_search = alloc(ddfw); + m_local_search = alloc(ddfw_wrapper); scoped_ls _ls(*this); SASSERT(m_local_search); m_local_search->add(*this); @@ -1442,7 +1442,7 @@ namespace sat { lbool solver::do_ddfw_search(unsigned num_lits, literal const* lits) { if (m_ext) return l_undef; SASSERT(!m_local_search); - m_local_search = alloc(ddfw); + m_local_search = alloc(ddfw_wrapper); return invoke_local_search(num_lits, lits); } @@ -1485,7 +1485,7 @@ namespace sat { // set up ddfw search for (int i = 0; i < num_ddfw; ++i) { - ddfw* d = alloc(ddfw); + ddfw_wrapper* d = alloc(ddfw_wrapper); d->updt_params(m_params); d->set_seed(m_config.m_random_seed + i); d->add(*this); @@ -2932,6 +2932,7 @@ namespace sat { bool_var v = m_trail[i].var(); m_best_phase[v] = m_phase[v]; } + set_has_new_best_phase(true); } } diff --git a/src/sat/sat_solver.h b/src/sat/sat_solver.h index 57477f686cf..657c92178bd 100644 --- a/src/sat/sat_solver.h +++ b/src/sat/sat_solver.h @@ -152,6 +152,7 @@ namespace sat { bool_vector m_phase; bool_vector m_best_phase; bool_vector m_prev_phase; + bool m_new_best_phase = false; svector m_assigned_since_gc; search_state m_search_state; unsigned m_search_unsat_conflicts; @@ -228,7 +229,7 @@ namespace sat { friend class parallel; friend class lookahead; friend class local_search; - friend class ddfw; + friend class ddfw_wrapper; friend class prob; friend class unit_walk; friend struct mk_stat; @@ -380,6 +381,9 @@ namespace sat { bool was_eliminated(literal l) const { return was_eliminated(l.var()); } void set_phase(literal l) override { if (l.var() < num_vars()) m_best_phase[l.var()] = m_phase[l.var()] = !l.sign(); } bool get_phase(bool_var b) { return m_phase.get(b, false); } + bool get_best_phase(bool_var b) { return m_best_phase.get(b, false); } + void set_has_new_best_phase(bool b) { m_new_best_phase = b; } + bool has_new_best_phase() const { return m_new_best_phase; } void move_to_front(bool_var b); unsigned scope_lvl() const { return m_scope_lvl; } unsigned search_lvl() const { return m_search_lvl; } diff --git a/src/sat/sat_solver/inc_sat_solver.cpp b/src/sat/sat_solver/inc_sat_solver.cpp index 522638c1d3d..9b77b648276 100644 --- a/src/sat/sat_solver/inc_sat_solver.cpp +++ b/src/sat/sat_solver/inc_sat_solver.cpp @@ -39,7 +39,7 @@ Module Name: #include "model/model_v2_pp.h" #include "model/model_evaluator.h" #include "sat/sat_solver.h" -#include "sat/sat_params.hpp" +#include "params/sat_params.hpp" #include "sat/smt/euf_solver.h" #include "sat/tactic/goal2sat.h" #include "sat/tactic/sat2goal.h" diff --git a/src/sat/sat_solver/sat_smt_solver.cpp b/src/sat/sat_solver/sat_smt_solver.cpp index 1c141a80115..e21b0996ff8 100644 --- a/src/sat/sat_solver/sat_smt_solver.cpp +++ b/src/sat/sat_solver/sat_smt_solver.cpp @@ -33,7 +33,7 @@ Module Name: #include "model/model_evaluator.h" #include "sat/sat_solver.h" #include "solver/simplifier_solver.h" -#include "sat/sat_params.hpp" +#include "params/sat_params.hpp" #include "sat/smt/euf_solver.h" #include "sat/tactic/goal2sat.h" #include "sat/tactic/sat2goal.h" @@ -586,8 +586,9 @@ class sat_smt_solver : public solver { void add_assumption(expr* a) { init_goal2sat(); - m_dep.insert(a, m_goal2sat.internalize(a)); - get_euf()->add_assertion(a); + auto lit = m_goal2sat.internalize(a); + m_dep.insert(a, lit); + get_euf()->add_clause(1, &lit); } void internalize_assumptions(expr_ref_vector const& asms) { diff --git a/src/sat/smt/CMakeLists.txt b/src/sat/smt/CMakeLists.txt index 7747b65cb90..1e21935016f 100644 --- a/src/sat/smt/CMakeLists.txt +++ b/src/sat/smt/CMakeLists.txt @@ -3,7 +3,6 @@ z3_add_component(sat_smt arith_axioms.cpp arith_diagnostics.cpp arith_internalize.cpp - arith_sls.cpp arith_solver.cpp arith_value.cpp array_axioms.cpp @@ -22,7 +21,6 @@ z3_add_component(sat_smt euf_ackerman.cpp euf_internalize.cpp euf_invariant.cpp - euf_local_search.cpp euf_model.cpp euf_proof.cpp euf_proof_checker.cpp diff --git a/src/sat/smt/arith_axioms.cpp b/src/sat/smt/arith_axioms.cpp index 59529658b75..d9bb3aa90e5 100644 --- a/src/sat/smt/arith_axioms.cpp +++ b/src/sat/smt/arith_axioms.cpp @@ -547,6 +547,7 @@ namespace arith { } void solver::new_diseq_eh(euf::th_eq const& e) { + TRACE("artih", tout << mk_bounded_pp(e.eq(), m) << "\n"); ensure_column(e.v1()); ensure_column(e.v2()); m_delayed_eqs.push_back(std::make_pair(e, false)); diff --git a/src/sat/smt/arith_sls.cpp b/src/sat/smt/arith_sls.cpp deleted file mode 100644 index 2168299803b..00000000000 --- a/src/sat/smt/arith_sls.cpp +++ /dev/null @@ -1,642 +0,0 @@ -/*++ -Copyright (c) 2023 Microsoft Corporation - -Module Name: - - arith_local_search.cpp - -Abstract: - - Local search dispatch for SMT - -Author: - - Nikolaj Bjorner (nbjorner) 2023-02-07 - ---*/ -#include "sat/sat_solver.h" -#include "sat/smt/arith_solver.h" - - -namespace arith { - - sls::sls(solver& s): - s(s), m(s.m) {} - - void sls::reset() { - m_bool_vars.reset(); - m_vars.reset(); - m_terms.reset(); - } - - void sls::save_best_values() { - for (unsigned v = 0; v < s.get_num_vars(); ++v) - m_vars[v].m_best_value = m_vars[v].m_value; - check_ineqs(); - if (unsat().size() == 1) { - auto idx = *unsat().begin(); - verbose_stream() << idx << "\n"; - auto const& c = *m_bool_search->m_clauses[idx].m_clause; - verbose_stream() << c << "\n"; - for (auto lit : c) { - bool_var bv = lit.var(); - ineq* i = atom(bv); - if (i) - verbose_stream() << lit << ": " << *i << "\n"; - } - verbose_stream() << "\n"; - } - } - - void sls::store_best_values() { - // first compute assignment to terms - // then update non-basic variables in tableau. - - if (!unsat().empty()) - return; - - for (auto const& [t,v] : m_terms) { - int64_t val = 0; - lp::lar_term const& term = s.lp().get_term(t); - for (lp::lar_term::ival const& arg : term) { - auto t2 = arg.j(); - auto w = s.lp().local_to_external(t2); - val += to_numeral(arg.coeff()) * m_vars[w].m_best_value; - } - m_vars[v].m_best_value = val; - } - - for (unsigned v = 0; v < s.get_num_vars(); ++v) { - if (s.is_bool(v)) - continue; - if (!s.lp().external_is_used(v)) - continue; - int64_t new_value = m_vars[v].m_best_value; - s.ensure_column(v); - lp::lpvar vj = s.lp().external_to_local(v); - SASSERT(vj != lp::null_lpvar); - if (!s.lp().is_base(vj)) { - rational new_value_(new_value, rational::i64()); - lp::impq val(new_value_, rational::zero()); - s.lp().set_value_for_nbasic_column(vj, val); - } - } - - lbool r = s.make_feasible(); - VERIFY (!unsat().empty() || r == l_true); -#if 0 - if (unsat().empty()) - s.m_num_conflicts = s.get_config().m_arith_propagation_threshold; -#endif - - auto check_bool_var = [&](sat::bool_var bv) { - auto* ineq = m_bool_vars.get(bv, nullptr); - if (!ineq) - return; - api_bound* b = nullptr; - s.m_bool_var2bound.find(bv, b); - if (!b) - return; - auto bound = b->get_value(); - theory_var v = b->get_var(); - if (s.get_phase(bv) == m_bool_search->get_model()[bv]) - return; - switch (b->get_bound_kind()) { - case lp_api::lower_t: - verbose_stream() << "v" << v << " " << bound << " <= " << s.get_value(v) << " " << m_vars[v].m_best_value << "\n"; - break; - case lp_api::upper_t: - verbose_stream() << "v" << v << " " << bound << " >= " << s.get_value(v) << " " << m_vars[v].m_best_value << "\n"; - break; - } - int64_t value = 0; - for (auto const& [coeff, v] : ineq->m_args) { - value += coeff * m_vars[v].m_best_value; - } - ineq->m_args_value = value; - verbose_stream() << *ineq << " dtt " << dtt(false, *ineq) << " phase " << s.get_phase(bv) << " model " << m_bool_search->get_model()[bv] << "\n"; - for (auto const& [coeff, v] : ineq->m_args) - verbose_stream() << "v" << v << " := " << m_vars[v].m_best_value << "\n"; - s.display(verbose_stream()); - display(verbose_stream()); - UNREACHABLE(); - exit(0); - }; - - if (unsat().empty()) { - for (bool_var v = 0; v < s.s().num_vars(); ++v) - check_bool_var(v); - } - } - - void sls::set(sat::ddfw* d) { - m_bool_search = d; - reset(); - m_bool_vars.reserve(s.s().num_vars()); - add_vars(); - for (unsigned i = 0; i < d->num_clauses(); ++i) - for (sat::literal lit : *d->get_clause_info(i).m_clause) - init_bool_var(lit.var()); - for (unsigned v = 0; v < s.s().num_vars(); ++v) - init_bool_var_assignment(v); - - d->set(this); - } - - // distance to true - int64_t sls::dtt(bool sign, int64_t args, ineq const& ineq) const { - switch (ineq.m_op) { - case ineq_kind::LE: - if (sign) { - if (args <= ineq.m_bound) - return ineq.m_bound - args + 1; - return 0; - } - if (args <= ineq.m_bound) - return 0; - return args - ineq.m_bound; - case ineq_kind::EQ: - if (sign) { - if (args == ineq.m_bound) - return 1; - return 0; - } - if (args == ineq.m_bound) - return 0; - return 1; - case ineq_kind::NE: - if (sign) { - if (args == ineq.m_bound) - return 0; - return 1; - } - if (args == ineq.m_bound) - return 1; - return 0; - case ineq_kind::LT: - if (sign) { - if (args < ineq.m_bound) - return ineq.m_bound - args; - return 0; - } - if (args < ineq.m_bound) - return 0; - return args - ineq.m_bound + 1; - default: - UNREACHABLE(); - return 0; - } - } - - // - // dtt is high overhead. It walks ineq.m_args - // m_vars[w].m_value can be computed outside and shared among calls - // different data-structures for storing coefficients - // - int64_t sls::dtt(bool sign, ineq const& ineq, var_t v, int64_t new_value) const { - for (auto const& [coeff, w] : ineq.m_args) - if (w == v) - return dtt(sign, ineq.m_args_value + coeff * (new_value - m_vars[v].m_value), ineq); - return 1; - } - - int64_t sls::dtt(bool sign, ineq const& ineq, int64_t coeff, int64_t old_value, int64_t new_value) const { - return dtt(sign, ineq.m_args_value + coeff * (new_value - old_value), ineq); - } - - bool sls::cm(bool old_sign, ineq const& ineq, var_t v, int64_t& new_value) { - for (auto const& [coeff, w] : ineq.m_args) - if (w == v) - return cm(old_sign, ineq, v, coeff, new_value); - return false; - } - - bool sls::cm(bool old_sign, ineq const& ineq, var_t v, int64_t coeff, int64_t& new_value) { - SASSERT(ineq.is_true() != old_sign); - VERIFY(ineq.is_true() != old_sign); - auto bound = ineq.m_bound; - auto argsv = ineq.m_args_value; - bool solved = false; - int64_t delta = argsv - bound; - auto make_eq = [&]() { - SASSERT(delta != 0); - if (delta < 0) - new_value = value(v) + (abs(delta) + abs(coeff) - 1) / coeff; - else - new_value = value(v) - (delta + abs(coeff) - 1) / coeff; - solved = argsv + coeff * (new_value - value(v)) == bound; - if (!solved && abs(coeff) == 1) { - verbose_stream() << "did not solve equality " << ineq << " for " << v << "\n"; - verbose_stream() << new_value << " " << value(v) << " delta " << delta << " lhs " << (argsv + coeff * (new_value - value(v))) << " bound " << bound << "\n"; - UNREACHABLE(); - } - return solved; - }; - - auto make_diseq = [&]() { - if (delta >= 0) - delta++; - else - delta--; - new_value = value(v) + (abs(delta) + abs(coeff) - 1) / coeff; - VERIFY(argsv + coeff * (new_value - value(v)) != bound); - return true; - }; - - if (!old_sign) { - switch (ineq.m_op) { - case ineq_kind::LE: - // args <= bound -> args > bound - SASSERT(argsv <= bound); - SASSERT(delta <= 0); - --delta; - new_value = value(v) + (abs(delta) + abs(coeff) - 1) / coeff; - VERIFY(argsv + coeff * (new_value - value(v)) > bound); - return true; - case ineq_kind::LT: - // args < bound -> args >= bound - SASSERT(argsv <= ineq.m_bound); - SASSERT(delta <= 0); - new_value = value(v) + (abs(delta) + abs(coeff) - 1) / coeff; - VERIFY(argsv + coeff * (new_value - value(v)) >= bound); - return true; - case ineq_kind::EQ: - return make_diseq(); - case ineq_kind::NE: - return make_eq(); - default: - UNREACHABLE(); - break; - } - } - else { - switch (ineq.m_op) { - case ineq_kind::LE: - SASSERT(argsv > ineq.m_bound); - SASSERT(delta > 0); - new_value = value(v) - (delta + abs(coeff) - 1) / coeff; - VERIFY(argsv + coeff * (new_value - value(v)) <= bound); - return true; - case ineq_kind::LT: - SASSERT(argsv >= ineq.m_bound); - SASSERT(delta >= 0); - ++delta; - new_value = value(v) - (abs(delta) + abs(coeff) - 1) / coeff; - VERIFY(argsv + coeff * (new_value - value(v)) < bound); - return true; - case ineq_kind::NE: - return make_diseq(); - case ineq_kind::EQ: - return make_eq(); - default: - UNREACHABLE(); - break; - } - } - return false; - } - - // flip on the first positive score - // it could be changed to flip on maximal positive score - // or flip on maximal non-negative score - // or flip on first non-negative score - bool sls::flip(bool sign, ineq const& ineq) { - int64_t new_value; - auto v = ineq.m_var_to_flip; - if (v == UINT_MAX) { - IF_VERBOSE(1, verbose_stream() << "no var to flip\n"); - return false; - } - if (!cm(sign, ineq, v, new_value)) { - verbose_stream() << "no critical move for " << v << "\n"; - return false; - } - update(v, new_value); - return true; - } - - // - // dscore(op) = sum_c (dts(c,alpha) - dts(c,alpha_after)) * weight(c) - // TODO - use cached dts instead of computed dts - // cached dts has to be updated when the score of literals are updated. - // - double sls::dscore(var_t v, int64_t new_value) const { - double score = 0; - auto const& vi = m_vars[v]; - for (auto const& [coeff, bv] : vi.m_bool_vars) { - sat::literal lit(bv, false); - for (auto cl : m_bool_search->get_use_list(lit)) - score += (compute_dts(cl) - dts(cl, v, new_value)) * m_bool_search->get_weight(cl); - for (auto cl : m_bool_search->get_use_list(~lit)) - score += (compute_dts(cl) - dts(cl, v, new_value)) * m_bool_search->get_weight(cl); - } - return score; - } - - // - // cm_score is costly. It involves several cache misses. - // Note that - // - m_bool_search->get_use_list(lit).size() is "often" 1 or 2 - // - dtt_old can be saved - // - int sls::cm_score(var_t v, int64_t new_value) { - int score = 0; - auto& vi = m_vars[v]; - int64_t old_value = vi.m_value; - for (auto const& [coeff, bv] : vi.m_bool_vars) { - auto const& ineq = *atom(bv); - bool old_sign = sign(bv); - int64_t dtt_old = dtt(old_sign, ineq); - int64_t dtt_new = dtt(old_sign, ineq, coeff, old_value, new_value); - if ((dtt_old == 0) == (dtt_new == 0)) - continue; - sat::literal lit(bv, old_sign); - if (dtt_old == 0) - // flip from true to false - lit.neg(); - - // lit flips form false to true: - for (auto cl : m_bool_search->get_use_list(lit)) { - auto const& clause = get_clause_info(cl); - if (!clause.is_true()) - ++score; - } - // ignore the situation where clause contains multiple literals using v - for (auto cl : m_bool_search->get_use_list(~lit)) { - auto const& clause = get_clause_info(cl); - if (clause.m_num_trues == 1) - --score; - } - } - return score; - } - - int64_t sls::compute_dts(unsigned cl) const { - int64_t d(1), d2; - bool first = true; - for (auto a : get_clause(cl)) { - auto const* ineq = atom(a.var()); - if (!ineq) - continue; - d2 = dtt(a.sign(), *ineq); - if (first) - d = d2, first = false; - else - d = std::min(d, d2); - if (d == 0) - break; - } - return d; - } - - int64_t sls::dts(unsigned cl, var_t v, int64_t new_value) const { - int64_t d(1), d2; - bool first = true; - for (auto lit : get_clause(cl)) { - auto const* ineq = atom(lit.var()); - if (!ineq) - continue; - d2 = dtt(lit.sign(), *ineq, v, new_value); - if (first) - d = d2, first = false; - else - d = std::min(d, d2); - if (d == 0) - break; - } - return d; - } - - void sls::update(var_t v, int64_t new_value) { - auto& vi = m_vars[v]; - auto old_value = vi.m_value; - for (auto const& [coeff, bv] : vi.m_bool_vars) { - auto& ineq = *atom(bv); - bool old_sign = sign(bv); - sat::literal lit(bv, old_sign); - SASSERT(is_true(lit)); - ineq.m_args_value += coeff * (new_value - old_value); - int64_t dtt_new = dtt(old_sign, ineq); - if (dtt_new != 0) - m_bool_search->flip(bv); - SASSERT(dtt(sign(bv), ineq) == 0); - } - vi.m_value = new_value; - } - - void sls::add_vars() { - SASSERT(m_vars.empty()); - for (unsigned v = 0; v < s.get_num_vars(); ++v) { - int64_t value = s.is_registered_var(v) ? to_numeral(s.get_ivalue(v).x) : 0; - auto k = s.is_int(v) ? sls::var_kind::INT : sls::var_kind::REAL; - m_vars.push_back({ value, value, k, {} }); - } - } - - sls::ineq& sls::new_ineq(ineq_kind op, int64_t const& bound) { - auto* i = alloc(ineq); - i->m_bound = bound; - i->m_op = op; - return *i; - } - - void sls::add_arg(sat::bool_var bv, ineq& ineq, int64_t const& c, var_t v) { - ineq.m_args.push_back({ c, v }); - ineq.m_args_value += c * value(v); - m_vars[v].m_bool_vars.push_back({ c, bv}); - } - - int64_t sls::to_numeral(rational const& r) { - if (r.is_int64()) - return r.get_int64(); - return 0; - } - - void sls::add_args(sat::bool_var bv, ineq& ineq, lp::lpvar t, theory_var v, int64_t sign) { - if (s.lp().column_has_term(t)) { - lp::lar_term const& term = s.lp().get_term(t); - m_terms.push_back({t,v}); - for (lp::lar_term::ival arg : term) { - auto t2 = arg.j(); - auto w = s.lp().local_to_external(t2); - add_arg(bv, ineq, sign * to_numeral(arg.coeff()), w); - } - } - else - add_arg(bv, ineq, sign, s.lp().local_to_external(t)); - } - - void sls::init_bool_var(sat::bool_var bv) { - if (m_bool_vars.get(bv, nullptr)) - return; - api_bound* b = nullptr; - s.m_bool_var2bound.find(bv, b); - if (b) { - auto t = b->column_index(); - rational bound = b->get_value(); - bool should_minus = false; - sls::ineq_kind op; - should_minus = b->get_bound_kind() == lp_api::bound_kind::lower_t; - op = sls::ineq_kind::LE; - if (should_minus) - bound.neg(); - - auto& ineq = new_ineq(op, to_numeral(bound)); - - - add_args(bv, ineq, t, b->get_var(), should_minus ? -1 : 1); - m_bool_vars.set(bv, &ineq); - m_bool_search->set_external(bv); - return; - } - - expr* e = s.bool_var2expr(bv); - expr* l = nullptr, * r = nullptr; - if (e && m.is_eq(e, l, r) && s.a.is_int_real(l)) { - theory_var u = s.get_th_var(l); - theory_var v = s.get_th_var(r); - lp::lpvar tu = s.get_column(u); - lp::lpvar tv = s.get_column(v); - auto& ineq = new_ineq(sls::ineq_kind::EQ, 0); - add_args(bv, ineq, tu, u, 1); - add_args(bv, ineq, tv, v, -1); - m_bool_vars.set(bv, &ineq); - m_bool_search->set_external(bv); - return; - } - } - - void sls::init_bool_var_assignment(sat::bool_var v) { - auto* ineq = m_bool_vars.get(v, nullptr); - if (ineq && is_true(sat::literal(v, false)) != (dtt(false, *ineq) == 0)) - m_bool_search->flip(v); - } - - void sls::init_search() { - on_restart(); - } - - void sls::finish_search() { - store_best_values(); - } - - void sls::flip(sat::bool_var v) { - sat::literal lit(v, !sign(v)); - SASSERT(!is_true(lit)); - auto const* ineq = atom(v); - if (!ineq) - IF_VERBOSE(0, verbose_stream() << "no inequality for variable " << v << "\n"); - if (!ineq) - return; - SASSERT(ineq->is_true() == lit.sign()); - flip(sign(v), *ineq); - } - - double sls::reward(sat::bool_var v) { - if (m_dscore_mode) - return dscore_reward(v); - else - return dtt_reward(v); - } - - double sls::dtt_reward(sat::bool_var bv0) { - bool sign0 = sign(bv0); - auto* ineq = atom(bv0); - if (!ineq) - return -1; - int64_t new_value; - double max_result = -1; - for (auto const & [coeff, x] : ineq->m_args) { - if (!cm(sign0, *ineq, x, coeff, new_value)) - continue; - double result = 0; - auto old_value = m_vars[x].m_value; - for (auto const& [coeff, bv] : m_vars[x].m_bool_vars) { - result += m_bool_search->reward(bv); - continue; - bool old_sign = sign(bv); - auto dtt_old = dtt(old_sign, *atom(bv)); - auto dtt_new = dtt(old_sign, *atom(bv), coeff, old_value, new_value); - if ((dtt_new == 0) != (dtt_old == 0)) - result += m_bool_search->reward(bv); - } - if (result > max_result) { - max_result = result; - ineq->m_var_to_flip = x; - } - } - return max_result; - } - - double sls::dscore_reward(sat::bool_var bv) { - m_dscore_mode = false; - bool old_sign = sign(bv); - sat::literal litv(bv, old_sign); - auto* ineq = atom(bv); - if (!ineq) - return 0; - SASSERT(ineq->is_true() != old_sign); - int64_t new_value; - - for (auto const& [coeff, v] : ineq->m_args) { - double result = 0; - if (cm(old_sign, *ineq, v, coeff, new_value)) - result = dscore(v, new_value); - // just pick first positive, or pick a max? - if (result > 0) { - ineq->m_var_to_flip = v; - return result; - } - } - return 0; - } - - // switch to dscore mode - void sls::on_rescale() { - m_dscore_mode = true; - } - - void sls::on_save_model() { - save_best_values(); - } - - void sls::on_restart() { - for (unsigned v = 0; v < s.s().num_vars(); ++v) - init_bool_var_assignment(v); - - check_ineqs(); - } - - void sls::check_ineqs() { - - auto check_bool_var = [&](sat::bool_var bv) { - auto const* ineq = atom(bv); - if (!ineq) - return; - int64_t d = dtt(sign(bv), *ineq); - sat::literal lit(bv, sign(bv)); - if (is_true(lit) != (d == 0)) { - verbose_stream() << "invalid assignment " << bv << " " << *ineq << "\n"; - } - VERIFY(is_true(lit) == (d == 0)); - }; - for (unsigned v = 0; v < s.get_num_vars(); ++v) - check_bool_var(v); - } - - std::ostream& sls::display(std::ostream& out) const { - for (bool_var bv = 0; bv < s.s().num_vars(); ++bv) { - auto const* ineq = atom(bv); - if (!ineq) - continue; - out << bv << " " << *ineq << "\n"; - } - for (unsigned v = 0; v < s.get_num_vars(); ++v) { - if (s.is_bool(v)) - continue; - out << "v" << v << " := " << m_vars[v].m_value << " " << m_vars[v].m_best_value << "\n"; - } - return out; - } - -} diff --git a/src/sat/smt/arith_sls.h b/src/sat/smt/arith_sls.h deleted file mode 100644 index a65ca686d70..00000000000 --- a/src/sat/smt/arith_sls.h +++ /dev/null @@ -1,169 +0,0 @@ -/*++ -Copyright (c) 2020 Microsoft Corporation - -Module Name: - - arith_local_search.h - -Abstract: - - Theory plugin for arithmetic local search - -Author: - - Nikolaj Bjorner (nbjorner) 2020-09-08 - ---*/ -#pragma once - -#include "util/obj_pair_set.h" -#include "ast/ast_trail.h" -#include "ast/arith_decl_plugin.h" -#include "math/lp/indexed_value.h" -#include "math/lp/lar_solver.h" -#include "math/lp/nla_solver.h" -#include "math/lp/lp_types.h" -#include "math/lp/lp_api.h" -#include "math/polynomial/algebraic_numbers.h" -#include "math/polynomial/polynomial.h" -#include "sat/smt/sat_th.h" -#include "sat/sat_ddfw.h" - -namespace arith { - - class solver; - - // local search portion for arithmetic - class sls : public sat::local_search_plugin { - enum class ineq_kind { EQ, LE, LT, NE }; - enum class var_kind { INT, REAL }; - typedef unsigned var_t; - typedef unsigned atom_t; - - struct config { - double cb = 0.0; - unsigned L = 20; - unsigned t = 45; - unsigned max_no_improve = 500000; - double sp = 0.0003; - }; - - struct stats { - unsigned m_num_flips = 0; - }; - - public: - // encode args <= bound, args = bound, args < bound - struct ineq { - vector> m_args; - ineq_kind m_op = ineq_kind::LE; - int64_t m_bound; - int64_t m_args_value; - unsigned m_var_to_flip = UINT_MAX; - - bool is_true() const { - switch (m_op) { - case ineq_kind::LE: - return m_args_value <= m_bound; - case ineq_kind::EQ: - return m_args_value == m_bound; - case ineq_kind::NE: - return m_args_value != m_bound; - default: - return m_args_value < m_bound; - } - } - std::ostream& display(std::ostream& out) const { - bool first = true; - for (auto const& [c, v] : m_args) - out << (first ? "" : " + ") << c << " * v" << v, first = false; - switch (m_op) { - case ineq_kind::LE: - return out << " <= " << m_bound << "(" << m_args_value << ")"; - case ineq_kind::EQ: - return out << " == " << m_bound << "(" << m_args_value << ")"; - case ineq_kind::NE: - return out << " != " << m_bound << "(" << m_args_value << ")"; - default: - return out << " < " << m_bound << "(" << m_args_value << ")"; - } - } - }; - private: - - struct var_info { - int64_t m_value; - int64_t m_best_value; - var_kind m_kind = var_kind::INT; - svector> m_bool_vars; - }; - - solver& s; - ast_manager& m; - sat::ddfw* m_bool_search = nullptr; - stats m_stats; - config m_config; - scoped_ptr_vector m_bool_vars; - vector m_vars; - svector> m_terms; - bool m_dscore_mode = false; - - - indexed_uint_set& unsat() { return m_bool_search->unsat_set(); } - unsigned num_clauses() const { return m_bool_search->num_clauses(); } - sat::clause& get_clause(unsigned idx) { return *get_clause_info(idx).m_clause; } - sat::clause const& get_clause(unsigned idx) const { return *get_clause_info(idx).m_clause; } - sat::ddfw::clause_info& get_clause_info(unsigned idx) { return m_bool_search->get_clause_info(idx); } - sat::ddfw::clause_info const& get_clause_info(unsigned idx) const { return m_bool_search->get_clause_info(idx); } - bool is_true(sat::literal lit) { return lit.sign() != m_bool_search->get_value(lit.var()); } - bool sign(sat::bool_var v) const { return !m_bool_search->get_value(v); } - - void reset(); - ineq* atom(sat::bool_var bv) const { return m_bool_vars[bv]; } - - bool flip(bool sign, ineq const& ineq); - int64_t dtt(bool sign, ineq const& ineq) const { return dtt(sign, ineq.m_args_value, ineq); } - int64_t dtt(bool sign, int64_t args_value, ineq const& ineq) const; - int64_t dtt(bool sign, ineq const& ineq, var_t v, int64_t new_value) const; - int64_t dtt(bool sign, ineq const& ineq, int64_t coeff, int64_t old_value, int64_t new_value) const; - int64_t dts(unsigned cl, var_t v, int64_t new_value) const; - int64_t compute_dts(unsigned cl) const; - bool cm(bool sign, ineq const& ineq, var_t v, int64_t& new_value); - bool cm(bool sign, ineq const& ineq, var_t v, int64_t coeff, int64_t& new_value); - int cm_score(var_t v, int64_t new_value); - void update(var_t v, int64_t new_value); - double dscore_reward(sat::bool_var v); - double dtt_reward(sat::bool_var v); - double dscore(var_t v, int64_t new_value) const; - void save_best_values(); - void store_best_values(); - void add_vars(); - sls::ineq& new_ineq(ineq_kind op, int64_t const& bound); - void add_arg(sat::bool_var bv, ineq& ineq, int64_t const& c, var_t v); - void add_args(sat::bool_var bv, ineq& ineq, lp::lpvar j, euf::theory_var v, int64_t sign); - void init_bool_var(sat::bool_var v); - void init_bool_var_assignment(sat::bool_var v); - - int64_t value(var_t v) const { return m_vars[v].m_value; } - int64_t to_numeral(rational const& r); - - void check_ineqs(); - - std::ostream& display(std::ostream& out) const; - - public: - sls(solver& s); - void set(sat::ddfw* d); - void init_search() override; - void finish_search() override; - void flip(sat::bool_var v) override; - double reward(sat::bool_var v) override; - void on_rescale() override; - void on_save_model() override; - void on_restart() override; - }; - - inline std::ostream& operator<<(std::ostream& out, sls::ineq const& ineq) { - return ineq.display(out); - } -} diff --git a/src/sat/smt/arith_solver.cpp b/src/sat/smt/arith_solver.cpp index 3086d75f43f..2e5ee58b867 100644 --- a/src/sat/smt/arith_solver.cpp +++ b/src/sat/smt/arith_solver.cpp @@ -24,7 +24,6 @@ namespace arith { solver::solver(euf::solver& ctx, theory_id id) : th_euf_solver(ctx, symbol("arith"), id), m_model_eqs(DEFAULT_HASHTABLE_INITIAL_CAPACITY, var_value_hash(*this), var_value_eq(*this)), - m_local_search(*this), m_resource_limit(*this), m_bp(*this, m_implied_bounds), a(m), @@ -988,10 +987,10 @@ namespace arith { } bool solver::use_nra_model() { - return m_nla && m_nla->use_nra_model(); + return m_nla && m_use_nra_model && m_nla->use_nra_model(); } - bool solver::is_eq(theory_var v1, theory_var v2) { + bool solver::is_eq(theory_var v1, theory_var v2) { if (use_nra_model()) { return m_nla->am().eq(nl_value(v1, m_nla->tmp1()), nl_value(v2, m_nla->tmp2())); } @@ -1006,6 +1005,8 @@ namespace arith { IF_VERBOSE(12, verbose_stream() << "final-check " << lp().get_status() << "\n"); SASSERT(lp().ax_is_correct()); + m_use_nra_model = false; + if (!lp().is_feasible() || lp().has_changed_columns()) { switch (make_feasible()) { case l_false: @@ -1038,8 +1039,12 @@ namespace arith { break; } + if (!check_delayed_eqs()) + return sat::check_result::CR_CONTINUE; + switch (check_nla()) { case l_true: + m_use_nra_model = true; break; case l_false: return sat::check_result::CR_CONTINUE; @@ -1053,7 +1058,8 @@ namespace arith { ++m_stats.m_assume_eqs; return sat::check_result::CR_CONTINUE; } - if (!check_delayed_eqs()) + + if (!check_delayed_eqs()) return sat::check_result::CR_CONTINUE; if (!int_undef && !check_bv_terms()) @@ -1141,6 +1147,7 @@ namespace arith { new_eq_eh(e); else if (is_eq(e.v1(), e.v2())) { mk_diseq_axiom(e.v1(), e.v2()); + TRACE("arith", tout << mk_bounded_pp(e.eq(), m) << " " << use_nra_model() << "\n"); found_diseq = true; break; } @@ -1249,9 +1256,9 @@ namespace arith { for (auto ev : m_explanation) set_evidence(ev.ci()); - TRACE("arith", + TRACE("arith_conflict", tout << "Lemma - " << (is_conflict ? "conflict" : "propagation") << "\n"; - for (literal c : m_core) tout << c << ": " << literal2expr(c) << "\n"; + for (literal c : m_core) tout << c << ": " << literal2expr(c) << " := " << s().value(c) << "\n"; for (auto p : m_eqs) tout << ctx.bpp(p.first) << " == " << ctx.bpp(p.second) << "\n";); if (ctx.get_config().m_arith_validate) @@ -1271,6 +1278,10 @@ namespace arith { m_core.push_back(ctx.mk_literal(m.mk_eq(eq.first->get_expr(), eq.second->get_expr()))); for (literal& c : m_core) c.neg(); + + // it is possible if multiple lemmas are added at the same time. + if (any_of(m_core, [&](literal c) { return s().value(c) == l_true; })) + return; add_redundant(m_core, explain(ty)); } @@ -1508,6 +1519,7 @@ namespace arith { case l_undef: break; } + TRACE("arith", tout << "nla " << r << "\n"); return r; } @@ -1521,10 +1533,13 @@ namespace arith { } for (auto const& ineq : m_nla->literals()) { auto lit = mk_ineq_literal(ineq); + if (s().value(lit) == l_true) + continue; ctx.mark_relevant(lit); s().set_phase(lit); + // verbose_stream() << lit << ":= " << s().value(lit) << "\n"; // force trichotomy axiom for equality literals - if (ineq.cmp() == lp::EQ) { + if (ineq.cmp() == lp::EQ && false) { nla::lemma l; l.push_back(ineq); l.push_back(nla::ineq(lp::LT, ineq.term(), ineq.rs())); diff --git a/src/sat/smt/arith_solver.h b/src/sat/smt/arith_solver.h index d4850e67d72..3c4a588902e 100644 --- a/src/sat/smt/arith_solver.h +++ b/src/sat/smt/arith_solver.h @@ -28,8 +28,6 @@ Module Name: #include "math/polynomial/algebraic_numbers.h" #include "math/polynomial/polynomial.h" #include "sat/smt/sat_th.h" -#include "sat/smt/arith_sls.h" -#include "sat/sat_ddfw.h" namespace euf { class solver; @@ -186,8 +184,6 @@ namespace arith { coeffs().pop_back(); } }; - - sls m_local_search; typedef vector> var_coeffs; vector m_columns; @@ -234,6 +230,7 @@ namespace arith { // non-linear arithmetic scoped_ptr m_nla; + bool m_use_nra_model = false; // integer arithmetic scoped_ptr m_lia; @@ -513,8 +510,6 @@ namespace arith { bool enable_ackerman_axioms(euf::enode* n) const override { return !a.is_add(n->get_expr()); } bool has_unhandled() const override { return m_not_handled != nullptr; } - void set_bool_search(sat::ddfw* ddfw) override { m_local_search.set(ddfw); } - // bounds and equality propagation callbacks lp::lar_solver& lp() { return *m_solver; } lp::lar_solver const& lp() const { return *m_solver; } diff --git a/src/sat/smt/bv_ackerman.cpp b/src/sat/smt/bv_ackerman.cpp index 940b1ebb027..a709c16f3a3 100644 --- a/src/sat/smt/bv_ackerman.cpp +++ b/src/sat/smt/bv_ackerman.cpp @@ -118,8 +118,8 @@ namespace bv { } } - if (glue < max_glue) - v.m_glue = (sz > 6 && 2*glue <= sz) ? 0 : glue; + if (glue < max_glue) + v.m_glue = glue; // (sz > 6 && 2 * glue <= sz) ? 0 : glue; } void ackerman::remove(vv* p) { diff --git a/src/sat/smt/euf_internalize.cpp b/src/sat/smt/euf_internalize.cpp index ebb6e4b85fb..602364e7d25 100644 --- a/src/sat/smt/euf_internalize.cpp +++ b/src/sat/smt/euf_internalize.cpp @@ -525,8 +525,8 @@ namespace euf { return n; } - void solver::add_assertion(expr* f) { - m_assertions.push_back(f); - m_trail.push(push_back_vector(m_assertions)); + void solver::add_clause(unsigned n, sat::literal const* lits) { + m_top_level_clauses.push_back(sat::literal_vector(n, lits)); + m_trail.push(push_back_vector(m_top_level_clauses)); } } diff --git a/src/sat/smt/euf_local_search.cpp b/src/sat/smt/euf_local_search.cpp deleted file mode 100644 index ca450e513e3..00000000000 --- a/src/sat/smt/euf_local_search.cpp +++ /dev/null @@ -1,50 +0,0 @@ -/*++ -Copyright (c) 2020 Microsoft Corporation - -Module Name: - - euf_local_search.cpp - -Abstract: - - Local search dispatch for SMT - -Author: - - Nikolaj Bjorner (nbjorner) 2023-02-07 - ---*/ -#include "sat/sat_solver.h" -#include "sat/sat_ddfw.h" -#include "sat/smt/euf_solver.h" - - -namespace euf { - - lbool solver::local_search(bool_vector& phase) { - scoped_limits scoped_rl(m.limit()); - sat::ddfw bool_search; - bool_search.reinit(s(), phase); - bool_search.updt_params(s().params()); - bool_search.set_seed(rand()); - scoped_rl.push_child(&(bool_search.rlimit())); - - for (auto* th : m_solvers) - th->set_bool_search(&bool_search); - - bool_search.check(0, nullptr, nullptr); - - auto const& mdl = bool_search.get_model(); - for (unsigned i = 0; i < mdl.size(); ++i) - phase[i] = mdl[i] == l_true; - - if (bool_search.unsat_set().empty()) { - enable_trace("arith"); - enable_trace("sat"); - enable_trace("euf"); - TRACE("sat", s().display(tout)); - } - - return bool_search.unsat_set().empty() ? l_true : l_undef; - } -} diff --git a/src/sat/smt/euf_proof_checker.cpp b/src/sat/smt/euf_proof_checker.cpp index c001ee90f94..4351f65813a 100644 --- a/src/sat/smt/euf_proof_checker.cpp +++ b/src/sat/smt/euf_proof_checker.cpp @@ -21,7 +21,7 @@ Module Name: #include "ast/ast_ll_pp.h" #include "ast/arith_decl_plugin.h" #include "smt/smt_solver.h" -#include "sat/sat_params.hpp" +#include "params/sat_params.hpp" #include "sat/smt/euf_proof_checker.h" #include "sat/smt/arith_theory_checker.h" #include "sat/smt/q_theory_checker.h" diff --git a/src/sat/smt/euf_solver.cpp b/src/sat/smt/euf_solver.cpp index b443d836e8b..fbb1025b7d9 100644 --- a/src/sat/smt/euf_solver.cpp +++ b/src/sat/smt/euf_solver.cpp @@ -55,7 +55,6 @@ namespace euf { m_smt_proof_checker(m, p), m_clause(m), m_expr_args(m), - m_assertions(m), m_values(m) { updt_params(p); diff --git a/src/sat/smt/euf_solver.h b/src/sat/smt/euf_solver.h index 650208183c0..06ee45df0cb 100644 --- a/src/sat/smt/euf_solver.h +++ b/src/sat/smt/euf_solver.h @@ -100,15 +100,6 @@ namespace euf { scope(unsigned l) : m_var_lim(l) {} }; - struct local_search_config { - double cb = 0.0; - unsigned L = 20; - unsigned t = 45; - unsigned max_no_improve = 500000; - double sp = 0.0003; - }; - - size_t* to_ptr(sat::literal l) { return TAG(size_t*, reinterpret_cast((size_t)(l.index() << 4)), 1); } size_t* to_ptr(size_t jst) { return TAG(size_t*, reinterpret_cast(jst), 2); } bool is_literal(size_t* p) const { return GET_TAG(p) == 1; } @@ -127,7 +118,6 @@ namespace euf { sat::sat_internalizer& si; relevancy m_relevancy; smt_params m_config; - local_search_config m_ls_config; euf::egraph m_egraph; trail_stack m_trail; stats m_stats; @@ -174,7 +164,7 @@ namespace euf { symbol m_smt = symbol("smt"); expr_ref_vector m_clause; expr_ref_vector m_expr_args; - expr_ref_vector m_assertions; + vector m_top_level_clauses; // internalization @@ -356,7 +346,6 @@ namespace euf { void add_assumptions(sat::literal_set& assumptions) override; bool tracking_assumptions() override; std::string reason_unknown() override { return m_reason_unknown; } - lbool local_search(bool_vector& phase) override; void propagate(literal lit, ext_justification_idx idx); bool propagate(enode* a, enode* b, ext_justification_idx idx); @@ -485,8 +474,10 @@ namespace euf { bool enable_ackerman_axioms(expr* n) const; bool is_fixed(euf::enode* n, expr_ref& val, sat::literal_vector& explain); - void add_assertion(expr* f); - expr_ref_vector const& get_assertions() { return m_assertions; } + // void add_assertion(expr* f); + // expr_ref_vector const& get_assertions() { return m_assertions; } + void add_clause(unsigned n, sat::literal const* lits); + vector const& top_level_clauses() const { return m_top_level_clauses; } model_ref get_sls_model(); // relevancy diff --git a/src/sat/smt/intblast_solver.cpp b/src/sat/smt/intblast_solver.cpp index d6dfe57d7c1..caca1e48ce8 100644 --- a/src/sat/smt/intblast_solver.cpp +++ b/src/sat/smt/intblast_solver.cpp @@ -22,6 +22,10 @@ Module Name: namespace intblast { + void translator_trail::push(push_back_vector const& c) { ctx.push(c); } + void translator_trail::push(push_back_vector> const& c) { ctx.push(c); } + void translator_trail::push_idx(set_vector_idx_trail const& c) { ctx.push(c); } + solver::solver(euf::solver& ctx) : th_euf_solver(ctx, symbol("intblast"), ctx.get_manager().get_family_id("bv")), ctx(ctx), @@ -29,9 +33,8 @@ namespace intblast { m(ctx.get_manager()), bv(m), a(m), - m_translate(m), - m_args(m), - m_pinned(m) + trail(ctx), + m_translator(m, trail) {} euf::theory_var solver::mk_var(euf::enode* n) { @@ -85,49 +88,22 @@ namespace intblast { SASSERT(!n->is_attached_to(get_id())); mk_var(n); SASSERT(n->is_attached_to(get_id())); - internalize_bv(a); + m_translator.internalize_bv(a); return true; } void solver::eq_internalized(euf::enode* n) { - expr* e = n->get_expr(); - expr* x = nullptr, * y = nullptr; - VERIFY(m.is_eq(n->get_expr(), x, y)); - SASSERT(bv.is_bv(x)); - if (!is_translated(e)) { - ensure_translated(x); - ensure_translated(y); - m_args.reset(); - m_args.push_back(a.mk_sub(translated(x), translated(y))); - set_translated(e, m.mk_eq(umod(x, 0), a.mk_int(0))); - } - m_preds.push_back(e); - TRACE("bv", tout << mk_pp(e, m) << " " << mk_pp(translated(e), m) << "\n"); - ctx.push(push_back_vector(m_preds)); - } - - void solver::set_translated(expr* e, expr* r) { - SASSERT(r); - SASSERT(!is_translated(e)); - m_translate.setx(e->get_id(), r); - ctx.push(set_vector_idx_trail(m_translate, e->get_id())); - } - - void solver::internalize_bv(app* e) { - ensure_translated(e); - if (m.is_bool(e)) { - m_preds.push_back(e); - ctx.push(push_back_vector(m_preds)); - } + m_translator.translate_eq(n->get_expr()); } bool solver::add_bound_axioms() { - if (m_vars_qhead == m_vars.size()) + auto const& vars = m_translator.vars(); + if (m_vars_qhead == vars.size()) return false; ctx.push(value_trail(m_vars_qhead)); - for (; m_vars_qhead < m_vars.size(); ++m_vars_qhead) { - auto v = m_vars[m_vars_qhead]; - auto w = translated(v); + for (; m_vars_qhead < vars.size(); ++m_vars_qhead) { + auto v = vars[m_vars_qhead]; + auto w = m_translator.translated(v); auto sz = rational::power_of_two(bv.get_bv_size(v->get_sort())); auto lo = ctx.mk_literal(a.mk_ge(w, a.mk_int(0))); auto hi = ctx.mk_literal(a.mk_le(w, a.mk_int(sz - 1))); @@ -140,12 +116,13 @@ namespace intblast { } bool solver::add_predicate_axioms() { - if (m_preds_qhead == m_preds.size()) + auto const& preds = m_translator.preds(); + if (m_preds_qhead == preds.size()) return false; ctx.push(value_trail(m_preds_qhead)); - for (; m_preds_qhead < m_preds.size(); ++m_preds_qhead) { - expr* e = m_preds[m_preds_qhead]; - expr_ref r(translated(e), m); + for (; m_preds_qhead < preds.size(); ++m_preds_qhead) { + expr* e = preds[m_preds_qhead]; + expr_ref r(m_translator.translated(e), m); ctx.get_rewriter()(r); auto a = expr2literal(e); auto b = mk_literal(r); @@ -158,31 +135,7 @@ namespace intblast { bool solver::unit_propagate() { return add_bound_axioms() || add_predicate_axioms(); - } - void solver::ensure_translated(expr* e) { - if (m_translate.get(e->get_id(), nullptr)) - return; - ptr_vector todo; - ast_fast_mark1 visited; - todo.push_back(e); - visited.mark(e); - for (unsigned i = 0; i < todo.size(); ++i) { - expr* e = todo[i]; - if (!is_app(e)) - continue; - app* a = to_app(e); - if (m.is_bool(e) && a->get_family_id() != bv.get_family_id()) - continue; - for (auto arg : *a) - if (!visited.is_marked(arg) && !m_translate.get(arg->get_id(), nullptr)) { - visited.mark(arg); - todo.push_back(arg); - } - } - std::stable_sort(todo.begin(), todo.end(), [&](expr* a, expr* b) { return get_depth(a) < get_depth(b); }); - for (expr* e : todo) - translate_expr(e); - } + } lbool solver::check_axiom(sat::literal_vector const& lits) { sat::literal_vector core; @@ -198,14 +151,10 @@ namespace intblast { } lbool solver::check_core(sat::literal_vector const& lits, euf::enode_pair_vector const& eqs) { - m_core.reset(); - m_vars.reset(); m_is_plugin = false; + m_translator.reset(false); m_solver = mk_smt2_solver(m, s.params(), symbol::null); - for (unsigned i = 0; i < m_translate.size(); ++i) - m_translate[i] = nullptr; - expr_ref_vector es(m), original_es(m); for (auto lit : lits) es.push_back(ctx.literal2expr(lit)); @@ -222,8 +171,8 @@ namespace intblast { translate(es); - for (auto e : m_vars) { - auto v = translated(e); + for (auto e : m_translator.vars()) { + auto v = m_translator.translated(e); auto b = rational::power_of_two(bv.get_bv_size(e)); m_solver->assert_expr(a.mk_le(a.mk_int(0), v)); m_solver->assert_expr(a.mk_lt(v, a.mk_int(b))); @@ -331,8 +280,8 @@ namespace intblast { translate(es); - for (auto e : m_vars) { - auto v = translated(e); + for (auto e : m_translator.vars()) { + auto v = m_translator.translated(e); auto b = rational::power_of_two(bv.get_bv_size(e)); m_solver->assert_expr(a.mk_le(a.mk_int(0), v)); m_solver->assert_expr(a.mk_lt(v, a.mk_int(b))); @@ -377,7 +326,7 @@ namespace intblast { void solver::sorted_subterms(expr_ref_vector& es, ptr_vector& sorted) { expr_fast_mark1 visited; for (expr* e : es) { - if (is_translated(e)) + if (m_translator.is_translated(e)) continue; if (visited.is_marked(e)) continue; @@ -389,7 +338,7 @@ namespace intblast { if (is_app(e)) { app* a = to_app(e); for (expr* arg : *a) { - if (!visited.is_marked(arg) && !is_translated(arg)) { + if (!visited.is_marked(arg) && !m_translator.is_translated(arg)) { visited.mark(arg); sorted.push_back(arg); } @@ -399,7 +348,7 @@ namespace intblast { else if (is_quantifier(e)) { quantifier* q = to_quantifier(e); expr* b = q->get_expr(); - if (!visited.is_marked(b) && !is_translated(b)) { + if (!visited.is_marked(b) && !m_translator.is_translated(b)) { visited.mark(b); sorted.push_back(b); } @@ -414,20 +363,20 @@ namespace intblast { sorted_subterms(es, todo); for (expr* e : todo) - translate_expr(e); + m_translator.translate_expr(e); TRACE("bv", for (expr* e : es) - tout << mk_pp(e, m) << "\n->\n" << mk_pp(translated(e), m) << "\n"; + tout << mk_pp(e, m) << "\n->\n" << mk_pp(m_translator.translated(e), m) << "\n"; ); for (unsigned i = 0; i < es.size(); ++i) - es[i] = translated(es.get(i)); + es[i] = m_translator.translated(es.get(i)); } sat::check_result solver::check() { // ensure that bv2int is injective - for (auto e : m_bv2int) { + for (auto e : m_translator.bv2int()) { euf::enode* n = expr2enode(e); euf::enode* r1 = n->get_arg(0)->get_root(); for (auto sib : euf::enode_class(n)) { @@ -449,7 +398,7 @@ namespace intblast { } // ensure that int2bv respects values // bv2int(int2bv(x)) = x mod N - for (auto e : m_int2bv) { + for (auto e : m_translator.int2bv()) { auto n = expr2enode(e); auto x = n->get_arg(0)->get_expr(); auto bv2int = bv.mk_bv2int(e); @@ -469,595 +418,12 @@ namespace intblast { return sat::check_result::CR_DONE; } - bool solver::is_bounded(expr* x, rational const& N) { - return any_of(m_vars, [&](expr* v) { - return is_translated(v) && translated(v) == x && bv_size(v) <= N; - }); - } - - bool solver::is_non_negative(expr* bv_expr, expr* e) { - auto N = rational::power_of_two(bv.get_bv_size(bv_expr)); - rational r; - if (a.is_numeral(e, r)) - return r >= 0; - if (is_bounded(e, N)) - return true; - expr* x = nullptr, * y = nullptr; - if (a.is_mul(e, x, y)) - return is_non_negative(bv_expr, x) && is_non_negative(bv_expr, y); - if (a.is_add(e, x, y)) - return is_non_negative(bv_expr, x) && is_non_negative(bv_expr, y); - return false; - } - - expr* solver::umod(expr* bv_expr, unsigned i) { - expr* x = arg(i); - rational N = bv_size(bv_expr); - return amod(bv_expr, x, N); - } - - expr* solver::smod(expr* bv_expr, unsigned i) { - expr* x = arg(i); - auto N = bv_size(bv_expr); - auto shift = N / 2; - rational r; - if (a.is_numeral(x, r)) - return a.mk_int(mod(r + shift, N)); - return amod(bv_expr, add(x, a.mk_int(shift)), N); - } - - expr_ref solver::mul(expr* x, expr* y) { - expr_ref _x(x, m), _y(y, m); - if (a.is_zero(x)) - return _x; - if (a.is_zero(y)) - return _y; - if (a.is_one(x)) - return _y; - if (a.is_one(y)) - return _x; - rational v1, v2; - if (a.is_numeral(x, v1) && a.is_numeral(y, v2)) - return expr_ref(a.mk_int(v1 * v2), m); - _x = a.mk_mul(x, y); - return _x; - } - - expr_ref solver::add(expr* x, expr* y) { - expr_ref _x(x, m), _y(y, m); - if (a.is_zero(x)) - return _y; - if (a.is_zero(y)) - return _x; - rational v1, v2; - if (a.is_numeral(x, v1) && a.is_numeral(y, v2)) - return expr_ref(a.mk_int(v1 + v2), m); - _x = a.mk_add(x, y); - return _x; - } - - /* - * Perform simplifications that are claimed sound when the bit-vector interpretations of - * mod/div always guard the mod and dividend to be non-zero. - * Potentially shady area is for arithmetic expressions created by int2bv. - * They will be guarded by a modulus which does not disappear. - */ - expr* solver::amod(expr* bv_expr, expr* x, rational const& N) { - rational v; - expr* r = nullptr, *c = nullptr, * t = nullptr, * e = nullptr; - if (m.is_ite(x, c, t, e)) - r = m.mk_ite(c, amod(bv_expr, t, N), amod(bv_expr, e, N)); - else if (a.is_idiv(x, t, e) && a.is_numeral(t, v) && 0 <= v && v < N && is_non_negative(bv_expr, e)) - r = x; - else if (a.is_mod(x, t, e) && a.is_numeral(t, v) && 0 <= v && v < N) - r = x; - else if (a.is_numeral(x, v)) - r = a.mk_int(mod(v, N)); - else if (is_bounded(x, N)) - r = x; - else - r = a.mk_mod(x, a.mk_int(N)); - return r; - } - - rational solver::bv_size(expr* bv_expr) { - return rational::power_of_two(bv.get_bv_size(bv_expr->get_sort())); - } - - void solver::translate_expr(expr* e) { - if (is_quantifier(e)) - translate_quantifier(to_quantifier(e)); - else if (is_var(e)) - translate_var(to_var(e)); - else { - app* ap = to_app(e); - if (m_is_plugin && ap->get_family_id() == basic_family_id && m.is_bool(ap)) { - set_translated(e, e); - return; - } - m_args.reset(); - for (auto arg : *ap) - m_args.push_back(translated(arg)); - - if (ap->get_family_id() == basic_family_id) - translate_basic(ap); - else if (ap->get_family_id() == bv.get_family_id()) - translate_bv(ap); - else - translate_app(ap); - } - } - - void solver::translate_quantifier(quantifier* q) { - if (m_is_plugin) { - set_translated(q, q); - return; - } - if (is_lambda(q)) - throw default_exception("lambdas are not supported in intblaster"); - expr* b = q->get_expr(); - unsigned nd = q->get_num_decls(); - ptr_vector sorts; - for (unsigned i = 0; i < nd; ++i) { - auto s = q->get_decl_sort(i); - if (bv.is_bv_sort(s)) { - NOT_IMPLEMENTED_YET(); - sorts.push_back(a.mk_int()); - } - else - sorts.push_back(s); - } - b = translated(b); - // TODO if sorts contain integer, then created bounds variables. - set_translated(q, m.update_quantifier(q, b)); - } - - void solver::translate_var(var* v) { - if (bv.is_bv_sort(v->get_sort())) - set_translated(v, m.mk_var(v->get_idx(), a.mk_int())); - else - set_translated(v, v); - } - - // Translate functions that are not built-in or bit-vectors. - // Base method uses fresh functions. - // Other method could use bv2int, int2bv axioms and coercions. - // f(args) = bv2int(f(int2bv(args')) - // - - void solver::translate_app(app* e) { - - if (m_is_plugin && m.is_bool(e)) { - set_translated(e, e); - return; - } - - bool has_bv_sort = bv.is_bv(e); - func_decl* f = e->get_decl(); - - for (unsigned i = 0; i < m_args.size(); ++i) - if (bv.is_bv(e->get_arg(i))) - m_args[i] = bv.mk_int2bv(bv.get_bv_size(e->get_arg(i)), m_args.get(i)); - - if (has_bv_sort) - m_vars.push_back(e); - if (m_is_plugin) { - expr* r = m.mk_app(f, m_args); - if (has_bv_sort) { - ctx.push(push_back_vector(m_vars)); - r = bv.mk_bv2int(r); - } - set_translated(e, r); - return; - } - else if (has_bv_sort) { - if (f->get_family_id() != null_family_id) - throw default_exception("conversion for interpreted functions is not supported by intblast solver"); - func_decl* g = nullptr; - if (!m_new_funs.find(f, g)) { - g = m.mk_fresh_func_decl(e->get_decl()->get_name(), symbol("bv"), f->get_arity(), f->get_domain(), a.mk_int()); - m_new_funs.insert(f, g); - } - f = g; - m_pinned.push_back(f); - } - set_translated(e, m.mk_app(f, m_args)); - } - - void solver::translate_bv(app* e) { - - auto bnot = [&](expr* e) { - return a.mk_sub(a.mk_int(-1), e); - }; - - auto band = [&](expr_ref_vector const& args) { - expr* r = arg(0); - for (unsigned i = 1; i < args.size(); ++i) - r = a.mk_band(bv.get_bv_size(e), r, arg(i)); - return r; - }; - - auto rotate_left = [&](unsigned n) { - auto sz = bv.get_bv_size(e); - n = n % sz; - expr* r = arg(0); - if (n != 0 && sz != 1) { - // r[sz - n - 1 : 0] ++ r[sz - 1 : sz - n] - // r * 2^(sz - n) + (r div 2^n) mod 2^(sz - n)??? - // r * A + (r div B) mod A - auto N = bv_size(e); - auto A = rational::power_of_two(sz - n); - auto B = rational::power_of_two(n); - auto hi = mul(r, a.mk_int(A)); - auto lo = amod(e, a.mk_idiv(umod(e, 0), a.mk_int(B)), A); - r = add(hi, lo); - } - return r; - }; - - expr* bv_expr = e; - expr_ref r(m); - auto const& args = m_args; - switch (e->get_decl_kind()) { - case OP_BADD: - r = a.mk_add(args); - break; - case OP_BSUB: - r = a.mk_sub(args.size(), args.data()); - break; - case OP_BMUL: - r = a.mk_mul(args); - break; - case OP_ULEQ: - bv_expr = e->get_arg(0); - r = a.mk_le(umod(bv_expr, 0), umod(bv_expr, 1)); - break; - case OP_UGEQ: - bv_expr = e->get_arg(0); - r = a.mk_ge(umod(bv_expr, 0), umod(bv_expr, 1)); - break; - case OP_ULT: - bv_expr = e->get_arg(0); - r = a.mk_lt(umod(bv_expr, 0), umod(bv_expr, 1)); - break; - case OP_UGT: - bv_expr = e->get_arg(0); - r = a.mk_gt(umod(bv_expr, 0), umod(bv_expr, 1)); - break; - case OP_SLEQ: - bv_expr = e->get_arg(0); - r = a.mk_le(smod(bv_expr, 0), smod(bv_expr, 1)); - break; - case OP_SGEQ: - bv_expr = e->get_arg(0); - r = a.mk_ge(smod(bv_expr, 0), smod(bv_expr, 1)); - break; - case OP_SLT: - bv_expr = e->get_arg(0); - r = a.mk_lt(smod(bv_expr, 0), smod(bv_expr, 1)); - break; - case OP_SGT: - bv_expr = e->get_arg(0); - r = a.mk_gt(smod(bv_expr, 0), smod(bv_expr, 1)); - break; - case OP_BNEG: - r = a.mk_uminus(arg(0)); - break; - case OP_CONCAT: { - unsigned sz = 0; - expr_ref new_arg(m); - for (unsigned i = args.size(); i-- > 0;) { - expr* old_arg = e->get_arg(i); - new_arg = umod(old_arg, i); - if (sz > 0) { - new_arg = mul(new_arg, a.mk_int(rational::power_of_two(sz))); - r = add(r, new_arg); - } - else - r = new_arg; - sz += bv.get_bv_size(old_arg->get_sort()); - } - break; - } - case OP_EXTRACT: { - unsigned lo, hi; - expr* old_arg; - VERIFY(bv.is_extract(e, lo, hi, old_arg)); - r = arg(0); - if (lo > 0) - r = a.mk_idiv(r, a.mk_int(rational::power_of_two(lo))); - break; - } - case OP_BV_NUM: { - rational val; - unsigned sz; - VERIFY(bv.is_numeral(e, val, sz)); - r = a.mk_int(val); - break; - } - case OP_BUREM: - case OP_BUREM_I: { - expr* x = umod(e, 0), * y = umod(e, 1); - r = if_eq(y, 0, x, a.mk_mod(x, y)); - break; - } - case OP_BUDIV: - case OP_BUDIV_I: { - expr* x = umod(e, 0), * y = umod(e, 1); - r = if_eq(y, 0, a.mk_int(-1), a.mk_idiv(x, y)); - break; - } - case OP_BUMUL_NO_OVFL: { - bv_expr = e->get_arg(0); - r = a.mk_lt(mul(umod(bv_expr, 0), umod(bv_expr, 1)), a.mk_int(bv_size(bv_expr))); - break; - } - case OP_BSHL: { - if (!a.is_numeral(arg(0)) && !a.is_numeral(arg(1))) - r = a.mk_shl(bv.get_bv_size(e), arg(0),arg(1)); - else { - expr* x = arg(0), * y = umod(e, 1); - r = a.mk_int(0); - IF_VERBOSE(2, verbose_stream() << "shl " << mk_bounded_pp(e, m) << " " << bv.get_bv_size(e) << "\n"); - for (unsigned i = 0; i < bv.get_bv_size(e); ++i) - r = if_eq(y, i, mul(x, a.mk_int(rational::power_of_two(i))), r); - } - break; - } - case OP_BNOT: - r = bnot(arg(0)); - break; - case OP_BLSHR: - if (!a.is_numeral(arg(0)) && !a.is_numeral(arg(1))) - r = a.mk_lshr(bv.get_bv_size(e), arg(0), arg(1)); - else { - expr* x = arg(0), * y = umod(e, 1); - r = a.mk_int(0); - IF_VERBOSE(2, verbose_stream() << "lshr " << mk_bounded_pp(e, m) << " " << bv.get_bv_size(e) << "\n"); - for (unsigned i = 0; i < bv.get_bv_size(e); ++i) - r = if_eq(y, i, a.mk_idiv(x, a.mk_int(rational::power_of_two(i))), r); - } - break; - case OP_BASHR: - if (!a.is_numeral(arg(1))) - r = a.mk_ashr(bv.get_bv_size(e), arg(0), arg(1)); - else { - - // - // ashr(x, y) - // if y = k & x >= 0 -> x / 2^k - // if y = k & x < 0 -> (x / 2^k) - 2^{N-k} - // - unsigned sz = bv.get_bv_size(e); - rational N = bv_size(e); - expr* x = umod(e, 0), *y = umod(e, 1); - expr* signx = a.mk_ge(x, a.mk_int(N / 2)); - r = m.mk_ite(signx, a.mk_int(- 1), a.mk_int(0)); - IF_VERBOSE(1, verbose_stream() << "ashr " << mk_bounded_pp(e, m) << " " << bv.get_bv_size(e) << "\n"); - for (unsigned i = 0; i < sz; ++i) { - expr* d = a.mk_idiv(x, a.mk_int(rational::power_of_two(i))); - r = if_eq(y, i, - m.mk_ite(signx, add(d, a.mk_int(- rational::power_of_two(sz-i))), d), - r); - } - } - break; - case OP_BOR: - // p | q := (p + q) - band(p, q) - IF_VERBOSE(2, verbose_stream() << "bor " << mk_bounded_pp(e, m) << " " << bv.get_bv_size(e) << "\n"); - r = arg(0); - for (unsigned i = 1; i < args.size(); ++i) - r = a.mk_sub(add(r, arg(i)), a.mk_band(bv.get_bv_size(e), r, arg(i))); - break; - case OP_BNAND: - r = bnot(band(args)); - break; - case OP_BAND: - IF_VERBOSE(2, verbose_stream() << "band " << mk_bounded_pp(e, m) << " " << bv.get_bv_size(e) << "\n"); - r = band(args); - break; - case OP_BXNOR: - case OP_BXOR: { - // p ^ q := (p + q) - 2*band(p, q); - unsigned sz = bv.get_bv_size(e); - IF_VERBOSE(2, verbose_stream() << "bxor " << bv.get_bv_size(e) << "\n"); - r = arg(0); - for (unsigned i = 1; i < args.size(); ++i) { - expr* q = arg(i); - r = a.mk_sub(add(r, q), mul(a.mk_int(2), a.mk_band(sz, r, q))); - } - if (e->get_decl_kind() == OP_BXNOR) - r = bnot(r); - break; - } - case OP_ZERO_EXT: - bv_expr = e->get_arg(0); - r = umod(bv_expr, 0); - SASSERT(bv.get_bv_size(e) >= bv.get_bv_size(bv_expr)); - break; - case OP_SIGN_EXT: { - bv_expr = e->get_arg(0); - r = umod(bv_expr, 0); - SASSERT(bv.get_bv_size(e) >= bv.get_bv_size(bv_expr)); - unsigned arg_sz = bv.get_bv_size(bv_expr); - //unsigned sz = bv.get_bv_size(e); - // rational N = rational::power_of_two(sz); - rational M = rational::power_of_two(arg_sz); - expr* signbit = a.mk_ge(r, a.mk_int(M / 2)); - r = m.mk_ite(signbit, a.mk_sub(r, a.mk_int(M)), r); - break; - } - case OP_INT2BV: - m_int2bv.push_back(e); - ctx.push(push_back_vector(m_int2bv)); - r = arg(0); - break; - case OP_BV2INT: - m_bv2int.push_back(e); - ctx.push(push_back_vector(m_bv2int)); - r = umod(e->get_arg(0), 0); - break; - case OP_BCOMP: - bv_expr = e->get_arg(0); - r = m.mk_ite(m.mk_eq(umod(bv_expr, 0), umod(bv_expr, 1)), a.mk_int(1), a.mk_int(0)); - break; - case OP_BSMOD_I: - case OP_BSMOD: { - expr* x = umod(e, 0), *y = umod(e, 1); - rational N = bv_size(e); - expr* signx = a.mk_ge(x, a.mk_int(N/2)); - expr* signy = a.mk_ge(y, a.mk_int(N/2)); - expr* u = a.mk_mod(x, y); - // u = 0 -> 0 - // y = 0 -> x - // x < 0, y < 0 -> -u - // x < 0, y >= 0 -> y - u - // x >= 0, y < 0 -> y + u - // x >= 0, y >= 0 -> u - r = a.mk_uminus(u); - r = m.mk_ite(m.mk_and(m.mk_not(signx), signy), add(u, y), r); - r = m.mk_ite(m.mk_and(signx, m.mk_not(signy)), a.mk_sub(y, u), r); - r = m.mk_ite(m.mk_and(m.mk_not(signx), m.mk_not(signy)), u, r); - r = if_eq(u, 0, a.mk_int(0), r); - r = if_eq(y, 0, x, r); - break; - } - case OP_BSDIV_I: - case OP_BSDIV: { - // d = udiv(abs(x), abs(y)) - // y = 0, x > 0 -> 1 - // y = 0, x <= 0 -> -1 - // x = 0, y != 0 -> 0 - // x > 0, y < 0 -> -d - // x < 0, y > 0 -> -d - // x > 0, y > 0 -> d - // x < 0, y < 0 -> d - expr* x = umod(e, 0), * y = umod(e, 1); - rational N = bv_size(e); - expr* signx = a.mk_ge(x, a.mk_int(N / 2)); - expr* signy = a.mk_ge(y, a.mk_int(N / 2)); - x = m.mk_ite(signx, a.mk_sub(a.mk_int(N), x), x); - y = m.mk_ite(signy, a.mk_sub(a.mk_int(N), y), y); - expr* d = a.mk_idiv(x, y); - r = m.mk_ite(m.mk_iff(signx, signy), d, a.mk_uminus(d)); - r = if_eq(y, 0, m.mk_ite(signx, a.mk_int(1), a.mk_int(-1)), r); - break; - } - case OP_BSREM_I: - case OP_BSREM: { - // y = 0 -> x - // else x - sdiv(x, y) * y - expr* x = umod(e, 0), * y = umod(e, 1); - rational N = bv_size(e); - expr* signx = a.mk_ge(x, a.mk_int(N / 2)); - expr* signy = a.mk_ge(y, a.mk_int(N / 2)); - expr* absx = m.mk_ite(signx, a.mk_sub(a.mk_int(N), x), x); - expr* absy = m.mk_ite(signy, a.mk_sub(a.mk_int(N), y), y); - expr* d = a.mk_idiv(absx, absy); - d = m.mk_ite(m.mk_iff(signx, signy), d, a.mk_uminus(d)); - r = a.mk_sub(x, mul(d, y)); - r = if_eq(y, 0, x, r); - break; - } - case OP_ROTATE_LEFT: { - auto n = e->get_parameter(0).get_int(); - r = rotate_left(n); - break; - } - case OP_ROTATE_RIGHT: { - unsigned sz = bv.get_bv_size(e); - auto n = e->get_parameter(0).get_int(); - r = rotate_left(sz - n); - break; - } - case OP_EXT_ROTATE_LEFT: { - unsigned sz = bv.get_bv_size(e); - expr* y = umod(e, 1); - r = a.mk_int(0); - for (unsigned i = 0; i < sz; ++i) - r = if_eq(y, i, rotate_left(i), r); - break; - } - case OP_EXT_ROTATE_RIGHT: { - unsigned sz = bv.get_bv_size(e); - expr* y = umod(e, 1); - r = a.mk_int(0); - for (unsigned i = 0; i < sz; ++i) - r = if_eq(y, i, rotate_left(sz - i), r); - break; - } - case OP_REPEAT: { - unsigned n = e->get_parameter(0).get_int(); - expr* x = umod(e->get_arg(0), 0); - r = x; - rational N = bv_size(e->get_arg(0)); - rational N0 = N; - for (unsigned i = 1; i < n; ++i) - r = add(mul(a.mk_int(N), x), r), N *= N0; - break; - } - case OP_BREDOR: { - r = umod(e->get_arg(0), 0); - r = m.mk_not(m.mk_eq(r, a.mk_int(0))); - break; - } - case OP_BREDAND: { - rational N = bv_size(e->get_arg(0)); - r = umod(e->get_arg(0), 0); - r = m.mk_not(m.mk_eq(r, a.mk_int(N - 1))); - break; - } - default: - verbose_stream() << mk_pp(e, m) << "\n"; - NOT_IMPLEMENTED_YET(); - } - set_translated(e, r); - } - - expr_ref solver::if_eq(expr* n, unsigned k, expr* th, expr* el) { - rational r; - expr_ref _th(th, m), _el(el, m); - if (bv.is_numeral(n, r)) { - if (r == k) - return expr_ref(th, m); - else - return expr_ref(el, m); - } - return expr_ref(m.mk_ite(m.mk_eq(n, a.mk_int(k)), th, el), m); - } - - void solver::translate_basic(app* e) { - if (m.is_eq(e)) { - bool has_bv_arg = any_of(*e, [&](expr* arg) { return bv.is_bv(arg); }); - if (has_bv_arg) { - expr* bv_expr = e->get_arg(0); - rational N = rational::power_of_two(bv.get_bv_size(bv_expr)); - if (a.is_numeral(arg(0)) || a.is_numeral(arg(1)) || - is_bounded(arg(0), N) || is_bounded(arg(1), N)) { - set_translated(e, m.mk_eq(umod(bv_expr, 0), umod(bv_expr, 1))); - } - else { - m_args[0] = a.mk_sub(arg(0), arg(1)); - set_translated(e, m.mk_eq(umod(bv_expr, 0), a.mk_int(0))); - } - } - else - set_translated(e, m.mk_eq(arg(0), arg(1))); - } - else if (m.is_ite(e)) - set_translated(e, m.mk_ite(arg(0), arg(1), arg(2))); - else if (m_is_plugin) - set_translated(e, e); - else - set_translated(e, m.mk_app(e->get_decl(), m_args)); - } - rational solver::get_value(expr* e) const { SASSERT(bv.is_bv(e)); model_ref mdl; m_solver->get_model(mdl); expr_ref r(m); - r = translated(e); + r = m_translator.translated(e); rational val; if (!mdl->eval_expr(r, r, true)) return rational::zero(); @@ -1099,7 +465,7 @@ namespace intblast { } rational r, N = rational::power_of_two(bv.get_bv_size(e)); - expr* te = translated(e); + expr* te = m_translator.translated(e); model_ref mdlr; m_solver->get_model(mdlr); expr_ref value(m); @@ -1126,14 +492,12 @@ namespace intblast { else { expr_ref bv2int(bv.mk_bv2int(n->get_expr()), m); euf::enode* b2i = ctx.get_enode(bv2int); - if (!b2i) verbose_stream() << bv2int << "\n"; SASSERT(b2i); VERIFY(b2i); arith::arith_value av(ctx); rational r; VERIFY(av.get_value(b2i->get_expr(), r)); value = bv.mk_numeral(r, bv.get_bv_size(n->get_expr())); - verbose_stream() << ctx.bpp(n) << " := " << value << "\n"; } values.set(n->get_root_id(), value); TRACE("model", tout << "add_value " << ctx.bpp(n) << " := " << value << "\n"); @@ -1143,11 +507,11 @@ namespace intblast { return; for (auto n : ctx.get_egraph().nodes()) { auto e = n->get_expr(); - if (!is_translated(e)) + if (!m_translator.is_translated(e)) continue; if (!bv.is_bv(e)) continue; - auto t = translated(e); + auto t = m_translator.translated(e); expr_ref ei(bv.mk_bv2int(e), m); expr_ref ti(a.mk_mod(t, a.mk_int(rational::power_of_two(bv.get_bv_size(e)))), m); diff --git a/src/sat/smt/intblast_solver.h b/src/sat/smt/intblast_solver.h index ee8d6fb1935..d840e389f34 100644 --- a/src/sat/smt/intblast_solver.h +++ b/src/sat/smt/intblast_solver.h @@ -38,6 +38,7 @@ Module Name: #include "solver/solver.h" #include "sat/smt/sat_th.h" #include "util/statistics.h" +#include "ast/rewriter/bv2int_translator.h" namespace euf { class solver; @@ -45,18 +46,31 @@ namespace euf { namespace intblast { + class translator_trail : public bv2int_translator_trail { + euf::solver& ctx; + public: + translator_trail(euf::solver& ctx):ctx(ctx) {} + void push(push_back_vector const& c) override; + void push(push_back_vector> const& c) override; + void push_idx(set_vector_idx_trail const& c) override; + }; + class solver : public euf::th_euf_solver { euf::solver& ctx; sat::solver& s; ast_manager& m; bv_util bv; arith_util a; + translator_trail trail; + bv2int_translator m_translator; + scoped_ptr<::solver> m_solver; - obj_map m_new_funs; - expr_ref_vector m_translate, m_args; - ast_ref_vector m_pinned; + + //obj_map m_new_funs; + //expr_ref_vector m_translate, m_args; + //ast_ref_vector m_pinned; sat::literal_vector m_core; - ptr_vector m_bv2int, m_int2bv; + // ptr_vector m_bv2int, m_int2bv; statistics m_stats; bool m_is_plugin = true; // when the solver is used as a plugin, then do not translate below quantifiers. @@ -66,33 +80,6 @@ namespace intblast { - bool is_translated(expr* e) const { return !!m_translate.get(e->get_id(), nullptr); } - expr* translated(expr* e) const { expr* r = m_translate.get(e->get_id(), nullptr); SASSERT(r); return r; } - void set_translated(expr* e, expr* r); - expr* arg(unsigned i) { return m_args.get(i); } - - expr* umod(expr* bv_expr, unsigned i); - expr* smod(expr* bv_expr, unsigned i); - bool is_bounded(expr* v, rational const& N); - bool is_non_negative(expr* bv_expr, expr* e); - expr_ref mul(expr* x, expr* y); - expr_ref add(expr* x, expr* y); - expr_ref if_eq(expr* n, unsigned k, expr* th, expr* el); - expr* amod(expr* bv_expr, expr* x, rational const& N); - rational bv_size(expr* bv_expr); - - void translate_expr(expr* e); - void translate_bv(app* e); - void translate_basic(app* e); - void translate_app(app* e); - void translate_quantifier(quantifier* q); - void translate_var(var* v); - - void ensure_translated(expr* e); - void internalize_bv(app* e); - - unsigned m_vars_qhead = 0, m_preds_qhead = 0; - ptr_vector m_vars, m_preds; bool add_bound_axioms(); bool add_predicate_axioms(); @@ -101,6 +88,9 @@ namespace intblast { void add_value_plugin(euf::enode* n, model& mdl, expr_ref_vector& values); void add_value_solver(euf::enode* n, model& mdl, expr_ref_vector& values); + unsigned m_vars_qhead = 0, m_preds_qhead = 0; + + public: solver(euf::solver& ctx); diff --git a/src/sat/smt/sat_th.h b/src/sat/smt/sat_th.h index 24226eede6f..8228f1ce4af 100644 --- a/src/sat/smt/sat_th.h +++ b/src/sat/smt/sat_th.h @@ -18,7 +18,6 @@ Module Name: #include "util/top_sort.h" #include "sat/smt/sat_smt.h" -#include "sat/sat_ddfw.h" #include "ast/euf/euf_egraph.h" #include "model/model.h" #include "smt/params/smt_params.h" @@ -139,10 +138,6 @@ namespace euf { virtual euf::enode_pair get_justification_eq(size_t j); - /** - * Local search interface - */ - virtual void set_bool_search(sat::ddfw* ddfw) {} virtual void set_bounds_begin() {} diff --git a/src/sat/smt/sls_solver.cpp b/src/sat/smt/sls_solver.cpp index a507619ee32..0eb01ff8518 100644 --- a/src/sat/smt/sls_solver.cpp +++ b/src/sat/smt/sls_solver.cpp @@ -13,154 +13,118 @@ Module Name: Nikolaj Bjorner (nbjorner) 2024-02-21 + --*/ #include "sat/smt/sls_solver.h" #include "sat/smt/euf_solver.h" - - +#include "ast/sls/sls_context.h" +#include "ast/for_each_expr.h" namespace sls { -#ifdef SINGLE_THREAD solver::solver(euf::solver& ctx) : th_euf_solver(ctx, symbol("sls"), ctx.get_manager().mk_family_id("sls")) - {} + {} + +#ifdef SINGLE_THREAD #else - solver::solver(euf::solver& ctx): - th_euf_solver(ctx, symbol("sls"), ctx.get_manager().mk_family_id("sls")) - {} solver::~solver() { finalize(); } - void solver::finalize() { - if (!m_completed && m_sls) { - m_sls->cancel(); - m_thread.join(); - m_sls->collect_statistics(m_st); - m_sls = nullptr; - m_shared = nullptr; - m_slsm = nullptr; - m_units = nullptr; - } + params_ref solver::get_params() { + return s().params(); } - sat::check_result solver::check() { - return sat::check_result::CR_DONE; + void solver::initialize_value(expr* t, expr* v) { + ctx.user_propagate_initialize_value(t, v); } - bool solver::unit_propagate() { - force_push(); - sample_local_search(); - return false; + void solver::force_phase(sat::literal lit) { + ctx.s().set_phase(lit); } - bool solver::is_unit(expr* e) { - if (!e) - return false; - m.is_not(e, e); - if (is_uninterp_const(e)) - return true; - bv_util bu(m); - expr* s; - if (bu.is_bit2bool(e, s)) - return is_uninterp_const(s); + void solver::set_has_new_best_phase(bool b) { + + } + + bool solver::get_best_phase(sat::bool_var v) { return false; } - void solver::pop_core(unsigned n) { - for (; m_trail_lim < s().init_trail_size(); ++m_trail_lim) { - auto lit = s().trail_literal(m_trail_lim); - auto e = ctx.literal2expr(lit); - if (is_unit(e)) { - // IF_VERBOSE(1, verbose_stream() << "add unit " << mk_pp(e, m) << "\n"); - std::lock_guard lock(m_mutex); - ast_translation tr(m, *m_shared); - m_units->push_back(tr(e.get())); - m_has_units = true; - } - } - } + expr* solver::bool_var2expr(sat::bool_var v) { + return ctx.bool_var2expr(v); + } - void solver::init_search() { - if (m_sls) { - m_sls->cancel(); - m_thread.join(); - m_result = l_undef; - m_completed = false; - m_has_units = false; - m_model = nullptr; - m_units = nullptr; - } - // set up state for local search solver here + void solver::set_finished() { + ctx.s().set_canceled(); + } + + unsigned solver::get_num_bool_vars() const { + return s().num_vars(); + } - m_shared = alloc(ast_manager); - m_slsm = alloc(ast_manager); - m_units = alloc(expr_ref_vector, *m_shared); - ast_translation tr(m, *m_slsm); + void solver::finalize() { + if (!m_smt_plugin) + return; - m_completed = false; - m_result = l_undef; + m_smt_plugin->finalize(m_model, m_st); m_model = nullptr; - m_sls = alloc(bv::sls, *m_slsm, s().params()); - - for (expr* a : ctx.get_assertions()) - m_sls->assert_expr(tr(a)); + m_smt_plugin = nullptr; + } - std::function eval = [&](expr* e, unsigned r) { + bool solver::unit_propagate() { + force_push(); + if (m_smt_plugin && !m_checking) { + expr_ref_vector fmls(m); + m_checking = true; + m_smt_plugin->check(fmls, ctx.top_level_clauses()); + return true; + } + if (!m_smt_plugin) return false; - }; - - m_sls->init(); - m_sls->init_eval(eval); - m_sls->updt_params(s().params()); - m_sls->init_unit([&]() { - if (!m_has_units) - return expr_ref(*m_slsm); - expr_ref e(*m_slsm); - { - std::lock_guard lock(m_mutex); - if (m_units->empty()) - return expr_ref(*m_slsm); - ast_translation tr(*m_shared, *m_slsm); - e = tr(m_units->back()); - m_units->pop_back(); - } - return e; - }); - m_sls->set_model([&](model& mdl) { - std::lock_guard lock(m_mutex); - ast_translation tr(*m_shared, m); - m_model = mdl.translate(tr); - }); - - m_thread = std::thread([this]() { run_local_search(); }); + if (!m_smt_plugin->completed()) + return false; + m_smt_plugin->finalize(m_model, m_st); + m_smt_plugin = nullptr; + return true; } - void solver::sample_local_search() { - if (!m_completed) - return; - m_thread.join(); - m_completed = false; - m_sls->collect_statistics(m_st); - if (m_result == l_true) { - IF_VERBOSE(2, verbose_stream() << "(sat.sls :model-completed)\n";); - auto mdl = m_sls->get_model(); - ast_translation tr(*m_slsm, m); - m_model = mdl->translate(tr); - s().set_canceled(); + void solver::pop_core(unsigned n) { + if (!m_smt_plugin) + return; + + unsigned scope_lvl = s().scope_lvl(); + if (s().search_lvl() == scope_lvl - n) { + for (; m_trail_lim < s().init_trail_size(); ++m_trail_lim) { + auto lit = s().trail_literal(m_trail_lim); + m_smt_plugin->add_unit(lit); + } } - m_sls = nullptr; +#if 0 + if (ctx.has_new_best_phase()) + m_smt_plugin->import_phase_from_smt(); + +#endif + + m_smt_plugin->import_from_sls(); + } + + void solver::init_search() { + if (m_smt_plugin) + finalize(); + m_smt_plugin = alloc(sls::smt_plugin, *this); + m_checking = false; } - void solver::run_local_search() { - m_result = (*m_sls)(); - m_completed = true; + std::ostream& solver::display(std::ostream& out) const { + return out << "theory-sls\n"; } + #endif } diff --git a/src/sat/smt/sls_solver.h b/src/sat/smt/sls_solver.h index e1d8a95b51c..35f4e1b7d13 100644 --- a/src/sat/smt/sls_solver.h +++ b/src/sat/smt/sls_solver.h @@ -18,7 +18,7 @@ Module Name: #include "util/rlimit.h" -#include "ast/sls/bv_sls.h" +#include "ast/sls/sat_ddfw.h" #include "sat/smt/sat_th.h" @@ -52,8 +52,7 @@ namespace sls { #else -#include -#include +#include "ast/sls/sls_smt_plugin.h" namespace euf { class solver; @@ -61,24 +60,12 @@ namespace euf { namespace sls { - class solver : public euf::th_euf_solver { - std::atomic m_result; - std::atomic m_completed, m_has_units; - std::thread m_thread; - std::mutex m_mutex; - // m is accessed by the main thread - // m_slsm is accessed by the sls thread - // m_shared is only accessed at synchronization points - scoped_ptr m_shared, m_slsm; - scoped_ptr m_sls; - scoped_ptr m_units; + class solver : public euf::th_euf_solver, public sls::smt_context { model_ref m_model; + sls::smt_plugin* m_smt_plugin = nullptr; unsigned m_trail_lim = 0; - statistics m_st; - - void run_local_search(); - void sample_local_search(); - bool is_unit(expr*); + bool m_checking = false; + ::statistics m_st; public: solver(euf::solver& ctx); @@ -97,10 +84,21 @@ namespace sls { sat::literal internalize(expr* e, bool sign, bool root) override { UNREACHABLE(); return sat::null_literal; } void internalize(expr* e) override { UNREACHABLE(); } void get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector & r, bool probing) override { UNREACHABLE(); } - sat::check_result check() override; - std::ostream & display(std::ostream & out) const override { return out; } + sat::check_result check() override { return sat::check_result::CR_DONE; } + std::ostream& display(std::ostream& out) const override; std::ostream & display_justification(std::ostream & out, sat::ext_justification_idx idx) const override { UNREACHABLE(); return out; } std::ostream & display_constraint(std::ostream & out, sat::ext_constraint_idx idx) const override { UNREACHABLE(); return out; } + + + ast_manager& get_manager() override { return m; } + params_ref get_params() override; + void initialize_value(expr* t, expr* v) override; + void force_phase(sat::literal lit) override; + void set_has_new_best_phase(bool b) override; + bool get_best_phase(sat::bool_var v) override; + expr* bool_var2expr(sat::bool_var v) override; + void set_finished() override; + unsigned get_num_bool_vars() const override; }; diff --git a/src/sat/tactic/goal2sat.cpp b/src/sat/tactic/goal2sat.cpp index 57e3a89b5be..005e5e035f9 100644 --- a/src/sat/tactic/goal2sat.cpp +++ b/src/sat/tactic/goal2sat.cpp @@ -44,7 +44,7 @@ Module Name: #include "sat/smt/pb_solver.h" #include "sat/smt/euf_solver.h" #include "sat/smt/sat_th.h" -#include "sat/sat_params.hpp" +#include "params/sat_params.hpp" #include struct goal2sat::imp : public sat::sat_internalizer { @@ -139,10 +139,6 @@ struct goal2sat::imp : public sat::sat_internalizer { return m_euf && ensure_euf()->relevancy_enabled(); } - bool top_level_relevant() { - return m_top_level && relevancy_enabled(); - } - void mk_clause(sat::literal l1, sat::literal l2, euf::th_proof_hint* ph) { sat::literal lits[2] = { l1, l2 }; mk_clause(2, lits, ph); @@ -158,6 +154,7 @@ struct goal2sat::imp : public sat::sat_internalizer { if (relevancy_enabled()) ensure_euf()->add_aux(n, lits); m_solver.add_clause(n, lits, mk_status(ph)); + add_top_level_clause(n, lits); } void mk_root_clause(sat::literal l) { @@ -179,6 +176,7 @@ struct goal2sat::imp : public sat::sat_internalizer { if (relevancy_enabled()) ensure_euf()->add_root(n, lits); m_solver.add_clause(n, lits, ph ? mk_status(ph) : sat::status::input()); + add_top_level_clause(n, lits); } sat::bool_var add_var(bool is_ext, expr* n) { @@ -895,7 +893,6 @@ struct goal2sat::imp : public sat::sat_internalizer { process(n, true); CTRACE("goal2sat", !m_result_stack.empty(), tout << m_result_stack << "\n";); SASSERT(m_result_stack.empty()); - add_assertion(n); } void insert_dep(expr* dep0, expr* dep, bool sign) { @@ -990,10 +987,12 @@ struct goal2sat::imp : public sat::sat_internalizer { } } - void add_assertion(expr* f) { + void add_top_level_clause(unsigned n, sat::literal const* lits) { + if (!m_top_level) + return; auto* ext = dynamic_cast(m_solver.get_extension()); if (ext) - ext->add_assertion(f); + ext->add_clause(n, lits); } void update_model(model_ref& mdl) { diff --git a/src/sat/tactic/sat2goal.cpp b/src/sat/tactic/sat2goal.cpp index 899345ad88f..ead71f2ad68 100644 --- a/src/sat/tactic/sat2goal.cpp +++ b/src/sat/tactic/sat2goal.cpp @@ -44,7 +44,7 @@ Module Name: #include "sat/smt/pb_solver.h" #include "sat/smt/euf_solver.h" #include "sat/smt/sat_th.h" -#include "sat/sat_params.hpp" +#include "params/sat_params.hpp" #include sat2goal::mc::mc(ast_manager& m): m(m), m_var2expr(m) {} diff --git a/src/sat/tactic/sat_tactic.cpp b/src/sat/tactic/sat_tactic.cpp index 105fe94ba00..d43442ed9c4 100644 --- a/src/sat/tactic/sat_tactic.cpp +++ b/src/sat/tactic/sat_tactic.cpp @@ -16,13 +16,14 @@ Module Name: Notes: --*/ +#include "params/sat_params.hpp" #include "ast/ast_pp.h" #include "model/model_v2_pp.h" #include "tactic/tactical.h" #include "sat/tactic/goal2sat.h" #include "sat/tactic/sat2goal.h" #include "sat/sat_solver.h" -#include "sat/sat_params.hpp" + class sat_tactic : public tactic { diff --git a/src/shell/dimacs_frontend.cpp b/src/shell/dimacs_frontend.cpp index afeb604a855..6310c0dd18a 100644 --- a/src/shell/dimacs_frontend.cpp +++ b/src/shell/dimacs_frontend.cpp @@ -23,7 +23,7 @@ Revision History: #include "util/rlimit.h" #include "util/gparams.h" #include "sat/dimacs.h" -#include "sat/sat_params.hpp" +#include "params/sat_params.hpp" #include "sat/sat_solver.h" #include "sat/tactic/goal2sat.h" #include "sat/tactic/sat2goal.h" diff --git a/src/smt/CMakeLists.txt b/src/smt/CMakeLists.txt index e6ee970460f..42469c365c0 100644 --- a/src/smt/CMakeLists.txt +++ b/src/smt/CMakeLists.txt @@ -61,11 +61,13 @@ z3_add_component(smt theory_dl.cpp theory_dummy.cpp theory_fpa.cpp + theory_intblast.cpp theory_lra.cpp theory_opt.cpp theory_pb.cpp theory_recfun.cpp theory_seq.cpp + theory_sls.cpp theory_special_relations.cpp theory_str.cpp theory_str_mc.cpp diff --git a/src/smt/smt_context.cpp b/src/smt/smt_context.cpp index 9ceee136f2e..5bfcdfd9567 100644 --- a/src/smt/smt_context.cpp +++ b/src/smt/smt_context.cpp @@ -37,6 +37,7 @@ Revision History: #include "smt/uses_theory.h" #include "smt/theory_special_relations.h" #include "smt/theory_polymorphism.h" +#include "smt/theory_sls.h" #include "smt/smt_for_each_relevant_expr.h" #include "smt/smt_model_generator.h" #include "smt/smt_model_checker.h" @@ -103,6 +104,10 @@ namespace smt { */ bool context::get_cancel_flag() { + if (l_true == m_sls_completed && !m.limit().suspended()) { + m_last_search_failure = CANCELED; + return true; + } if (m.limit().inc()) return false; m_last_search_failure = CANCELED; @@ -3503,9 +3508,13 @@ namespace smt { m_case_split_queue->display(tout << "case splits\n"); ); display_profile(verbose_stream()); - if (r == l_true && get_cancel_flag()) { + if (r == l_true && get_cancel_flag()) r = l_undef; + if (r == l_undef && m_sls_completed == l_true && has_sls_model()) { + m_last_search_failure = OK; + r = l_true; } + m_sls_completed = l_false; if (r == l_true && gparams::get_value("model_validate") == "true") { recfun::util u(m); if (u.get_rec_funs().empty() && m_proto_model) { @@ -3581,6 +3590,17 @@ namespace smt { return r; } + bool context::has_sls_model() { + if (!m_fparams.m_sls_enable) + return false; + auto tid = m.get_family_id("sls"); + auto p = m_theories.get_plugin(tid); + if (!p) + return false; + m_model = dynamic_cast(p)->get_model(); + return m_model.get() != nullptr; + } + /** \brief Setup the logical context based on the current set of asserted formulas and execute the check command. @@ -3734,6 +3754,7 @@ namespace smt { m_phase_default = false; m_case_split_queue ->init_search_eh(); m_next_progress_sample = 0; + m_sls_completed = l_undef; if (m.has_type_vars() && !m_theories.get_plugin(poly_family_id)) register_plugin(alloc(theory_polymorphism, *this)); TRACE("literal_occ", display_literal_num_occs(tout);); diff --git a/src/smt/smt_context.h b/src/smt/smt_context.h index 715b28f23d2..f1d88093027 100644 --- a/src/smt/smt_context.h +++ b/src/smt/smt_context.h @@ -128,6 +128,7 @@ namespace smt { class parallel* m_par = nullptr; unsigned m_par_index = 0; bool m_internalizing_assertions = false; + lbool m_sls_completed = l_undef; // ----------------------------------- @@ -288,6 +289,11 @@ namespace smt { bool get_cancel_flag(); + void set_sls_completed() { + if (m_sls_completed == l_undef) + m_sls_completed = l_true; + } + region & get_region() { return m_region; } @@ -619,6 +625,9 @@ namespace smt { friend class set_var_theory_trail; void set_var_theory(bool_var v, theory_id tid); + + bool has_sls_model(); + // ----------------------------------- // // Backtracking support @@ -939,6 +948,8 @@ namespace smt { mk_th_clause(tid, num_lits, lits, num_params, params, CLS_TH_AXIOM); } + void mk_th_axiom(theory_id tid, literal l1, unsigned num_params = 0, parameter * params = nullptr); + void mk_th_axiom(theory_id tid, literal l1, literal l2, unsigned num_params = 0, parameter * params = nullptr); void mk_th_axiom(theory_id tid, literal l1, literal l2, literal l3, unsigned num_params = 0, parameter * params = nullptr); diff --git a/src/smt/smt_internalizer.cpp b/src/smt/smt_internalizer.cpp index 2b18d9a3f33..9181c902501 100644 --- a/src/smt/smt_internalizer.cpp +++ b/src/smt/smt_internalizer.cpp @@ -1562,6 +1562,10 @@ namespace smt { mk_clause(num_lits, lits, js, k); } + void context::mk_th_axiom(theory_id tid, literal l1, unsigned num_params, parameter * params) { + mk_th_axiom(tid, 1, &l1, num_params, params); + } + void context::mk_th_axiom(theory_id tid, literal l1, literal l2, unsigned num_params, parameter * params) { literal ls[2] = { l1, l2 }; mk_th_axiom(tid, 2, ls, num_params, params); diff --git a/src/smt/smt_setup.cpp b/src/smt/smt_setup.cpp index caeca965906..f1983364fd7 100644 --- a/src/smt/smt_setup.cpp +++ b/src/smt/smt_setup.cpp @@ -27,6 +27,7 @@ Revision History: #include "smt/theory_array.h" #include "smt/theory_array_full.h" #include "smt/theory_bv.h" +#include "smt/theory_intblast.h" #include "smt/theory_datatype.h" #include "smt/theory_recfun.h" #include "smt/theory_dummy.h" @@ -35,6 +36,7 @@ Revision History: #include "smt/theory_seq.h" #include "smt/theory_char.h" #include "smt/theory_special_relations.h" +#include "smt/theory_sls.h" #include "smt/theory_pb.h" #include "smt/theory_fpa.h" #include "smt/theory_str.h" @@ -67,6 +69,7 @@ namespace smt { case CFG_AUTO: setup_auto_config(); break; } setup_card(); + setup_sls(); } void setup::setup_default() { @@ -471,12 +474,12 @@ namespace smt { void setup::setup_QF_BV() { TRACE("setup", tout << "qf-bv\n";); m_params.setup_QF_BV(); - m_context.register_plugin(alloc(smt::theory_bv, m_context)); + setup_bv(); } void setup::setup_QF_AUFBV() { m_params.setup_QF_AUFBV(); - m_context.register_plugin(alloc(smt::theory_bv, m_context)); + setup_bv(); setup_arrays(); } @@ -693,7 +696,15 @@ namespace smt { family_id bv_fid = m_manager.mk_family_id("bv"); if (m_context.get_theory(bv_fid)) return; - switch(m_params.m_bv_mode) { + switch (m_params.m_bv_solver) { + case 2: + m_context.register_plugin(alloc(smt::theory_intblast, m_context)); + setup_lra_arith(); + return; + default: + break; + } + switch (m_params.m_bv_mode) { case BS_NO_BV: m_context.register_plugin(alloc(smt::theory_dummy, m_context, bv_fid, "no bit-vector")); break; @@ -766,6 +777,11 @@ namespace smt { m_context.register_plugin(alloc(theory_pb, m_context)); } + void setup::setup_sls() { + if (m_params.m_sls_enable) + m_context.register_plugin(alloc(theory_sls, m_context)); + } + void setup::setup_fpa() { setup_bv(); m_context.register_plugin(alloc(theory_fpa, m_context)); diff --git a/src/smt/smt_setup.h b/src/smt/smt_setup.h index bb4a81671ac..acbea59cb6c 100644 --- a/src/smt/smt_setup.h +++ b/src/smt/smt_setup.h @@ -103,6 +103,7 @@ namespace smt { void setup_seq(); void setup_char(); void setup_card(); + void setup_sls(); void setup_i_arith(); void setup_mi_arith(); void setup_lra_arith(); diff --git a/src/smt/theory_intblast.cpp b/src/smt/theory_intblast.cpp new file mode 100644 index 00000000000..0e8b937d728 --- /dev/null +++ b/src/smt/theory_intblast.cpp @@ -0,0 +1,191 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + theory_intblast + +Author: + + Nikolaj Bjorner (nbjorner) 2024-10-27 + +--*/ + +#include "smt/smt_context.h" +#include "smt/theory_intblast.h" +#include "smt/smt_model_generator.h" + +namespace smt { + + void theory_intblast::translator_trail::push(push_back_vector const& c) { + ctx.push_trail(c); + } + void theory_intblast::translator_trail::push(push_back_vector> const& c) { + ctx.push_trail(c); + } + + void theory_intblast::translator_trail::push_idx(set_vector_idx_trail const& c) { + ctx.push_trail(c); + } + + theory_intblast::theory_intblast(context& ctx): + theory(ctx, ctx.get_manager().mk_family_id("bv")), + m_trail(ctx), + m_translator(m, m_trail), + bv(m), + a(m) + {} + + theory_intblast::~theory_intblast() {} + + final_check_status theory_intblast::final_check_eh() { + for (auto e : m_translator.bv2int()) { + auto* n = ctx.get_enode(e); + auto* r1 = n->get_arg(0)->get_root(); + for (auto sib : *n) { + if (sib == n) + continue; + if (!bv.is_bv2int(sib->get_expr())) + continue; + if (sib->get_arg(0)->get_root() == r1) + continue; + if (bv.get_bv_size(r1->get_expr()) != bv.get_bv_size(sib->get_arg(0)->get_expr())) + continue; + auto a = mk_eq(n->get_expr(), sib->get_expr(), false); + auto b = mk_eq(sib->get_arg(0)->get_expr(), n->get_arg(0)->get_expr(), false); + ctx.mark_as_relevant(a); + ctx.mark_as_relevant(b); + ctx.mk_th_axiom(m_id, ~a, b); + return final_check_status::FC_CONTINUE; + } + } + // ensure that int2bv respects values + // bv2int(int2bv(x)) = x mod N + for (auto e : m_translator.int2bv()) { + auto n = ctx.get_enode(e); + auto x = n->get_arg(0)->get_expr(); + auto bv2int = bv.mk_bv2int(e); + ctx.internalize(bv2int, false); + auto N = rational::power_of_two(bv.get_bv_size(e)); + auto xModN = a.mk_mod(x, a.mk_int(N)); + ctx.internalize(xModN, false); + auto nBv2int = ctx.get_enode(bv2int); + auto nxModN = ctx.get_enode(xModN); + if (nBv2int->get_root() != nxModN->get_root()) { + auto a = mk_eq(nBv2int->get_expr(), nxModN->get_expr(), false); + ctx.mark_as_relevant(a); + ctx.mk_th_axiom(m_id, 1, &a); + return final_check_status::FC_CONTINUE; + } + } + return final_check_status::FC_DONE; + } + + bool theory_intblast::add_bound_axioms() { + auto const& vars = m_translator.vars(); + if (m_vars_qhead == vars.size()) + return false; + ctx.push_trail(value_trail(m_vars_qhead)); + for (; m_vars_qhead < vars.size(); ++m_vars_qhead) { + auto v = vars[m_vars_qhead]; + auto w = m_translator.translated(v); + auto sz = rational::power_of_two(bv.get_bv_size(v->get_sort())); + auto lo = mk_literal(a.mk_ge(w, a.mk_int(0))); + auto hi = mk_literal(a.mk_le(w, a.mk_int(sz - 1))); + ctx.mark_as_relevant(lo); + ctx.mark_as_relevant(hi); + ctx.mk_th_axiom(m_id, 1, &lo); + ctx.mk_th_axiom(m_id, 1, &hi); + } + return true; + } + + bool theory_intblast::add_predicate_axioms() { + auto const& preds = m_translator.preds(); + if (m_preds_qhead == preds.size()) + return false; + ctx.push_trail(value_trail(m_preds_qhead)); + for (; m_preds_qhead < preds.size(); ++m_preds_qhead) { + expr* e = preds[m_preds_qhead]; + expr_ref r(m_translator.translated(e), m); + ctx.get_rewriter()(r); + auto a = mk_literal(e); + auto b = mk_literal(r); + ctx.mark_as_relevant(a); + ctx.mark_as_relevant(b); + ctx.mk_th_axiom(m_id, ~a, b); + ctx.mk_th_axiom(m_id, a, ~b); + } + return true; + } + + bool theory_intblast::can_propagate() { + return m_preds_qhead < m_translator.preds().size() || m_vars_qhead < m_translator.vars().size(); + } + + void theory_intblast::propagate() { + add_bound_axioms(); + add_predicate_axioms(); + } + + bool theory_intblast::internalize_atom(app * atom, bool gate_ctx) { + return internalize_term(atom); + } + + void theory_intblast::apply_sort_cnstr(enode* n, sort* s) { + SASSERT(bv.is_bv_sort(s)); + if (!is_attached_to_var(n)) { + m_translator.internalize_bv(n->get_expr()); + auto v = mk_var(n); + ctx.attach_th_var(n, this, v); + } + } + + bool theory_intblast::internalize_term(app* term) { + + ctx.internalize(term->get_args(), term->get_num_args(), false); + m_translator.internalize_bv(term); + enode* n; + if (!ctx.e_internalized(term)) + n = ctx.mk_enode(term, false, false, false); + else + n = ctx.get_enode(term); + + if (!is_attached_to_var(n)) { + auto v = mk_var(n); + ctx.attach_th_var(n, this, v); + } + if (m.is_bool(term)) { + literal l(ctx.mk_bool_var(term)); + ctx.set_var_theory(l.var(), get_id()); + } + return true; + } + + void theory_intblast::internalize_eq_eh(app * atom, bool_var v) { + m_translator.translate_eq(atom); + } + + void theory_intblast::init_model(model_generator& mg) { + m_factory = alloc(bv_factory, m); + mg.register_factory(m_factory); + } + + model_value_proc* theory_intblast::mk_value(enode* n, model_generator& mg) { + expr* e = n->get_expr(); + SASSERT(bv.is_bv(e)); + rational r; + expr* ie = nullptr; + expr_ref val(m); + if (!bv.is_numeral(e, r)) { + for (enode* sib : *n) { + ie = m_translator.translated(sib->get_expr()); + if (ctx.e_internalized(ie) && ctx.get_value(ctx.get_enode(ie), val) && a.is_numeral(val, r)) + break; + } + } + return alloc(expr_wrapper_proc, m_factory->mk_num_value(r, bv.get_bv_size(e))); + } + + +} diff --git a/src/smt/theory_intblast.h b/src/smt/theory_intblast.h new file mode 100644 index 00000000000..b822593b743 --- /dev/null +++ b/src/smt/theory_intblast.h @@ -0,0 +1,73 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + theory_intblast + +Abstract: + + Intblast version of bit-vector solver + +Author: + + Nikolaj Bjorner (nbjorner) 2024-10-24 + +--*/ +#pragma once + + +#include "util/rlimit.h" +#include "ast/sls/sat_ddfw.h" +#include "smt/smt_theory.h" +#include "model/model.h" +#include "model/numeral_factory.h" +#include "ast/rewriter/bv2int_translator.h" + + +namespace smt { + + class theory_intblast : public theory { + + class translator_trail : public bv2int_translator_trail { + context& ctx; + public: + translator_trail(context& ctx):ctx(ctx) {} + void push(push_back_vector const& c) override; + void push(push_back_vector> const& c) override; + void push_idx(set_vector_idx_trail const& c) override; + }; + + translator_trail m_trail; + bv2int_translator m_translator; + bv_util bv; + arith_util a; + unsigned m_vars_qhead = 0, m_preds_qhead = 0; + bv_factory * m_factory = nullptr; + + bool add_bound_axioms(); + bool add_predicate_axioms(); + + public: + theory_intblast(context& ctx); + ~theory_intblast() override; + + char const* get_name() const override { return "bv-intblast"; } + smt::theory* mk_fresh(context* new_ctx) override { return alloc(theory_intblast, *new_ctx); } + final_check_status final_check_eh() override; + void display(std::ostream& out) const override {} + bool can_propagate() override; + void propagate() override; + bool internalize_atom(app * atom, bool gate_ctx) override; + bool internalize_term(app* term) override; + void internalize_eq_eh(app * atom, bool_var v) override; + void apply_sort_cnstr(enode* n, sort* s) override; + void init_model(model_generator& m) override; + model_value_proc* mk_value(enode* n, model_generator& m) override; + void new_eq_eh(theory_var v1, theory_var v2) override {} + void new_diseq_eh(theory_var v1, theory_var v2) override {} + + }; + +} + diff --git a/src/smt/theory_lra.cpp b/src/smt/theory_lra.cpp index 5a3cbbd1ad9..6b51e2f6922 100644 --- a/src/smt/theory_lra.cpp +++ b/src/smt/theory_lra.cpp @@ -154,6 +154,7 @@ class theory_lra::imp { svector m_asserted_atoms; ptr_vector m_not_handled; ptr_vector m_underspecified; + ptr_vector m_bv_terms; vector > m_use_list; // bounds where variables are used. // attributes for incremental version: @@ -470,6 +471,13 @@ class theory_lra::imp { st.to_ensure_var().push_back(n1); st.to_ensure_var().push_back(n2); } + else if (a.is_band(n) || a.is_shl(n) || a.is_ashr(n) || a.is_lshr(n)) { + m_bv_terms.push_back(to_app(n)); + ctx().push_trail(push_back_vector(m_bv_terms)); + mk_bv_axiom(to_app(n)); + for (expr* arg : *to_app(n)) + st.to_ensure_var().push_back(arg); + } else if (!a.is_div0(n)) { found_unsupported(n); } @@ -1611,6 +1619,7 @@ class theory_lra::imp { if (!lp().is_feasible() || lp().has_changed_columns()) is_sat = make_feasible(); final_check_status st = FC_DONE; + bool int_undef = false; switch (is_sat) { case l_true: TRACE("arith", display(tout)); @@ -1621,6 +1630,7 @@ class theory_lra::imp { case FC_CONTINUE: return FC_CONTINUE; case FC_GIVEUP: + int_undef = true; TRACE("arith", tout << "check-lia giveup\n";); if (ctx().get_fparams().m_arith_ignore_int) st = FC_CONTINUE; @@ -1642,6 +1652,9 @@ class theory_lra::imp { ++m_stats.m_assume_eqs; return FC_CONTINUE; } + + if (!int_undef && !check_bv_terms()) + return FC_CONTINUE; for (expr* e : m_not_handled) { if (!ctx().is_relevant(e)) @@ -2442,6 +2455,180 @@ class theory_lra::imp { return null_literal; } + bool check_bv_terms() { + for (app* n : m_bv_terms) { + if (!check_bv_term(n)) { + ++m_stats.m_bv_axioms; + return false; + } + } + return true; + } + + + bool check_bv_term(app* n) { + unsigned sz = 0; + expr* _x = nullptr, * _y = nullptr; + if (!ctx().is_relevant(ctx().get_enode(n))) + return true; + expr_ref vx(m), vy(m),vn(m); + rational valn, valx, valy; + bool is_int; + VERIFY(a.is_band(n, sz, _x, _y) || a.is_shl(n, sz, _x, _y) || a.is_ashr(n, sz, _x, _y) || a.is_lshr(n, sz, _x, _y)); + if (!get_value(ctx().get_enode(_x), vx) || !get_value(ctx().get_enode(_y), vy) || !get_value(ctx().get_enode(n), vn)) { + IF_VERBOSE(2, verbose_stream() << "could not get value of " << mk_pp(n, m) << "\n"); + found_unsupported(n); + return true; + } + if (!a.is_numeral(vn, valn, is_int) || !is_int || !a.is_numeral(vx, valx, is_int) || !is_int || !a.is_numeral(vy, valy, is_int) || !is_int) { + IF_VERBOSE(2, verbose_stream() << "could not get value of " << mk_pp(n, m) << "\n"); + found_unsupported(n); + return true; + } + rational N = rational::power_of_two(sz); + valx = mod(valx, N); + valy = mod(valy, N); + expr_ref x(a.mk_mod(_x, a.mk_int(N)), m); + expr_ref y(a.mk_mod(_y, a.mk_int(N)), m); + SASSERT(0 <= valn && valn < N); + + // x mod 2^{i + 1} >= 2^i means the i'th bit is 1. + auto bitof = [&](expr* x, unsigned i) { + expr_ref r(m); + r = a.mk_ge(a.mk_mod(x, a.mk_int(rational::power_of_two(i+1))), a.mk_int(rational::power_of_two(i))); + return mk_literal(r); + }; + + if (a.is_band(n)) { + IF_VERBOSE(2, verbose_stream() << "band: " << mk_bounded_pp(n, m) << " " << valn << " := " << valx << "&" << valy << "\n"); + for (unsigned i = 0; i < sz; ++i) { + bool xb = valx.get_bit(i); + bool yb = valy.get_bit(i); + bool nb = valn.get_bit(i); + if (xb && yb && !nb) + ctx().mk_th_axiom(get_id(), ~bitof(x, i), ~bitof(y, i), bitof(n, i)); + else if (nb && !xb) + ctx().mk_th_axiom(get_id(), ~bitof(n, i), bitof(x, i)); + else if (nb && !yb) + ctx().mk_th_axiom(get_id(), ~bitof(n, i), bitof(y, i)); + else + continue; + return false; + } + } + if (a.is_shl(n)) { + SASSERT(valy >= 0); + if (valy >= sz || valy == 0) + return true; + unsigned k = valy.get_unsigned(); + sat::literal eq = th.mk_eq(n, a.mk_mod(a.mk_mul(_x, a.mk_int(rational::power_of_two(k))), a.mk_int(N)), false); + if (ctx().get_assignment(eq) == l_true) + return true; + ctx().mk_th_axiom(get_id(), ~th.mk_eq(y, a.mk_int(k), false), eq); + IF_VERBOSE(2, verbose_stream() << "shl: " << mk_bounded_pp(n, m) << " " << valn << " := " << valx << " << " << valy << "\n"); + return false; + } + if (a.is_lshr(n)) { + SASSERT(valy >= 0); + if (valy >= sz || valy == 0) + return true; + unsigned k = valy.get_unsigned(); + sat::literal eq = th.mk_eq(n, a.mk_idiv(x, a.mk_int(rational::power_of_two(k))), false); + if (ctx().get_assignment(eq) == l_true) + return true; + ctx().mk_th_axiom(get_id(), ~th.mk_eq(y, a.mk_int(k), false), eq); + IF_VERBOSE(2, verbose_stream() << "lshr: " << mk_bounded_pp(n, m) << " " << valn << " := " << valx << " >>l " << valy << "\n"); + return false; + } + if (a.is_ashr(n)) { + SASSERT(valy >= 0); + if (valy >= sz || valy == 0) + return true; + unsigned k = valy.get_unsigned(); + sat::literal signx = mk_literal(a.mk_ge(x, a.mk_int(N/2))); + sat::literal eq; + expr* xdiv2k; + switch (ctx().get_assignment(signx)) { + case l_true: + // x < 0 & y = k -> n = (x div 2^k - 2^{N-k}) mod 2^N + xdiv2k = a.mk_idiv(x, a.mk_int(rational::power_of_two(k))); + eq = th.mk_eq(n, a.mk_mod(a.mk_add(xdiv2k, a.mk_int(-rational::power_of_two(sz - k))), a.mk_int(N)), false); + if (ctx().get_assignment(eq) == l_true) + return true; + break; + case l_false: + // x >= 0 & y = k -> n = x div 2^k + xdiv2k = a.mk_idiv(x, a.mk_int(rational::power_of_two(k))); + eq = th.mk_eq(n, xdiv2k, false); + if (ctx().get_assignment(eq) == l_true) + return true; + break; + case l_undef: + ctx().mark_as_relevant(signx); + return false; + } + ctx().mk_th_axiom(get_id(), ~th.mk_eq(y, a.mk_int(k), false), ~signx, eq); + return false; + } + return true; + } + + expr_ref mk_le(expr* x, expr* y) { + if (a.is_numeral(y)) + return expr_ref(a.mk_le(x, y), m); + if (a.is_numeral(x)) + return expr_ref(a.mk_ge(y, x), m); + return expr_ref(a.mk_le(a.mk_sub(x, y), a.mk_numeral(rational(0), x->get_sort())), m); + } + + void mk_bv_axiom(app* n) { + unsigned sz = 0; + expr* _x = nullptr, * _y = nullptr; + VERIFY(a.is_band(n, sz, _x, _y) || a.is_shl(n, sz, _x, _y) || a.is_ashr(n, sz, _x, _y) || a.is_lshr(n, sz, _x, _y)); + rational N = rational::power_of_two(sz); + expr_ref x(a.mk_mod(_x, a.mk_int(N)), m); + expr_ref y(a.mk_mod(_y, a.mk_int(N)), m); + + // 0 <= n < 2^sz + + ctx().mk_th_axiom(get_id(), mk_literal(a.mk_ge(n, a.mk_int(0)))); + ctx().mk_th_axiom(get_id(), mk_literal(a.mk_le(n, a.mk_int(N - 1)))); + + if (a.is_band(n)) { + + // x&y <= x + // x&y <= y + // TODO? x = y => x&y = x + + ctx().mk_th_axiom(get_id(), mk_literal(mk_le(n, x))); + ctx().mk_th_axiom(get_id(), mk_literal(mk_le(n, y))); + } + else if (a.is_shl(n)) { + // y >= sz => n = 0 + // y = 0 => n = x + ctx().mk_th_axiom(get_id(), ~mk_literal(a.mk_ge(y, a.mk_int(sz))), mk_literal(m.mk_eq(n, a.mk_int(0)))); + ctx().mk_th_axiom(get_id(), ~mk_literal(a.mk_eq(y, a.mk_int(0))), mk_literal(m.mk_eq(n, x))); + } + else if (a.is_lshr(n)) { + // y >= sz => n = 0 + // y = 0 => n = x + ctx().mk_th_axiom(get_id(), ~mk_literal(a.mk_ge(y, a.mk_int(sz))), mk_literal(m.mk_eq(n, a.mk_int(0)))); + ctx().mk_th_axiom(get_id(), ~mk_literal(a.mk_eq(y, a.mk_int(0))), mk_literal(m.mk_eq(n, x))); + } + else if (a.is_ashr(n)) { + // y >= sz & x < 2^{sz-1} => n = 0 + // y >= sz & x >= 2^{sz-1} => n = -1 + // y = 0 => n = x + auto signx = mk_literal(a.mk_ge(x, a.mk_int(N/2))); + ctx().mk_th_axiom(get_id(), ~mk_literal(a.mk_ge(a.mk_mod(y, a.mk_int(N)), a.mk_int(sz))), signx, mk_literal(m.mk_eq(n, a.mk_int(0)))); + ctx().mk_th_axiom(get_id(), ~mk_literal(a.mk_ge(a.mk_mod(y, a.mk_int(N)), a.mk_int(sz))), ~signx, mk_literal(m.mk_eq(n, a.mk_int(N-1)))); + ctx().mk_th_axiom(get_id(), ~mk_literal(a.mk_eq(a.mk_mod(y, a.mk_int(N)), a.mk_int(0))), mk_literal(m.mk_eq(n, x))); + } + else + UNREACHABLE(); + } + + void mk_bound_axioms(api_bound& b) { if (!ctx().is_searching()) { // @@ -3254,7 +3441,7 @@ class theory_lra::imp { tout << "@" << ctx().get_scope_level() << (is_conflict ? " conflict":" lemma"); for (auto const& p : m_params) tout << " " << p; tout << "\n"; - display_evidence(tout, m_explanation);); + display_evidence(tout << core << " ", m_explanation);); for (auto ev : m_explanation) set_evidence(ev.ci(), m_core, m_eqs); @@ -3276,6 +3463,8 @@ class theory_lra::imp { for (literal & c : m_core) { c.neg(); ctx().mark_as_relevant(c); + if (ctx().get_assignment(c) == l_true) + return; } TRACE("arith", ctx().display_literals_verbose(tout, m_core) << "\n";); ctx().mk_th_axiom(get_id(), m_core.size(), m_core.data()); diff --git a/src/smt/theory_sls.cpp b/src/smt/theory_sls.cpp new file mode 100644 index 00000000000..d0b0378bd9e --- /dev/null +++ b/src/smt/theory_sls.cpp @@ -0,0 +1,133 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + theory_sls + +Abstract: + + Interface to Concurrent SLS solver + +Author: + + Nikolaj Bjorner (nbjorner) 2024-10-24 + +--*/ + + +#include "smt/smt_context.h" +#include "ast/sls/sls_context.h" +#include "ast/for_each_expr.h" +#include "smt/theory_sls.h" + +namespace smt { + + theory_sls::theory_sls(smt::context& ctx) : + theory(ctx, ctx.get_manager().mk_family_id("sls")) + {} + +#ifndef SINGLE_THREAD + + theory_sls::~theory_sls() { + finalize(); + } + + params_ref theory_sls::get_params() { + return ctx.get_params(); + } + + void theory_sls::initialize_value(expr* t, expr* v) { + //ctx.user_propagate_initialize_value(t, v); + } + + void theory_sls::force_phase(sat::literal lit) { + // + // ctx.force_phase(lit); + } + + void theory_sls::set_has_new_best_phase(bool b) { + + } + + bool theory_sls::get_best_phase(sat::bool_var v) { + return false; + } + + expr* theory_sls::bool_var2expr(sat::bool_var v) { + return ctx.bool_var2expr(v); + } + + void theory_sls::set_finished() { + ctx.set_sls_completed(); + } + + unsigned theory_sls::get_num_bool_vars() const { + return ctx.get_num_bool_vars(); + } + + void theory_sls::finalize() { + if (!m_smt_plugin) + return; + + m_smt_plugin->finalize(m_model, m_st); + m_model = nullptr; + m_smt_plugin = nullptr; + } + + void theory_sls::propagate() { + if (m_smt_plugin && !m_checking) { + expr_ref_vector fmls(m); + for (unsigned i = 0; i < ctx.get_num_asserted_formulas(); ++i) + fmls.push_back(ctx.get_asserted_formula(i)); + m_checking = true; + vector clauses; + m_smt_plugin->check(fmls, clauses); + return; + } + if (!m_smt_plugin) + return; + if (!m_smt_plugin->completed()) + return; + m_smt_plugin->finalize(m_model, m_st); + m_smt_plugin = nullptr; + } + + void theory_sls::pop_scope_eh(unsigned n) { + if (!m_smt_plugin) + return; + + unsigned scope_lvl = ctx.get_scope_level(); + if (ctx.get_search_level() == scope_lvl - n) { + auto& lits = ctx.assigned_literals(); + for (; m_trail_lim < lits.size() && ctx.get_assign_level(lits[m_trail_lim]) == scope_lvl; ++m_trail_lim) + m_smt_plugin->add_unit(lits[m_trail_lim]); + } +#if 0 + if (ctx.has_new_best_phase()) + m_smt_plugin->import_phase_from_smt(); + +#endif + + // m_smt_plugin->import_from_sls(); + } + + void theory_sls::init() { + if (m_smt_plugin) + finalize(); + m_smt_plugin = alloc(sls::smt_plugin, *this); + m_checking = false; + } + + void theory_sls::collect_statistics(::statistics& st) const { + st.copy(m_st); + } + + void theory_sls::display(std::ostream& out) const { + out << "theory-sls\n"; + } + + + +#endif +} diff --git a/src/smt/theory_sls.h b/src/smt/theory_sls.h new file mode 100644 index 00000000000..581d7d4603b --- /dev/null +++ b/src/smt/theory_sls.h @@ -0,0 +1,93 @@ +/*++ +Copyright (c) 2020 Microsoft Corporation + +Module Name: + + theory_sls + +Abstract: + + Interface to Concurrent SLS solver + +Author: + + Nikolaj Bjorner (nbjorner) 2024-10-24 + +--*/ +#pragma once + + +#include "util/rlimit.h" +#include "ast/sls/sat_ddfw.h" +#include "smt/smt_theory.h" +#include "model/model.h" + + +#ifdef SINGLE_THREAD + +namespace smt { + class theory_sls : public theory { + model_ref m_model; + public: + theory_sls(context& ctx); + ~theory_sls() override {} + model_ref get_model() { return m_model; } + char const* get_name() const override { return "sls"; } + smt::theory* mk_fresh(context* new_ctx) override { return alloc(theory_sls, *new_ctx); } + void display(std::ostream& out) const override {} + bool internalize_atom(app* atom, bool gate_ctx) override { return false; } + bool internalize_term(app* term) override { return false; } + void new_eq_eh(theory_var v1, theory_var v2) override {} + void new_diseq_eh(theory_var v1, theory_var v2) override {} + }; +} + +#else + +#include "ast/sls/sls_smt_plugin.h" + + +namespace smt { + + class theory_sls : public theory, public sls::smt_context { + model_ref m_model; + sls::smt_plugin* m_smt_plugin = nullptr; + unsigned m_trail_lim = 0; + bool m_checking = false; + ::statistics m_st; + + void finalize(); + + public: + theory_sls(context& ctx); + ~theory_sls() override; + model_ref get_model() { return m_model; } + + // smt::theory interface + char const* get_name() const override { return "sls"; } + void init() override; + void pop_scope_eh(unsigned n) override; + smt::theory* mk_fresh(context* new_ctx) override { return alloc(theory_sls, *new_ctx); } + void collect_statistics(::statistics& st) const override; + void propagate() override; + void display(std::ostream& out) const override; + bool internalize_atom(app * atom, bool gate_ctx) override { return false; } + bool internalize_term(app* term) override { return false; } + void new_eq_eh(theory_var v1, theory_var v2) override {} + void new_diseq_eh(theory_var v1, theory_var v2) override {} + + // sls::smt_context interface + ast_manager& get_manager() override { return m; } + params_ref get_params() override; + void initialize_value(expr* t, expr* v) override; + void force_phase(sat::literal lit) override; + void set_has_new_best_phase(bool b) override; + bool get_best_phase(sat::bool_var v) override; + expr* bool_var2expr(sat::bool_var v) override; + void set_finished() override; + unsigned get_num_bool_vars() const override; + }; + +} + +#endif diff --git a/src/tactic/portfolio/smt_strategic_solver.cpp b/src/tactic/portfolio/smt_strategic_solver.cpp index 09e7aa047b8..0f952e016d7 100644 --- a/src/tactic/portfolio/smt_strategic_solver.cpp +++ b/src/tactic/portfolio/smt_strategic_solver.cpp @@ -47,7 +47,7 @@ Module Name: #include "solver/parallel_params.hpp" #include "params/tactic_params.hpp" #include "parsers/smt2/smt2parser.h" -#include "sat/sat_params.hpp" +#include "params/sat_params.hpp" tactic* mk_tactic_for_logic(ast_manager& m, params_ref const& p, symbol const& logic); diff --git a/src/tactic/sls/sls_tactic.cpp b/src/tactic/sls/sls_tactic.cpp index f484372c8bd..e72c01af0d7 100644 --- a/src/tactic/sls/sls_tactic.cpp +++ b/src/tactic/sls/sls_tactic.cpp @@ -28,37 +28,38 @@ Module Name: #include "util/stopwatch.h" #include "tactic/sls/sls_tactic.h" #include "params/sls_params.hpp" -#include "ast/sls/sls_engine.h" -#include "ast/sls/bv_sls.h" +#include "ast/sls/sls_bv_engine.h" +#include "ast/sls/sls_smt_solver.h" -class sls_tactic : public tactic { - ast_manager & m; - params_ref m_params; - sls_engine * m_engine; +class sls_smt_tactic : public tactic { + ast_manager& m; + params_ref m_params; + sls::smt_solver* m_sls; + statistics m_st; public: - sls_tactic(ast_manager & _m, params_ref const & p): + sls_smt_tactic(ast_manager& _m, params_ref const& p) : m(_m), m_params(p) { - m_engine = alloc(sls_engine, m, p); + m_sls = alloc(sls::smt_solver, m, p); } - tactic * translate(ast_manager & m) override { - return alloc(sls_tactic, m, m_params); + tactic* translate(ast_manager& m) override { + return alloc(sls_smt_tactic, m, m_params); } - ~sls_tactic() override { - dealloc(m_engine); + ~sls_smt_tactic() override { + dealloc(m_sls); } - char const* name() const override { return "sls"; } + char const* name() const override { return "sls-smt"; } - void updt_params(params_ref const & p) override { + void updt_params(params_ref const& p) override { m_params.append(p); - m_engine->updt_params(m_params); + m_sls->updt_params(m_params); } - void collect_param_descrs(param_descrs & r) override { + void collect_param_descrs(param_descrs& r) override { sls_params::collect_param_descrs(r); } @@ -66,24 +67,30 @@ class sls_tactic : public tactic { if (g->inconsistent()) { mc = nullptr; return; - } - + } + for (unsigned i = 0; i < g->size(); i++) - m_engine->assert_expr(g->form(i)); - - lbool res = m_engine->operator()(); - auto const& stats = m_engine->get_stats(); - if (res == l_true) { - report_tactic_progress("Number of flips:", stats.m_moves); - - for (unsigned i = 0; i < g->size(); i++) - if (!m_engine->get_mpz_manager().is_one(m_engine->get_value(g->form(i)))) { - verbose_stream() << "Terminated before all assertions were SAT!" << std::endl; - NOT_IMPLEMENTED_YET(); - } + m_sls->assert_expr(g->form(i)); + + m_st.reset(); + lbool res = l_undef; + try { + res = m_sls->check(); + } + catch (z3_exception&) { + m_sls->collect_statistics(m_st); + throw; + } + m_sls->collect_statistics(m_st); + +// report_tactic_progress("Number of flips:", m_sls->get_num_moves()); + IF_VERBOSE(10, verbose_stream() << res << "\n"); + IF_VERBOSE(10, m_sls->display(verbose_stream())); + + if (res == l_true) { if (g->models_enabled()) { - model_ref mdl = m_engine->get_model(); + model_ref mdl = m_sls->get_model(); mc = model2model_converter(mdl.get()); TRACE("sls_model", mc->display(tout);); } @@ -91,16 +98,16 @@ class sls_tactic : public tactic { } else mc = nullptr; - + } - - void operator()(goal_ref const & g, - goal_ref_buffer & result) override { + + void operator()(goal_ref const& g, + goal_ref_buffer& result) override { result.reset(); - + TRACE("sls", g->display(tout);); tactic_report report("sls", *g); - + model_converter_ref mc; run(g, mc); g->add(mc.get()); @@ -109,50 +116,53 @@ class sls_tactic : public tactic { } void cleanup() override { - sls_engine * d = alloc(sls_engine, m, m_params); - std::swap(d, m_engine); + auto* d = alloc(sls::smt_solver, m, m_params); + std::swap(d, m_sls); dealloc(d); } - - void collect_statistics(statistics & st) const override { - m_engine->collect_statistics(st); + + void collect_statistics(statistics& st) const override { + st.copy(m_st); } void reset_statistics() override { - m_engine->reset_statistics(); + m_sls->reset_statistics(); + m_st.reset(); } - }; -class bv_sls_tactic : public tactic { - ast_manager& m; - params_ref m_params; - bv::sls* m_sls; - statistics m_st; +tactic* mk_sls_smt_tactic(ast_manager& m, params_ref const& p) { + return alloc(sls_smt_tactic, m, p); +} + +class sls_tactic : public tactic { + ast_manager & m; + params_ref m_params; + sls_engine * m_engine; public: - bv_sls_tactic(ast_manager& _m, params_ref const& p) : + sls_tactic(ast_manager & _m, params_ref const & p): m(_m), m_params(p) { - m_sls = alloc(bv::sls, m, p); + m_engine = alloc(sls_engine, m, p); } - tactic* translate(ast_manager& m) override { - return alloc(bv_sls_tactic, m, m_params); + tactic * translate(ast_manager & m) override { + return alloc(sls_tactic, m, m_params); } - ~bv_sls_tactic() override { - dealloc(m_sls); + ~sls_tactic() override { + dealloc(m_engine); } - char const* name() const override { return "bv-sls"; } + char const* name() const override { return "sls"; } - void updt_params(params_ref const& p) override { + void updt_params(params_ref const & p) override { m_params.append(p); - m_sls->updt_params(m_params); + m_engine->updt_params(m_params); } - void collect_param_descrs(param_descrs& r) override { + void collect_param_descrs(param_descrs & r) override { sls_params::collect_param_descrs(r); } @@ -160,27 +170,24 @@ class bv_sls_tactic : public tactic { if (g->inconsistent()) { mc = nullptr; return; - } - + } + for (unsigned i = 0; i < g->size(); i++) - m_sls->assert_expr(g->form(i)); - - m_sls->init(); - std::function false_eval = [&](expr* e, unsigned idx) { - return false; - }; - m_sls->init_eval(false_eval); - - lbool res = m_sls->operator()(); - m_st.reset(); - m_sls->collect_statistics(m_st); - report_tactic_progress("Number of flips:", m_sls->get_num_moves()); - IF_VERBOSE(10, verbose_stream() << res << "\n"); - IF_VERBOSE(10, m_sls->display(verbose_stream())); + m_engine->assert_expr(g->form(i)); + + lbool res = m_engine->operator()(); + auto const& stats = m_engine->get_stats(); + if (res == l_true) { + report_tactic_progress("Number of flips:", stats.m_moves); + + for (unsigned i = 0; i < g->size(); i++) + if (!m_engine->get_mpz_manager().is_one(m_engine->get_value(g->form(i)))) { + verbose_stream() << "Terminated before all assertions were SAT!" << std::endl; + NOT_IMPLEMENTED_YET(); + } - if (res == l_true) { if (g->models_enabled()) { - model_ref mdl = m_sls->get_model(); + model_ref mdl = m_engine->get_model(); mc = model2model_converter(mdl.get()); TRACE("sls_model", mc->display(tout);); } @@ -188,16 +195,16 @@ class bv_sls_tactic : public tactic { } else mc = nullptr; - + } - - void operator()(goal_ref const& g, - goal_ref_buffer& result) override { + + void operator()(goal_ref const & g, + goal_ref_buffer & result) override { result.reset(); - + TRACE("sls", g->display(tout);); tactic_report report("sls", *g); - + model_converter_ref mc; run(g, mc); g->add(mc.get()); @@ -206,19 +213,17 @@ class bv_sls_tactic : public tactic { } void cleanup() override { - - auto* d = alloc(bv::sls, m, m_params); - std::swap(d, m_sls); + sls_engine * d = alloc(sls_engine, m, m_params); + std::swap(d, m_engine); dealloc(d); } - - void collect_statistics(statistics& st) const override { - st.copy(m_st); + + void collect_statistics(statistics & st) const override { + m_engine->collect_statistics(st); } void reset_statistics() override { - m_sls->reset_statistics(); - m_st.reset(); + m_engine->reset_statistics(); } }; @@ -228,12 +233,6 @@ static tactic * mk_sls_tactic(ast_manager & m, params_ref const & p) { clean(alloc(sls_tactic, m, p))); } -tactic* mk_bv_sls_tactic(ast_manager& m, params_ref const& p) { - return and_then(fail_if_not(mk_is_qfbv_probe()), // Currently only QF_BV is supported. - clean(alloc(bv_sls_tactic, m, p))); -} - - static tactic * mk_preamble(ast_manager & m, params_ref const & p) { params_ref simp2_p = p; @@ -268,10 +267,3 @@ tactic * mk_qfbv_sls_tactic(ast_manager & m, params_ref const & p) { return t; } -tactic* mk_qfbv_new_sls_tactic(ast_manager& m, params_ref const& p) { - params_ref q = p; - q.set_bool("elim_sign_ext", false); - tactic* t = and_then(mk_preamble(m, q), mk_bv_sls_tactic(m, q)); - t->updt_params(q); - return t; -} diff --git a/src/tactic/sls/sls_tactic.h b/src/tactic/sls/sls_tactic.h index d58d310e3c1..b6a42a78322 100644 --- a/src/tactic/sls/sls_tactic.h +++ b/src/tactic/sls/sls_tactic.h @@ -23,17 +23,11 @@ class ast_manager; class tactic; tactic * mk_qfbv_sls_tactic(ast_manager & m, params_ref const & p = params_ref()); - -tactic* mk_qfbv_new_sls_tactic(ast_manager& m, params_ref const& p = params_ref()); - -tactic* mk_bv_sls_tactic(ast_manager& m, params_ref const& p = params_ref()); +tactic * mk_sls_smt_tactic(ast_manager & m, params_ref const & p = params_ref()); /* ADD_TACTIC("qfbv-sls", "(try to) solve using stochastic local search for QF_BV.", "mk_qfbv_sls_tactic(m, p)") - - ADD_TACTIC("qfbv-new-sls", "(try to) solve using stochastic local search for QF_BV.", "mk_qfbv_new_sls_tactic(m, p)") - - ADD_TACTIC("qfbv-new-sls-core", "(try to) solve using stochastic local search for QF_BV.", "mk_bv_sls_tactic(m, p)") + ADD_TACTIC("sls-smt", "(try to) solve SMT formulas using local search.", "mk_sls_smt_tactic(m, p)") */ diff --git a/src/tactic/smtlogics/smt_tactic.cpp b/src/tactic/smtlogics/smt_tactic.cpp index aefe7ccadee..7bae01a8100 100644 --- a/src/tactic/smtlogics/smt_tactic.cpp +++ b/src/tactic/smtlogics/smt_tactic.cpp @@ -15,11 +15,11 @@ Module Name: --*/ -#include "smt/tactic/smt_tactic_core.h" -#include "sat/tactic/sat_tactic.h" -#include "sat/sat_params.hpp" +#include "params/sat_params.hpp" #include "solver/solver2tactic.h" #include "solver/solver.h" +#include "smt/tactic/smt_tactic_core.h" +#include "sat/tactic/sat_tactic.h" tactic * mk_smt_tactic(ast_manager & m, params_ref const & p) { sat_params sp(p); diff --git a/src/test/dlist.cpp b/src/test/dlist.cpp index 9ad04d6b28c..e378ddec721 100644 --- a/src/test/dlist.cpp +++ b/src/test/dlist.cpp @@ -106,6 +106,7 @@ static void test_insert_before() { std::cout << "test_insert_before passed." << std::endl; } +#if 0 // Test the remove_from() method static void test_remove_from() { TestNode* list = nullptr; @@ -119,6 +120,7 @@ static void test_remove_from() { SASSERT(node2.prev() == &node2); std::cout << "test_remove_from passed." << std::endl; } +#endif // Test the push_to_front() method static void test_push_to_front() { @@ -179,6 +181,6 @@ void tst_dlist() { test_detach(); test_invariant(); test_contains(); - (void)test_remove_from; + //test_remove_from; std::cout << "All tests passed." << std::endl; } diff --git a/src/test/sls_test.cpp b/src/test/sls_test.cpp index d99035398e3..fe96a5dd92c 100644 --- a/src/test/sls_test.cpp +++ b/src/test/sls_test.cpp @@ -1,10 +1,36 @@ -#include "ast/sls/bv_sls_eval.h" +#include "ast/sls/sls_bv_eval.h" +#include "ast/sls/sls_bv_terms.h" #include "ast/rewriter/th_rewriter.h" #include "ast/reg_decl_plugins.h" #include "ast/ast_pp.h" namespace bv { + + class my_sat_solver_context : public sls::sat_solver_context { + vector m_clauses; + indexed_uint_set s; + public: + my_sat_solver_context() {} + + vector const& clauses() const override { return m_clauses; } + sat::clause_info const& get_clause(unsigned idx) const override { return m_clauses[idx]; } + ptr_iterator get_use_list(sat::literal lit) override { return ptr_iterator(nullptr, nullptr); } + void flip(sat::bool_var v) override { } + double reward(sat::bool_var v) override { return 0; } + double get_weigth(unsigned clause_idx) override { return 0; } + bool is_true(sat::literal lit) override { return true; } + unsigned num_vars() const override { return 0; } + indexed_uint_set const& unsat() const override { return s; } + void on_model(model_ref& mdl) override {} + sat::bool_var add_var() override { return sat::null_bool_var;} + void add_clause(unsigned n, sat::literal const* lits) override {} + // void collect_statistics(statistics& st) const override {} + // void reset_statistics() override {} + void force_restart() override {} + std::ostream& display(std::ostream& out) override { return out; } + }; + class sls_test { ast_manager& m; bv_util bv; @@ -28,9 +54,14 @@ namespace bv { expr_ref_vector es(m); bv_util bv(m); es.push_back(e); - sls_eval ev(m); - ev.init_eval(es, value); - ev.tighten_range(es); + + my_sat_solver_context solver; + sls::context ctx(m, solver); + sls::bv_terms terms(ctx); + sls::bv_eval ev(terms, ctx); + for (auto e : es) + ev.register_term(e); + ev.init(); th_rewriter rw(m); expr_ref r(e, m); rw(r); @@ -142,9 +173,14 @@ namespace bv { rw(r); es.push_back(m.is_false(r) ? m.mk_not(e1) : e1); es.push_back(m.is_false(r) ? m.mk_not(e2) : e2); - sls_eval ev(m); - ev.init_eval(es, value); - ev.tighten_range(es); + + my_sat_solver_context solver; + sls::context ctx(m, solver); + sls::bv_terms terms(ctx); + sls::bv_eval ev(terms, ctx); + for (auto e : es) + ev.register_term(e); + ev.init(); if (m.is_bool(e1)) { SASSERT(m.is_true(r) || m.is_false(r)); @@ -152,14 +188,14 @@ namespace bv { auto val2 = ev.bval0(e2); if (val != val2) { ev.set(e2, val); - auto rep1 = ev.try_repair(to_app(e2), idx); + auto rep1 = ev.repair_down(to_app(e2), idx); if (!rep1) { verbose_stream() << "Not repaired " << mk_pp(e1, m) << " " << mk_pp(e2, m) << " r: " << r << "\n"; } auto val3 = ev.bval0(e2); if (val3 != val) { verbose_stream() << "Repaired but not corrected " << mk_pp(e2, m) << "\n"; - ev.display(std::cout, es); + ev.display(std::cout); exit(0); } //SASSERT(rep1); @@ -170,7 +206,7 @@ namespace bv { auto& val2 = ev.wval(e2); if (!val1.eq(val2)) { val2.set(val1.bits()); - auto rep2 = ev.try_repair(to_app(e2), idx); + auto rep2 = ev.repair_down(to_app(e2), idx); if (!rep2) { verbose_stream() << "Not repaired " << mk_pp(e2, m) << "\n"; } diff --git a/src/util/checked_int64.h b/src/util/checked_int64.h index 06b957fcf97..31ef5bdd6a9 100644 --- a/src/util/checked_int64.h +++ b/src/util/checked_int64.h @@ -25,23 +25,25 @@ Revision History: #include "util/z3_exception.h" #include "util/rational.h" +#include "util/mpn.h" + + +class overflow_exception : public z3_exception { + char const* msg() const override { return "checked_int64 overflow/underflow"; } +}; template class checked_int64 { int64_t m_value; typedef checked_int64 ci; - rational r64(int64_t i) { return rational(i, rational::i64()); } + rational r64(int64_t i) const { return rational(i, rational::i64()); } public: checked_int64(): m_value(0) {} checked_int64(int64_t v): m_value(v) {} - class overflow_exception : public z3_exception { - char const * msg() const override { return "checked_int64 overflow/underflow";} - }; - bool is_zero() const { return m_value == 0; } bool is_pos() const { return m_value > 0; } bool is_neg() const { return m_value < 0; } @@ -56,6 +58,7 @@ class checked_int64 { static checked_int64 minus_one() { return ci(-1);} int64_t get_int64() const { return m_value; } + rational to_rational() const { return r64(m_value); } checked_int64 abs() const { if (m_value >= 0) { @@ -118,8 +121,10 @@ class checked_int64 { uint64_t x = static_cast(m_value); uint64_t y = static_cast(other.m_value); int64_t r = static_cast(x + y); - if (m_value > 0 && other.m_value > 0 && r <= 0) throw overflow_exception(); - if (m_value < 0 && other.m_value < 0 && r >= 0) throw overflow_exception(); + if (m_value > 0 && other.m_value > 0 && r <= 0) + throw overflow_exception(); + if (m_value < 0 && other.m_value < 0 && r >= 0) + throw overflow_exception(); m_value = r; } else { @@ -133,8 +138,10 @@ class checked_int64 { uint64_t x = static_cast(m_value); uint64_t y = static_cast(other.m_value); int64_t r = static_cast(x - y); - if (m_value > 0 && other.m_value < 0 && r <= 0) throw overflow_exception(); - if (m_value < 0 && other.m_value > 0 && r >= 0) throw overflow_exception(); + if (m_value > 0 && other.m_value < 0 && r <= 0) + throw overflow_exception(); + if (m_value < 0 && other.m_value > 0 && r >= 0) + throw overflow_exception(); m_value = r; } else { @@ -148,13 +155,23 @@ class checked_int64 { if (INT_MIN < m_value && m_value <= INT_MAX && INT_MIN < other.m_value && other.m_value <= INT_MAX) { m_value *= other.m_value; } - // TBD: could be tuned by using known techniques or 128-bit arithmetic. + else if (m_value == 0 || other.m_value == 0 || m_value == 1 || other.m_value == 1) { + m_value *= other.m_value; + } + else if (m_value == INT64_MIN || other.m_value == INT64_MIN) + throw overflow_exception(); else { - rational r(r64(m_value) * r64(other.m_value)); - if (!r.is_int64()) { + uint64_t x = m_value < 0 ? -m_value : m_value; + uint64_t y = other.m_value < 0 ? -other.m_value : other.m_value; + uint64_t r = x * y; + if ((y != 0 && r / y != x) || r > INT64_MAX) throw overflow_exception(); - } - m_value = r.get_int64(); + int64_t old_value = m_value; + m_value = r; + if (old_value < 0 && other.m_value > 0) + m_value = -m_value; + else if (old_value > 0 && other.m_value < 0) + m_value = -m_value; } } else { @@ -163,6 +180,16 @@ class checked_int64 { return *this; } + checked_int64& operator/=(checked_int64 const& other) { + m_value /= other.m_value; + return *this; + } + + checked_int64& operator%=(checked_int64 const& other) { + m_value %= other.m_value; + return *this; + } + friend inline checked_int64 abs(checked_int64 const& i) { return i.abs(); } @@ -174,21 +201,42 @@ inline bool operator!=(checked_int64 const & i1, checked_int64 con return !operator==(i1, i2); } +template +inline bool operator!=(checked_int64 const& i1, int64_t const& i2) { + return !operator==(i1, i2); +} + template inline bool operator>(checked_int64 const & i1, checked_int64 const & i2) { return operator<(i2, i1); } +template +inline bool operator>(checked_int64 const& i1, int64_t i2) { + return operator<(i2, i1); +} + template inline bool operator<=(checked_int64 const & i1, checked_int64 const & i2) { return !operator>(i1, i2); } +template +inline bool operator<=(checked_int64 const& i1, int64_t const& i2) { + return !operator>(i1, i2); +} + template inline bool operator>=(checked_int64 const & i1, checked_int64 const & i2) { return !operator<(i1, i2); } + +template +inline bool operator>=(checked_int64 const& i1, int64_t const& i2) { + return !operator<(i1, i2); +} + template inline checked_int64 operator-(checked_int64 const& i) { checked_int64 result(i); @@ -202,6 +250,14 @@ inline checked_int64 operator+(checked_int64 const& a, checked_int return result; } +template +inline checked_int64 operator+(checked_int64 const& a, int64_t const& b) { + checked_int64 result(a); + checked_int64 _b(b); + result += _b; + return result; +} + template inline checked_int64 operator-(checked_int64 const& a, checked_int64 const& b) { checked_int64 result(a); @@ -209,9 +265,103 @@ inline checked_int64 operator-(checked_int64 const& a, checked_int return result; } +template +inline checked_int64 operator-(checked_int64 const& a, int64_t const& b) { + checked_int64 result(a); + checked_int64 _b(b); + result -= _b; + return result; +} + template inline checked_int64 operator*(checked_int64 const& a, checked_int64 const& b) { checked_int64 result(a); result *= b; return result; } + +template +inline checked_int64 operator*(int64_t const& a, checked_int64 const& b) { + checked_int64 result(a); + result *= b; + return result; +} + +template +inline checked_int64 operator*(checked_int64 const& a, int64_t const& b) { + checked_int64 result(a); + checked_int64 _b(b); + result *= _b; + return result; +} + +template +inline checked_int64 div(checked_int64 const& a, checked_int64 const& b) { + checked_int64 result(a); + result /= b; + return result; +} + +template +inline checked_int64 operator/(checked_int64 const& a, checked_int64 const& b) { + checked_int64 result(a); + result /= b; + return result; +} + +template +inline checked_int64 mod(checked_int64 const& a, checked_int64 const& b) { + checked_int64 result(a); + result %= b; + if (result < 0) { + if (b > 0) + result += b; + else + result -= b; + } + return result; +} + +template +inline bool divides(checked_int64 const& a, checked_int64 const& b) { + return mod(b, a) == 0; +} + +template +inline checked_int64 gcd(checked_int64 const& a, checked_int64 const& b) { + checked_int64 _a = abs(a); + checked_int64 _b = abs(b); + if (_a == 0) + return _b; + while (_b != 0) { + checked_int64 r = mod(_a, _b); + _a = _b; + _b = r; + } + return _a; +} + +// Compute the extended GCD such that ax + by = gcd(a,b) +template +inline checked_int64 gcd(checked_int64 const& a, checked_int64 const& b, + checked_int64& x, checked_int64& y) { + checked_int64 _a = a; + checked_int64 _b = b; + x = 0; + y = 0; + checked_int64 lastx = 1; + checked_int64 lasty = 0; + while (_b != 0) { + checked_int64 q = div(_a, _b); + checked_int64 r = mod(_a, _b); + _a = _b; + _b = r; + checked_int64 temp = x; + x = lastx - q * x; + lastx = temp; + temp = y; + y = lasty - q * y; + lasty = temp; + } + return _a; +} \ No newline at end of file diff --git a/src/util/mpz.cpp b/src/util/mpz.cpp index b1837662caf..7b9719b3668 100644 --- a/src/util/mpz.cpp +++ b/src/util/mpz.cpp @@ -615,6 +615,7 @@ void mpz_manager::div_gcd(mpz const& a, mpz const& b, mpz & c) { template void mpz_manager::div(mpz const & a, mpz const & b, mpz & c) { STRACE("mpz", tout << "[mpz-ext] div(" << to_string(a) << ", " << to_string(b) << ") == ";); + SASSERT(!is_zero(b)); if (is_one(b)) { set(c, a); } diff --git a/src/util/rlimit.cpp b/src/util/rlimit.cpp index ecc527681a1..ea2e68b311b 100644 --- a/src/util/rlimit.cpp +++ b/src/util/rlimit.cpp @@ -92,6 +92,19 @@ void reslimit::pop_child() { m_children.pop_back(); } +void reslimit::pop_child(reslimit* r) { + lock_guard lock(*g_rlimit_mux); + for (unsigned i = 0; i < m_children.size(); ++i) { + if (m_children[i] == r) { + m_count += r->m_count; + r->m_count = 0; + m_children.erase(m_children.begin() + i); + return; + } + } +} + + void reslimit::cancel() { lock_guard lock(*g_rlimit_mux); set_cancel(m_cancel+1); diff --git a/src/util/rlimit.h b/src/util/rlimit.h index 0abb06cb341..10ad90cd144 100644 --- a/src/util/rlimit.h +++ b/src/util/rlimit.h @@ -45,6 +45,7 @@ class reslimit { void pop(); void push_child(reslimit* r); void pop_child(); + void pop_child(reslimit* r); bool inc(); bool inc(unsigned offset); diff --git a/src/util/sat_sls.h b/src/util/sat_sls.h new file mode 100644 index 00000000000..82b84bd03dc --- /dev/null +++ b/src/util/sat_sls.h @@ -0,0 +1,41 @@ +/*++ +Copyright (c) 2011 Microsoft Corporation + +Module Name: + + sat_sls.h + +Abstract: + + Base types for SLS. + +Author: + + Nikolaj Bjorner (nbjorner) 2024-06027 + +--*/ +#pragma once + +#include "util/sat_literal.h" + +namespace sat { + + struct clause_info { + clause_info(unsigned n, literal const* lits, double init_weight): m_weight(init_weight), m_clause(n, lits) {} + double m_weight; // weight of clause + unsigned m_trues = 0; // set of literals that are true + unsigned m_num_trues = 0; // size of true set + literal_vector m_clause; + literal const* begin() const { return m_clause.begin(); } + literal const* end() const { return m_clause.end(); } + bool is_true() const { return m_num_trues > 0; } + void add(literal lit) { ++m_num_trues; m_trues += lit.index(); } + void del(literal lit) { SASSERT(m_num_trues > 0); --m_num_trues; m_trues -= lit.index(); } + }; + + inline std::ostream& operator<<(std::ostream& out, clause_info const& ci) { + return out << ci.m_clause << " w: " << ci.m_weight << " nt: " << ci.m_num_trues; + } +}; + + diff --git a/src/util/util.h b/src/util/util.h index cf0146b12c1..7d1265b3317 100644 --- a/src/util/util.h +++ b/src/util/util.h @@ -326,6 +326,16 @@ void force_ptr_array_size(T & v, unsigned sz) { } } +template +class ptr_iterator { + T const* b; + T const* e; +public: + ptr_iterator(T const* b, T const* e): b(b), e(e) {} + T const* begin() const { return b; } + T const* end() const { return e; } +}; + class random_gen { unsigned m_data; public: diff --git a/src/util/vector.h b/src/util/vector.h index 1f51ac77548..67f6218ed4d 100644 --- a/src/util/vector.h +++ b/src/util/vector.h @@ -723,6 +723,10 @@ class vector { } } + vector(std::initializer_list const& l) { + for (auto const& t : l) + push_back(t); + } ~vector() { destroy();