From 5e30323b1ac7626634e2d1f167296d9f828ed676 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sat, 11 Feb 2023 15:46:39 -0800 Subject: [PATCH] wip - bounded local search for arithmetic --- src/sat/sat_ddfw.cpp | 10 ++-- src/sat/sat_extension.h | 2 +- src/sat/sat_solver.cpp | 9 +++ src/sat/smt/arith_sls.cpp | 95 ++++++++++++++++++++------------ src/sat/smt/arith_sls.h | 33 +++++++++-- src/sat/smt/arith_solver.h | 2 +- src/sat/smt/euf_local_search.cpp | 35 ++++++------ src/sat/smt/euf_solver.h | 4 +- 8 files changed, 124 insertions(+), 66 deletions(-) diff --git a/src/sat/sat_ddfw.cpp b/src/sat/sat_ddfw.cpp index 98e3ce2bd..747ea4940 100644 --- a/src/sat/sat_ddfw.cpp +++ b/src/sat/sat_ddfw.cpp @@ -148,7 +148,8 @@ namespace sat { m_use_list[lit.index()].pop_back(); m_alloc.del_clause(info.m_clause); m_clauses.pop_back(); - m_unsat.remove(m_clauses.size()); + if (m_unsat.contains(m_clauses.size())) + m_unsat.remove(m_clauses.size()); } void ddfw::add(solver const& s) { @@ -188,12 +189,11 @@ namespace sat { } void ddfw::remove_assumptions() { + if (m_assumptions.empty()) + return; for (unsigned i = 0; i < m_assumptions.size(); ++i) del(); - m_unsat_vars.reset(); - for (auto idx : m_unsat) - for (auto lit : get_clause(idx)) - m_unsat_vars.insert(lit.var()); + init(0, nullptr); } void ddfw::init(unsigned sz, literal const* assumptions) { diff --git a/src/sat/sat_extension.h b/src/sat/sat_extension.h index 3a1f363a3..ae99cae12 100644 --- a/src/sat/sat_extension.h +++ b/src/sat/sat_extension.h @@ -126,7 +126,7 @@ namespace sat { virtual void add_assumptions(literal_set& ext_assumptions) {} virtual bool tracking_assumptions() { return false; } virtual bool enable_self_propagate() const { return false; } - virtual void local_search(bool_vector& phase) {} + virtual lbool local_search(bool_vector& phase) { return l_undef; } virtual bool extract_pb(std::function& card, std::function& pb) { diff --git a/src/sat/sat_solver.cpp b/src/sat/sat_solver.cpp index 10aac6dcb..898c3a2a4 100644 --- a/src/sat/sat_solver.cpp +++ b/src/sat/sat_solver.cpp @@ -1302,6 +1302,9 @@ namespace sat { return l_undef; } + // uncomment this to test bounded local search: + // bounded_local_search(); + log_stats(); if (m_config.m_max_conflicts > 0 && m_config.m_burst_search > 0) { m_restart_threshold = m_config.m_burst_search; @@ -1360,6 +1363,12 @@ namespace sat { }; void solver::bounded_local_search() { + if (m_ext) { + verbose_stream() << "bounded local search\n"; + do_restart(true); + m_ext->local_search(m_best_phase); + return; + } literal_vector _lits; scoped_limits scoped_rl(rlimit()); m_local_search = alloc(ddfw); diff --git a/src/sat/smt/arith_sls.cpp b/src/sat/smt/arith_sls.cpp index 7e0997156..79e6aec54 100644 --- a/src/sat/smt/arith_sls.cpp +++ b/src/sat/smt/arith_sls.cpp @@ -20,21 +20,24 @@ Author: namespace arith { - - /// - /// need to initialize ineqs (arithmetical atoms) - /// - - sls::sls(solver& s): s(s), m(s.m) {} - void sls::operator()(bool_vector& phase) { + void sls::reset() { + m_literals.reset(); + m_vars.reset(); + m_clauses.reset(); + m_terms.reset(); + } + + lbool sls::operator()(bool_vector& phase) { unsigned num_steps = 0; for (unsigned v = 0; v < s.s().num_vars(); ++v) init_bool_var_assignment(v); m_best_min_unsat = unsat().size(); + verbose_stream() << "max arith steps " << m_max_arith_steps << "\n"; + //m_max_arith_steps = 10000; while (m.inc() && m_best_min_unsat > 0 && num_steps < m_max_arith_steps) { if (!flip()) break; @@ -47,24 +50,27 @@ namespace arith { save_best_values(); } } - IF_VERBOSE(2, verbose_stream() << "(sls " << m_stats.m_num_flips << " " << unsat().size() << ")\n"); + log(); + return unsat().empty() ? l_true : l_undef; + } + + void sls::log() { + IF_VERBOSE(2, verbose_stream() << "(sls :flips " << m_stats.m_num_flips << " :unsat " << unsat().size() << ")\n"); } void sls::save_best_values() { // first compute assignment to terms - // then update non-basic variables in tableau, assuming a sat solution was found. -#if false - for (auto const& [t, v] : terms) { + // then update non-basic variables in tableau. + for (auto const& [t, v] : m_terms) { rational val; - lp::lar_term const& term = lp().get_term(t); + lp::lar_term const& term = s.lp().get_term(t); for (lp::lar_term::ival arg : term) { - auto t2 = lp().column2tv(arg.column()); - auto w = lp().local_to_external(t2.id()); - val += arg.coeff() * local_search.value(w); + auto t2 = s.lp().column2tv(arg.column()); + auto w = s.lp().local_to_external(t2.id()); + val += arg.coeff() * value(w); } update(v, val); } -#endif for (unsigned v = 0; v < s.get_num_vars(); ++v) { if (s.is_bool(v)) @@ -87,6 +93,8 @@ namespace arith { void sls::set(sat::ddfw* d) { m_bool_search = d; + reset(); + m_literals.reserve(s.s().num_vars() * 2); add_vars(); m_clauses.resize(d->num_clauses()); for (unsigned i = 0; i < d->num_clauses(); ++i) @@ -151,12 +159,16 @@ namespace arith { bool sls::cm(ineq const& ineq, var_t v, rational& new_value) { SASSERT(!ineq.is_true()); auto delta = ineq.m_args_value - ineq.m_bound; + if (ineq.m_op == ineq_kind::NE || ineq.m_op == ineq_kind::LT) + delta--; for (auto const& [coeff, w] : ineq.m_args) { if (w == v) { + if (coeff > 0) new_value = value(v) - abs(ceil(delta / coeff)); else new_value = value(v) + abs(floor(delta / coeff)); + switch (ineq.m_op) { case ineq_kind::LE: SASSERT(delta + coeff * (new_value - value(v)) <= 0); @@ -189,9 +201,12 @@ namespace arith { auto const& clause = get_clause(cl); rational new_value; for (literal lit : clause) { - auto const* ineq = atom(lit); - if (!ineq || ineq->is_true()) + if (is_true(lit)) continue; + auto const* ineq = atom(lit); + if (!ineq) + continue; + SASSERT(!ineq->is_true()); for (auto const& [coeff, v] : ineq->m_args) { if (!cm(*ineq, v, new_value)) continue; @@ -201,8 +216,9 @@ namespace arith { unsigned num_unsat = unsat().size(); update(v, new_value); IF_VERBOSE(2, - verbose_stream() << "score " << v << " " << score << "\n" + verbose_stream() << "v" << v << " score " << score << " " << num_unsat << " -> " << unsat().size() << "\n"); + SASSERT(num_unsat > unsat().size()); return true; } } @@ -255,7 +271,8 @@ namespace arith { } /** - * redistribute weights of clauses. TODO - re-use ddfw weights instead. + * redistribute weights of clauses. + * TODO - re-use ddfw weights instead. */ void sls::paws() { for (unsigned cl = num_clauses(); cl-- > 0; ) { @@ -270,13 +287,15 @@ namespace arith { // // dscore(op) = sum_c (dts(c,alpha) - dts(c,alpha_after)) * weight(c) + // TODO - use cached dts instead of computed dts + // cached dts has to be updated when the score of literals are updated. // rational sls::dscore(var_t v, rational const& new_value) const { auto const& vi = m_vars[v]; rational score(0); for (auto const& [coeff, lit] : vi.m_literals) for (auto cl : m_bool_search->get_use_list(lit)) - score += (dts(cl) - dts(cl, v, new_value)) * rational(get_weight(cl)); + score += (compute_dts(cl) - dts(cl, v, new_value)) * rational(get_weight(cl)); return score; } @@ -290,10 +309,11 @@ namespace arith { for (auto cl : m_bool_search->get_use_list(lit)) { auto const& clause = get_clause_info(cl); if (!clause.is_true()) { + VERIFY(dtt_old != 0); if (dtt_new == 0) ++score; // false -> true } - else if (dtt_new == 0 || dtt_old > 0 || clause.m_num_trues > 0) // true -> true ?? TODO + else if (dtt_new == 0 || dtt_old > 0 || clause.m_num_trues > 1) // true -> true not really, same variable can be in multiple literals continue; else if (all_of(*clause.m_clause, [&](auto lit) { return !atom(lit) || dtt(*atom(lit), v, new_value) > 0; })) // ?? TODO --score; @@ -302,7 +322,7 @@ namespace arith { return score; } - rational sls::dts(unsigned cl) const { + rational sls::compute_dts(unsigned cl) const { rational d(1), d2; bool first = true; for (auto a : get_clause(cl)) { @@ -346,14 +366,20 @@ namespace arith { rational dtt_old = dtt(ineq); ineq.m_args_value += coeff * (new_value - old_value); rational dtt_new = dtt(ineq); - SASSERT(!(dtt_new == 0 && dtt_new < dtt_old) || m_bool_search->get_value(lit.var()) == lit.sign()); - SASSERT(!(dtt_old == 0 && dtt_new > dtt_old) || m_bool_search->get_value(lit.var()) != lit.sign()); + if ((dtt_new == 0) == is_true(lit)) { + dtt(ineq) = dtt_new; + continue; + } + VERIFY((dtt_old == 0) == is_true(lit)); + VERIFY(!(dtt_new == 0 && dtt_new < dtt_old) || !is_true(lit)); + VERIFY(!(dtt_old == 0 && dtt_new > dtt_old) || is_true(lit)); if (dtt_new == 0 && dtt_new < dtt_old) // flip from false to true m_bool_search->flip(lit.var()); else if (dtt_old == 0 && dtt_old < dtt_new) // flip from true to false m_bool_search->flip(lit.var()); dtt(ineq) = dtt_new; - SASSERT((dtt_new == 0) == (m_bool_search->get_value(lit.var()) != lit.sign())); + + VERIFY((dtt_new == 0) == is_true(lit)); } vi.m_value = new_value; } @@ -422,18 +448,18 @@ namespace arith { } - void sls::add_args(ineq& ineq, lp::tv t, theory_var v, rational sign) { + void sls::add_args(sat::literal lit, ineq& ineq, lp::tv t, theory_var v, rational sign) { if (t.is_term()) { lp::lar_term const& term = s.lp().get_term(t); for (lp::lar_term::ival arg : term) { auto t2 = s.lp().column2tv(arg.column()); auto w = s.lp().local_to_external(t2.id()); - ineq.m_args.push_back({ sign * arg.coeff(), w }); + add_arg(lit, ineq, sign * arg.coeff(), w); } } else - ineq.m_args.push_back({ sign, s.lp().local_to_external(t.id()) }); + add_arg(lit, ineq, sign, s.lp().local_to_external(t.id())); } @@ -465,7 +491,7 @@ namespace arith { bound.neg(); auto& ineq = new_ineq(op, bound); - add_args(ineq, t, b->get_var(), should_minus ? rational::minus_one() :rational::one()); + add_args(lit, ineq, t, b->get_var(), should_minus ? rational::minus_one() :rational::one()); m_literals.set(lit.index(), &ineq); return; } @@ -478,8 +504,8 @@ namespace arith { lp::tv tu = s.get_tv(u); lp::tv tv = s.get_tv(v); auto& ineq = new_ineq(lit.sign() ? sls::ineq_kind::NE : sls::ineq_kind::EQ, rational::zero()); - add_args(ineq, tu, u, rational::one()); - add_args(ineq, tv, v, -rational::one()); + add_args(lit, ineq, tu, u, rational::one()); + add_args(lit, ineq, tv, v, -rational::one()); m_literals.set(lit.index(), &ineq); return; } @@ -492,8 +518,9 @@ namespace arith { void sls::init_literal_assignment(sat::literal lit) { auto* ineq = m_literals.get(lit.index(), nullptr); - if (ineq && m_bool_search->get_value(lit.var()) != (dtt(*ineq) == 0)) - m_bool_search->flip(lit.var()); + + if (ineq && is_true(lit) != (dtt(*ineq) == 0)) + m_bool_search->flip(lit.var()); } } diff --git a/src/sat/smt/arith_sls.h b/src/sat/smt/arith_sls.h index 16a8549f9..9a4cfcd81 100644 --- a/src/sat/smt/arith_sls.h +++ b/src/sat/smt/arith_sls.h @@ -55,6 +55,7 @@ namespace arith { unsigned m_num_flips = 0; }; + public: // encode args <= bound, args = bound, args < bound struct ineq { vector> m_args; @@ -74,7 +75,23 @@ namespace arith { return m_args_value < m_bound; } } + std::ostream& display(std::ostream& out) const { + bool first = true; + for (auto const& [c, v] : m_args) + out << (first? "": " + ") << c << " * v" << v, first = false; + switch (m_op) { + case ineq_kind::LE: + return out << " <= " << m_bound << "(" << m_args_value << ")"; + case ineq_kind::EQ: + return out << " == " << m_bound << "(" << m_args_value << ")"; + case ineq_kind::NE: + return out << " != " << m_bound << "(" << m_args_value << ")"; + default: + return out << " < " << m_bound << "(" << m_args_value << ")"; + } + } }; + private: struct var_info { rational m_value; @@ -85,6 +102,7 @@ namespace arith { struct clause { unsigned m_weight = 1; + rational m_dts = rational::one(); }; solver& s; @@ -97,6 +115,8 @@ namespace arith { scoped_ptr_vector m_literals; vector m_vars; vector m_clauses; + svector> m_terms; + indexed_uint_set& unsat() { return m_bool_search->unsat_set(); } unsigned num_clauses() const { return m_bool_search->num_clauses(); } @@ -104,12 +124,14 @@ namespace arith { sat::clause const& get_clause(unsigned idx) const { return *get_clause_info(idx).m_clause; } sat::ddfw::clause_info& get_clause_info(unsigned idx) { return m_bool_search->get_clause_info(idx); } sat::ddfw::clause_info const& get_clause_info(unsigned idx) const { return m_bool_search->get_clause_info(idx); } + bool is_true(sat::literal lit) { return lit.sign() != m_bool_search->get_value(lit.var()); } + void reset(); ineq* atom(sat::literal lit) const { return m_literals[lit.index()]; } unsigned& get_weight(unsigned idx) { return m_clauses[idx].m_weight; } unsigned get_weight(unsigned idx) const { return m_clauses[idx].m_weight; } bool flip(); - void log() {} + void log(); bool flip_unsat(); bool flip_clauses(); bool flip_dscore(); @@ -119,7 +141,7 @@ namespace arith { rational dtt(rational const& args, ineq const& ineq) const; rational dtt(ineq const& ineq, var_t v, rational const& new_value) const; rational dts(unsigned cl, var_t v, rational const& new_value) const; - rational dts(unsigned cl) const; + rational compute_dts(unsigned cl) const; bool cm(ineq const& ineq, var_t v, rational& new_value); int cm_score(var_t v, rational const& new_value); void update(var_t v, rational const& new_value); @@ -130,7 +152,7 @@ namespace arith { sls::ineq& new_ineq(ineq_kind op, rational const& bound); void add_arg(sat::literal lit, ineq& ineq, rational const& c, var_t v); void add_bounds(sat::literal_vector& bounds); - void add_args(ineq& ineq, lp::tv t, euf::theory_var v, rational sign); + void add_args(sat::literal lit, ineq& ineq, lp::tv t, euf::theory_var v, rational sign); void init_literal(sat::literal lit); void init_bool_var_assignment(sat::bool_var v); void init_literal_assignment(sat::literal lit); @@ -138,11 +160,14 @@ namespace arith { rational value(var_t v) const { return m_vars[v].m_value; } public: sls(solver& s); - void operator ()(bool_vector& phase); + lbool operator ()(bool_vector& phase); void set_bounds_begin(); void set_bounds_end(unsigned num_literals); void set_bounds(euf::enode* n); void set(sat::ddfw* d); }; + inline std::ostream& operator<<(std::ostream& out, sls::ineq const& ineq) { + return ineq.display(out); + } } diff --git a/src/sat/smt/arith_solver.h b/src/sat/smt/arith_solver.h index 732e291b1..a31ca844a 100644 --- a/src/sat/smt/arith_solver.h +++ b/src/sat/smt/arith_solver.h @@ -515,7 +515,7 @@ namespace arith { void set_bounds_begin() override { m_local_search.set_bounds_begin(); } void set_bounds_end(unsigned num_literals) override { m_local_search.set_bounds_end(num_literals); } void set_bounds(enode* n) override { m_local_search.set_bounds(n); } - void local_search(bool_vector& phase) override { m_local_search(phase); } + lbool local_search(bool_vector& phase) override { return m_local_search(phase); } void set_bool_search(sat::ddfw* ddfw) override { m_local_search.set(ddfw); } // bounds and equality propagation callbacks diff --git a/src/sat/smt/euf_local_search.cpp b/src/sat/smt/euf_local_search.cpp index 873a64b7e..fab3b7815 100644 --- a/src/sat/smt/euf_local_search.cpp +++ b/src/sat/smt/euf_local_search.cpp @@ -21,7 +21,7 @@ Author: namespace euf { - void solver::local_search(bool_vector& phase) { + lbool solver::local_search(bool_vector& phase) { scoped_limits scoped_rl(m.limit()); sat::ddfw bool_search; bool_search.reinit(s(), phase); @@ -36,7 +36,7 @@ namespace euf { for (unsigned rounds = 0; m.inc() && rounds < max_rounds; ++rounds) { - setup_bounds(phase); + setup_bounds(bool_search, phase); // Non-boolean literals are assumptions to Boolean search literal_vector assumptions; @@ -44,6 +44,8 @@ namespace euf { if (!is_propositional(literal(v))) assumptions.push_back(literal(v, !bool_search.get_value(v))); + verbose_stream() << "assumptions " << assumptions.size() << "\n"; + bool_search.rlimit().push(m_max_bool_steps); lbool r = bool_search.check(assumptions.size(), assumptions.data(), nullptr); @@ -51,15 +53,15 @@ namespace euf { for (auto* th : m_solvers) th->local_search(phase); - // if is_sat break; + if (bool_search.unsat_set().empty()) + break; } - - auto const& mdl = bool_search.get_model(); for (unsigned i = 0; i < mdl.size(); ++i) - phase[i] = mdl[i] == l_true; - + phase[i] = mdl[i] == l_true; + + return bool_search.unsat_set().empty() ? l_true : l_undef; } bool solver::is_propositional(sat::literal lit) { @@ -67,13 +69,13 @@ namespace euf { return !e || is_uninterp_const(e) || !m_egraph.find(e); } - void solver::setup_bounds(bool_vector const& phase) { + void solver::setup_bounds(sat::ddfw& bool_search, bool_vector const& phase) { unsigned num_literals = 0; unsigned num_bool = 0; for (auto* th : m_solvers) th->set_bounds_begin(); - auto init_literal = [&](sat::literal l) { + auto count_literal = [&](sat::literal l) { if (is_propositional(l)) { ++num_bool; return; @@ -86,16 +88,11 @@ namespace euf { } }; - auto is_true = [&](auto lit) { - return phase[lit.var()] == !lit.sign(); - }; - - for (auto* cp : s().clauses()) { - if (any_of(*cp, [&](auto lit) { return is_true(lit); })) - continue; - num_literals += cp->size(); - for (auto l : *cp) - init_literal(l); + for (auto cl : bool_search.unsat_set()) { + auto& c = *bool_search.get_clause_info(cl).m_clause; + num_literals += c.size(); + for (auto l : c) + count_literal(l); } m_max_bool_steps = (m_ls_config.L * num_bool) / num_literals; diff --git a/src/sat/smt/euf_solver.h b/src/sat/smt/euf_solver.h index a19dbca5d..d62390329 100644 --- a/src/sat/smt/euf_solver.h +++ b/src/sat/smt/euf_solver.h @@ -265,7 +265,7 @@ namespace euf { // local search unsigned m_max_bool_steps = 10; bool is_propositional(sat::literal lit); - void setup_bounds(bool_vector const& mdl); + void setup_bounds(sat::ddfw& bool_search, bool_vector const& mdl); // user propagator void check_for_user_propagator() { @@ -353,7 +353,7 @@ namespace euf { void add_assumptions(sat::literal_set& assumptions) override; bool tracking_assumptions() override; std::string reason_unknown() override { return m_reason_unknown; } - void local_search(bool_vector& phase) override; + lbool local_search(bool_vector& phase) override; void propagate(literal lit, ext_justification_idx idx); bool propagate(enode* a, enode* b, ext_justification_idx idx);