diff --git a/src/ast/sls/bv_sls.cpp b/src/ast/sls/bv_sls.cpp index df741dac3..9af87d3c5 100644 --- a/src/ast/sls/bv_sls.cpp +++ b/src/ast/sls/bv_sls.cpp @@ -93,6 +93,18 @@ namespace bv { if (m_to_repair.empty()) return; + + // add fresh units, if any + bool new_assertion = false; + while (m_get_unit) { + auto e = m_get_unit(); + if (!e) + break; + new_assertion = true; + assert_expr(e); + } + if (new_assertion) + init(); std::function eval = [&](expr* e, unsigned i) { unsigned id = e->get_id(); @@ -212,9 +224,7 @@ namespace bv { void sls::try_repair_down(app* e) { unsigned n = e->get_num_args(); if (n == 0) { - m_eval.commit_eval(e); - - IF_VERBOSE(3, verbose_stream() << "done\n"); + m_eval.commit_eval(e); for (auto p : m_terms.parents(e)) m_repair_up.insert(p->get_id()); diff --git a/src/ast/sls/bv_sls.h b/src/ast/sls/bv_sls.h index fe31125a7..690b618bf 100644 --- a/src/ast/sls/bv_sls.h +++ b/src/ast/sls/bv_sls.h @@ -53,6 +53,7 @@ namespace bv { sls_engine m_engine; bool m_engine_model = false; bool m_engine_init = false; + std::function m_get_unit; std::pair next_to_repair(); @@ -81,7 +82,6 @@ namespace bv { /* * Invoke init after all expressions are asserted. - * No other expressions can be asserted after init. */ void init(); @@ -91,6 +91,11 @@ namespace bv { */ void init_eval(std::function& eval); + /** + * add callback to retrieve new units + */ + void init_unit(std::function get_unit) { m_get_unit = get_unit; } + /** * Run (bounded) local search to find feasible assignments. */ diff --git a/src/ast/sls/bv_sls_fixed.cpp b/src/ast/sls/bv_sls_fixed.cpp index be587a8bd..9f897a7bd 100644 --- a/src/ast/sls/bv_sls_fixed.cpp +++ b/src/ast/sls/bv_sls_fixed.cpp @@ -58,16 +58,16 @@ namespace bv { expr* t, * s; rational v; if (bv.is_concat(e, t, s)) { - auto& val = wval(s); - if (val.lo() != val.hi() && (val.lo() < val.hi() || val.hi() == 0)) + auto& vals = wval(s); + if (vals.lo() != vals.hi() && (vals.lo() < vals.hi() || vals.hi() == 0)) // lo <= e - add_range(e, val.lo(), rational::zero(), false); + add_range(e, vals.lo(), rational::zero(), false); auto valt = wval(t); -#if 0 - if (val.lo() < val.hi()) - // e < (2^|s|) * hi - add_range(e, rational::zero(), val.hi() * rational::power_of_two(bv.get_bv_size(s)), false); -#endif + if (valt.lo() != valt.hi() && (valt.lo() < valt.hi() || valt.hi() == 0)) { + // (2^|s|) * lo <= e < (2^|s|) * hi + auto p = rational::power_of_two(bv.get_bv_size(s)); + add_range(e, valt.lo() * p, valt.hi() * p, false); + } } else if (bv.is_bv_add(e, s, t) && bv.is_numeral(s, v)) { auto& val = wval(t); diff --git a/src/ast/sls/bv_sls_terms.cpp b/src/ast/sls/bv_sls_terms.cpp index 4624ab85c..ed1bf2396 100644 --- a/src/ast/sls/bv_sls_terms.cpp +++ b/src/ast/sls/bv_sls_terms.cpp @@ -206,6 +206,7 @@ namespace bv { m_todo.push_back(arg); } // populate parents + m_parents.reset(); m_parents.reserve(m_terms.size()); for (expr* e : m_terms) { if (!e || !is_app(e)) @@ -213,6 +214,7 @@ namespace bv { for (expr* arg : *to_app(e)) m_parents[arg->get_id()].push_back(e); } + m_assertion_set.reset(); for (auto a : m_assertions) m_assertion_set.insert(a->get_id()); } diff --git a/src/ast/sls/sls_valuation.cpp b/src/ast/sls/sls_valuation.cpp index c46b2ef77..99b3921f3 100644 --- a/src/ast/sls/sls_valuation.cpp +++ b/src/ast/sls/sls_valuation.cpp @@ -245,6 +245,7 @@ namespace bv { } bool sls_valuation::set_random_at_most(bvect const& src, random_gen& r) { + m_tmp.set_bw(bw); if (!get_at_most(src, m_tmp)) return false; @@ -639,6 +640,14 @@ namespace bv { if (has_range() && !in_range(m_bits)) m_bits = m_lo; + + if (mod(lo() + 1, rational::power_of_two(bw)) == hi()) + for (unsigned i = 0; i < nw; ++i) + fixed[i] = ~0; + if (lo() < hi() && hi() < rational::power_of_two(bw - 1)) + for (unsigned i = 0; i < bw; ++i) + if (hi() < rational::power_of_two(i)) + fixed.set(i, true); SASSERT(well_formed()); } diff --git a/src/sat/sat_solver.cpp b/src/sat/sat_solver.cpp index 2d2962940..5829b18d1 100644 --- a/src/sat/sat_solver.cpp +++ b/src/sat/sat_solver.cpp @@ -1314,7 +1314,7 @@ namespace sat { } bool solver::should_cancel() { - if (limit_reached() || memory_exceeded()) { + if (limit_reached() || memory_exceeded() || m_solver_canceled) { return true; } if (m_config.m_restart_max <= m_restarts) { @@ -1959,6 +1959,7 @@ namespace sat { void solver::init_search() { m_model_is_current = false; + m_solver_canceled = false; m_phase_counter = 0; m_search_state = s_unsat; m_search_unsat_conflicts = m_config.m_search_unsat_conflicts; diff --git a/src/sat/sat_solver.h b/src/sat/sat_solver.h index 0361fc157..57477f686 100644 --- a/src/sat/sat_solver.h +++ b/src/sat/sat_solver.h @@ -177,6 +177,7 @@ namespace sat { clause_wrapper_vector m_clauses_to_reinit; std::string m_reason_unknown; bool m_trim = false; + bool m_solver_canceled = false; visit_helper m_visited; @@ -287,6 +288,7 @@ namespace sat { random_gen& rand() { return m_rand; } void set_trim() { m_trim = true; } + void set_canceled() { m_solver_canceled = true; } protected: void reset_var(bool_var v, bool ext, bool dvar); diff --git a/src/sat/sat_solver/sat_smt_solver.cpp b/src/sat/sat_solver/sat_smt_solver.cpp index ab0e71cc3..19b10eb3e 100644 --- a/src/sat/sat_solver/sat_smt_solver.cpp +++ b/src/sat/sat_solver/sat_smt_solver.cpp @@ -197,10 +197,16 @@ public: case l_false: extract_core(); break; - default: + default: { + auto* ext = get_euf(); + if (ext && ext->get_sls_model()) { + r = l_true; + break; + } set_reason_unknown(m_solver.get_reason_unknown()); break; } + } return r; } @@ -576,6 +582,7 @@ private: void add_assumption(expr* a) { init_goal2sat(); m_dep.insert(a, m_goal2sat.internalize(a)); + get_euf()->add_assertion(a); } void internalize_assumptions(expr_ref_vector const& asms) { @@ -632,6 +639,11 @@ private: void get_model_core(model_ref & mdl) override { TRACE("sat", tout << "retrieve model " << (m_solver.model_is_current()?"present":"absent") << "\n";); mdl = nullptr; + auto ext = get_euf(); + if (ext) + mdl = ext->get_sls_model(); + if (mdl) + return; if (!m_solver.model_is_current()) return; if (m_fmls.size() > m_qhead) diff --git a/src/sat/smt/euf_internalize.cpp b/src/sat/smt/euf_internalize.cpp index f750f186d..ebb6e4b85 100644 --- a/src/sat/smt/euf_internalize.cpp +++ b/src/sat/smt/euf_internalize.cpp @@ -525,4 +525,8 @@ namespace euf { return n; } + void solver::add_assertion(expr* f) { + m_assertions.push_back(f); + m_trail.push(push_back_vector(m_assertions)); + } } diff --git a/src/sat/smt/euf_model.cpp b/src/sat/smt/euf_model.cpp index 2035e16b6..ac7ef1522 100644 --- a/src/sat/smt/euf_model.cpp +++ b/src/sat/smt/euf_model.cpp @@ -18,6 +18,7 @@ Author: #include "ast/ast_pp.h" #include "ast/ast_ll_pp.h" #include "sat/smt/euf_solver.h" +#include "sat/smt/sls_solver.h" #include "model/value_factory.h" namespace euf { @@ -67,6 +68,14 @@ namespace euf { m_qmodel = mdl; } + model_ref solver::get_sls_model() { + model_ref mdl; + auto s = get_solver(m.mk_family_id("sls"), nullptr); + if (s) + mdl = dynamic_cast(s)->get_model(); + return mdl; + } + void solver::update_model(model_ref& mdl, bool validate) { TRACE("model", tout << "create model\n";); if (m_qmodel) { diff --git a/src/sat/smt/euf_solver.cpp b/src/sat/smt/euf_solver.cpp index b108430d8..efd091f66 100644 --- a/src/sat/smt/euf_solver.cpp +++ b/src/sat/smt/euf_solver.cpp @@ -28,6 +28,7 @@ Author: #include "sat/smt/q_solver.h" #include "sat/smt/fpa_solver.h" #include "sat/smt/dt_solver.h" +#include "sat/smt/sls_solver.h" #include "sat/smt/recfun_solver.h" #include "sat/smt/specrel_solver.h" @@ -54,6 +55,7 @@ namespace euf { m_smt_proof_checker(m, p), m_clause(m), m_expr_args(m), + m_assertions(m), m_values(m) { updt_params(p); @@ -77,6 +79,7 @@ namespace euf { }; m_egraph.set_on_merge(on_merge); } + } void solver::updt_params(params_ref const& p) { @@ -185,7 +188,9 @@ namespace euf { IF_VERBOSE(0, verbose_stream() << mk_pp(f, m) << " not handled\n"); } - void solver::init_search() { + void solver::init_search() { + if (get_config().m_sls_enable) + add_solver(alloc(sls::solver, *this)); TRACE("before_search", s().display(tout);); m_reason_unknown.clear(); for (auto* s : m_solvers) @@ -217,7 +222,7 @@ namespace euf { mark_relevant(lit); s().assign(lit, sat::justification::mk_ext_justification(s().scope_lvl(), idx)); } - + lbool solver::resolve_conflict() { for (auto* s : m_solvers) { lbool r = s->resolve_conflict(); @@ -664,7 +669,9 @@ namespace euf { if (give_up) return sat::check_result::CR_GIVEUP; if (m_qsolver && m_config.m_arith_ignore_int) - return sat::check_result::CR_GIVEUP; + return sat::check_result::CR_GIVEUP; + for (auto s : m_solvers) + s->finalize(); return sat::check_result::CR_DONE; } diff --git a/src/sat/smt/euf_solver.h b/src/sat/smt/euf_solver.h index 7d2d01473..ec89667d5 100644 --- a/src/sat/smt/euf_solver.h +++ b/src/sat/smt/euf_solver.h @@ -154,6 +154,7 @@ namespace euf { svector m_scopes; scoped_ptr_vector m_solvers; ptr_vector m_id2solver; + constraint* m_conflict = nullptr; constraint* m_eq = nullptr; @@ -173,6 +174,7 @@ namespace euf { symbol m_smt = symbol("smt"); expr_ref_vector m_clause; expr_ref_vector m_expr_args; + expr_ref_vector m_assertions; // internalization @@ -482,6 +484,10 @@ namespace euf { bool enable_ackerman_axioms(expr* n) const; bool is_fixed(euf::enode* n, expr_ref& val, sat::literal_vector& explain); + void add_assertion(expr* f); + expr_ref_vector const& get_assertions() { return m_assertions; } + model_ref get_sls_model(); + // relevancy bool relevancy_enabled() const { return m_relevancy.enabled(); } diff --git a/src/sat/smt/sat_th.h b/src/sat/smt/sat_th.h index e226566b8..cdd1292a3 100644 --- a/src/sat/smt/sat_th.h +++ b/src/sat/smt/sat_th.h @@ -148,6 +148,8 @@ namespace euf { virtual void set_bounds(enode* n) {} + virtual void finalize() {} + }; class th_proof_hint : public sat::proof_hint { @@ -225,6 +227,7 @@ namespace euf { void push() override { m_num_scopes++; } void pop(unsigned n) override; + unsigned random(); }; diff --git a/src/sat/smt/sls_solver.cpp b/src/sat/smt/sls_solver.cpp index ae8620e28..e12ff5ba7 100644 --- a/src/sat/smt/sls_solver.cpp +++ b/src/sat/smt/sls_solver.cpp @@ -23,38 +23,79 @@ Author: namespace sls { solver::solver(euf::solver& ctx): - th_euf_solver(ctx, symbol("sls"), ctx.get_manager().mk_family_id("sls")) {} + th_euf_solver(ctx, symbol("sls"), ctx.get_manager().mk_family_id("sls")), + m_units(m) {} solver::~solver() { - if (m_bvsls) { - m_bvsls->cancel(); + finalize(); + } + + void solver::finalize() { + if (!m_completed && m_bvsls) { + m_bvsls->cancel(); m_thread.join(); + m_bvsls->collect_statistics(m_st); + m_bvsls = nullptr; } } - void solver::push_core() { - if (s().scope_lvl() == s().search_lvl() + 1) - init_local_search(); - } - - void solver::pop_core(unsigned n) { - if (s().scope_lvl() - n <= s().search_lvl()) - sample_local_search(); + sat::check_result solver::check() { + + return sat::check_result::CR_DONE; } - void solver::simplify() { - + void solver::simplify() { + } + + bool solver::unit_propagate() { + force_push(); + sample_local_search(); + return false; + } + + bool solver::is_unit(expr* e) { + if (!e) + return false; + m.is_not(e, e); + if (is_uninterp_const(e)) + return true; + bv_util bu(m); + expr* s; + if (bu.is_bit2bool(e, s)) + return is_uninterp_const(s); + return false; + } + + void solver::push_core() { + + } + + void solver::pop_core(unsigned n) { + for (; m_trail_lim < s().init_trail_size(); ++m_trail_lim) { + auto lit = s().trail_literal(m_trail_lim); + auto e = ctx.literal2expr(lit); + if (is_unit(e)) { + // IF_VERBOSE(1, verbose_stream() << "add unit " << mk_pp(e, m) << "\n"); + std::lock_guard lock(m_mutex); + m_units.push_back(e); + m_has_units = true; + } + } } + void solver::init_search() { + init_local_search(); + } void solver::init_local_search() { if (m_bvsls) { m_bvsls->cancel(); m_thread.join(); - if (m_result == l_true) { - verbose_stream() << "Found model using local search - INIT\n"; - exit(1); - } + m_result = l_undef; + m_completed = false; + m_has_units = false; + m_model = nullptr; + m_units.reset(); } // set up state for local search solver here @@ -64,43 +105,12 @@ namespace sls { params_ref p; m_completed = false; m_result = l_undef; + m_model = nullptr; m_bvsls = alloc(bv::sls, *m_m, p); - // walk clauses, add them - // walk trail stack until search level, add units - // encapsulate bvsls within the arguments of run-local-search. - // ensure bvsls does not touch ast-manager. + + for (expr* a : ctx.get_assertions()) + m_bvsls->assert_expr(tr(a)); - unsigned trail_sz = s().trail_size(); - for (unsigned i = 0; i < trail_sz; ++i) { - auto lit = s().trail_literal(i); - if (s().lvl(lit) > s().search_lvl()) - break; - expr_ref fml = literal2expr(lit); - m_bvsls->assert_expr(tr(fml.get())); - } - unsigned num_vars = s().num_vars(); - for (unsigned i = 0; i < 2*num_vars; ++i) { - auto l1 = ~sat::to_literal(i); - auto const& wlist = s().get_wlist(l1); - for (sat::watched const& w : wlist) { - if (!w.is_binary_non_learned_clause()) - continue; - sat::literal l2 = w.get_literal(); - if (l1.index() > l2.index()) - continue; - expr_ref fml(m.mk_or(literal2expr(l1), literal2expr(l2)), m); - m_bvsls->assert_expr(tr(fml.get())); - } - } - for (auto clause : s().clauses()) { - expr_ref_vector cls(m); - for (auto lit : *clause) - cls.push_back(literal2expr(lit)); - expr_ref fml(m.mk_or(cls), m); - m_bvsls->assert_expr(tr(fml.get())); - } - - // use phase assignment from literals? std::function eval = [&](expr* e, unsigned r) { return false; }; @@ -108,23 +118,42 @@ namespace sls { m_bvsls->init(); m_bvsls->init_eval(eval); m_bvsls->updt_params(s().params()); - + m_bvsls->init_unit([&]() { + if (!m_has_units) + return expr_ref(*m_m); + expr_ref e(m); + { + std::lock_guard lock(m_mutex); + if (m_units.empty()) + return expr_ref(*m_m); + e = m_units.back(); + m_units.pop_back(); + } + ast_translation tr(m, *m_m); + return expr_ref(tr(e.get()), *m_m); + }); + m_thread = std::thread([this]() { run_local_search(); }); } void solver::sample_local_search() { - if (m_completed) { - m_thread.join(); - if (m_result == l_true) { - verbose_stream() << "Found model using local search\n"; - exit(1); - } + if (!m_completed) + return; + m_thread.join(); + m_completed = false; + m_bvsls->collect_statistics(m_st); + if (m_result == l_true) { + IF_VERBOSE(2, verbose_stream() << "(sat.sls :model-completed)\n";); + auto mdl = m_bvsls->get_model(); + ast_translation tr(*m_m, m); + m_model = mdl->translate(tr); + s().set_canceled(); } + m_bvsls = nullptr; } void solver::run_local_search() { - lbool r = (*m_bvsls)(); - m_result = r; + m_result = (*m_bvsls)(); m_completed = true; } diff --git a/src/sat/smt/sls_solver.h b/src/sat/smt/sls_solver.h index c473264ac..55e98fac7 100644 --- a/src/sat/smt/sls_solver.h +++ b/src/sat/smt/sls_solver.h @@ -30,30 +30,43 @@ namespace sls { class solver : public euf::th_euf_solver { std::atomic m_result; - std::atomic m_completed; + std::atomic m_completed, m_has_units; std::thread m_thread; + std::mutex m_mutex; scoped_ptr m_m; scoped_ptr m_bvsls; + model_ref m_model; + unsigned m_trail_lim = 0; + expr_ref_vector m_units; + statistics m_st; void run_local_search(); void init_local_search(); void sample_local_search(); + + bool is_unit(expr*); public: solver(euf::solver& ctx); ~solver(); + void simplify() override; + void init_search() override; + void push_core() override; void pop_core(unsigned n) override; - void simplify() override; sat::literal internalize(expr* e, bool sign, bool root) override { UNREACHABLE(); return sat::null_literal; } void internalize(expr* e) override { UNREACHABLE(); } th_solver* clone(euf::solver& ctx) override { return alloc(solver, ctx); } - + void collect_statistics(statistics& st) const override { st.copy(m_st); } - bool unit_propagate() override { return false; } + model_ref get_model() { return m_model; } + + void finalize() override; + + bool unit_propagate() override; void get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector & r, bool probing) override { UNREACHABLE(); } - sat::check_result check() override { return sat::check_result::CR_DONE; } + sat::check_result check() override; std::ostream & display(std::ostream & out) const override { return out; } std::ostream & display_justification(std::ostream & out, sat::ext_justification_idx idx) const override { UNREACHABLE(); return out; } std::ostream & display_constraint(std::ostream & out, sat::ext_constraint_idx idx) const override { UNREACHABLE(); return out; } diff --git a/src/sat/tactic/goal2sat.cpp b/src/sat/tactic/goal2sat.cpp index 267888804..57e3a89b5 100644 --- a/src/sat/tactic/goal2sat.cpp +++ b/src/sat/tactic/goal2sat.cpp @@ -895,6 +895,7 @@ struct goal2sat::imp : public sat::sat_internalizer { process(n, true); CTRACE("goal2sat", !m_result_stack.empty(), tout << m_result_stack << "\n";); SASSERT(m_result_stack.empty()); + add_assertion(n); } void insert_dep(expr* dep0, expr* dep, bool sign) { @@ -989,6 +990,12 @@ struct goal2sat::imp : public sat::sat_internalizer { } } + void add_assertion(expr* f) { + auto* ext = dynamic_cast(m_solver.get_extension()); + if (ext) + ext->add_assertion(f); + } + void update_model(model_ref& mdl) { auto* ext = dynamic_cast(m_solver.get_extension()); if (ext) diff --git a/src/smt/params/smt_params.cpp b/src/smt/params/smt_params.cpp index ef617f724..02919b287 100644 --- a/src/smt/params/smt_params.cpp +++ b/src/smt/params/smt_params.cpp @@ -49,6 +49,7 @@ void smt_params::updt_local_params(params_ref const & _p) { m_threads_max_conflicts = p.threads_max_conflicts(); m_threads_cube_frequency = p.threads_cube_frequency(); m_core_validate = p.core_validate(); + m_sls_enable = p.sls_enable(); m_logic = _p.get_sym("logic", m_logic); m_string_solver = p.string_solver(); m_up_persist_clauses = p.up_persist_clauses(); @@ -66,6 +67,7 @@ void smt_params::updt_local_params(params_ref const & _p) { m_lemmas2console = sp.lemmas2console(); m_instantiations2console = sp.instantiations2console(); m_proof_log = sp.proof_log(); + } void smt_params::updt_params(params_ref const & p) { diff --git a/src/smt/params/smt_params.h b/src/smt/params/smt_params.h index c678b7536..0ef063e4a 100644 --- a/src/smt/params/smt_params.h +++ b/src/smt/params/smt_params.h @@ -114,6 +114,7 @@ struct smt_params : public preprocessor_params, bool m_induction = false; bool m_clause_proof = false; symbol m_proof_log; + bool m_sls_enable = false; // ----------------------------------- // diff --git a/src/smt/params/smt_params_helper.pyg b/src/smt/params/smt_params_helper.pyg index 708fa87d8..4e498b2c4 100644 --- a/src/smt/params/smt_params_helper.pyg +++ b/src/smt/params/smt_params_helper.pyg @@ -135,6 +135,7 @@ def_module_params(module_name='smt', ('str.regex_automata_length_attempt_threshold', UINT, 10, 'number of length/path constraint attempts before checking unsatisfiability of regex terms'), ('str.fixed_length_refinement', BOOL, False, 'use abstraction refinement in fixed-length equation solver (Z3str3 only)'), ('str.fixed_length_naive_cex', BOOL, True, 'construct naive counterexamples when fixed-length model construction fails for a given length assignment (Z3str3 only)'), + ('sls.enable', BOOL, False, 'enable sls co-processor with SMT engine'), ('core.minimize', BOOL, False, 'minimize unsat core produced by SMT context'), ('core.extend_patterns', BOOL, False, 'extend unsat core with literals that trigger (potential) quantifier instances'), ('core.extend_patterns.max_distance', UINT, UINT_MAX, 'limits the distance of a pattern-extended unsat core'),