From d22e4aa5259137961bc5ae2c45403b9655aa5f1e Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Thu, 9 Feb 2023 15:52:32 -0800 Subject: [PATCH] wip - integrating arithmetic local search --- src/sat/sat_ddfw.cpp | 27 ++- src/sat/sat_ddfw.h | 15 +- src/sat/sat_solver.h | 1 + src/sat/smt/arith_local_search.cpp | 320 ++++++++++++++++------------- src/sat/smt/arith_solver.h | 60 ++++-- src/sat/smt/euf_local_search.cpp | 23 +-- src/sat/smt/sat_th.h | 3 + 7 files changed, 280 insertions(+), 169 deletions(-) diff --git a/src/sat/sat_ddfw.cpp b/src/sat/sat_ddfw.cpp index 723b38586..74a0e7777 100644 --- a/src/sat/sat_ddfw.cpp +++ b/src/sat/sat_ddfw.cpp @@ -129,6 +129,7 @@ namespace sat { void ddfw::add(unsigned n, literal const* c) { clause* cls = m_alloc.mk_clause(n, c, false); unsigned idx = m_clauses.size(); + m_clauses.push_back(clause_info(cls, m_config.m_init_clause_weight)); for (literal lit : *cls) { m_use_list.reserve(2*(lit.var()+1)); @@ -137,6 +138,18 @@ namespace sat { } } + /** + * Remove the last clause that was added + */ + void ddfw::del() { + auto& info = m_clauses.back(); + for (literal lit : *info.m_clause) + m_use_list[lit.index()].pop_back(); + m_alloc.del_clause(info.m_clause); + m_clauses.pop_back(); + m_unsat.remove(m_clauses.size()); + } + void ddfw::add(solver const& s) { for (auto& ci : m_clauses) m_alloc.del_clause(ci.m_clause); @@ -169,9 +182,17 @@ namespace sat { } void ddfw::add_assumptions() { - for (unsigned i = 0; i < m_assumptions.size(); ++i) { - add(1, m_assumptions.data() + i); - } + for (unsigned i = 0; i < m_assumptions.size(); ++i) + add(1, m_assumptions.data() + i); + } + + void ddfw::remove_assumptions() { + for (unsigned i = 0; i < m_assumptions.size(); ++i) + del(); + m_unsat_vars.reset(); + for (auto idx : m_unsat) + for (auto lit : get_clause(idx)) + m_unsat_vars.insert(lit.var()); } void ddfw::init(unsigned sz, literal const* assumptions) { diff --git a/src/sat/sat_ddfw.h b/src/sat/sat_ddfw.h index ce5ff9fdf..7dd69de81 100644 --- a/src/sat/sat_ddfw.h +++ b/src/sat/sat_ddfw.h @@ -32,7 +32,7 @@ namespace sat { class parallel; class ddfw : public i_local_search { - + public: struct clause_info { clause_info(clause* cl, double init_weight): m_weight(init_weight), m_clause(cl) {} double m_weight; // weight of clause @@ -43,6 +43,7 @@ namespace sat { void add(literal lit) { ++m_num_trues; m_trues += lit.index(); } void del(literal lit) { SASSERT(m_num_trues > 0); --m_num_trues; m_trues -= lit.index(); } }; + protected: struct config { config() { reset(); } @@ -197,6 +198,8 @@ namespace sat { void add(unsigned sz, literal const* c); + void del(); + void add_assumptions(); inline void transfer_weight(unsigned from, unsigned to, double w); @@ -232,6 +235,16 @@ namespace sat { void collect_statistics(statistics& st) const override {} double get_priority(bool_var v) const override { return m_probs[v]; } + + // access clause information and state of Boolean search + indexed_uint_set& unsat_set() { return m_unsat; } + + unsigned num_clauses() const { return m_clauses.size(); } + + clause_info& get_clause_info(unsigned idx) { return m_clauses[idx]; } + + void remove_assumptions(); + }; } diff --git a/src/sat/sat_solver.h b/src/sat/sat_solver.h index ca738ce9b..703b36dd0 100644 --- a/src/sat/sat_solver.h +++ b/src/sat/sat_solver.h @@ -347,6 +347,7 @@ namespace sat { s.m_checkpoint_enabled = true; } }; + unsigned select_watch_lit(clause const & cls, unsigned starting_at) const; unsigned select_learned_watch_lit(clause const & cls) const; bool simplify_clause(unsigned & num_lits, literal * lits) const; diff --git a/src/sat/smt/arith_local_search.cpp b/src/sat/smt/arith_local_search.cpp index c6d02e436..787280e73 100644 --- a/src/sat/smt/arith_local_search.cpp +++ b/src/sat/smt/arith_local_search.cpp @@ -36,22 +36,19 @@ namespace arith { // need to init variables/atoms/ineqs m.limit().push(m_max_arith_steps); - - unsigned m_best_min_unsat = 1; - unsigned best = m_best_min_unsat; - - while (m.inc() && m_best_min_unsat > 0) { - // unsigned prev = m_unsat.size(); + m_best_min_unsat = unsat().size(); + unsigned num_steps = 0; + while (m.inc() && m_best_min_unsat > 0 && num_steps < m_max_arith_steps) { if (!flip()) return; -#if 0 - if (m_unsat.size() < best) { - best = m_unsat.size(); + ++m_stats.m_num_flips; + ++num_steps; + unsigned num_unsat = unsat().size(); + if (num_unsat < m_best_min_unsat) { + m_best_min_unsat = num_unsat; num_steps = 0; - } - if (m_unsat.size() < m_best_min_unsat) save_best_values(); -#endif + } } } @@ -68,7 +65,6 @@ namespace arith { } bool solver::sls::flip() { - ++m_stats.m_num_flips; log(); if (flip_unsat()) return true; @@ -141,45 +137,72 @@ namespace arith { return false; } -#if 0 - bool solver::sls::flip_unsat() { - unsigned start = m_rand(); - for (unsigned i = m_unsat.size(); i-- > 0; ) { - unsigned cl = m_unsat.elem_at((i + start) % m_unsat.size()); - if (flip(m_clauses[cl])) + unsigned start = s.random(); + unsigned sz = unsat().size(); + for (unsigned i = sz; i-- > 0; ) { + unsigned cl = unsat().elem_at((i + start) % sz); + if (flip(cl)) return true; } return false; } + + bool solver::sls::flip(unsigned cl) { + auto const& clause = get_clause(cl); + rational new_value; + for (literal lit : clause) { + auto const* ai = atom(lit); + if (!ai) + continue; + ineq const& ineq = ai->m_ineq; + for (auto const& [coeff, v] : ineq.m_args) { + if (!ineq.is_true() && cm(ineq, v, new_value)) { + int score = cm_score(v, new_value); + if (score <= 0) + continue; + unsigned num_unsat = unsat().size(); + update(v, new_value); + IF_VERBOSE(0, + verbose_stream() << "score " << v << " " << score << "\n" + << num_unsat << " -> " << unsat().size() << "\n"); + return true; + } + } + } + return false; + } + bool solver::sls::flip_clauses() { - unsigned start = m_rand(); - for (unsigned i = m_clauses.size(); i-- > 0; ) - if (flip_arith(m_clauses[(i + start) % m_clauses.size()])) + unsigned start = s.random(); + for (unsigned i = num_clauses(); i-- > 0; ) + if (flip((i + start) % num_clauses())) return true; return false; } bool solver::sls::flip_dscore() { paws(); - unsigned start = m_rand(); - for (unsigned i = m_unsat.size(); i-- > 0; ) { - unsigned cl = m_unsat.elem_at((i + start) % m_unsat.size()); - if (flip_dscore(m_clauses[cl])) + unsigned start = s.random(); + for (unsigned i = unsat().size(); i-- > 0; ) { + unsigned cl = unsat().elem_at((i + start) % unsat().size()); + if (flip_dscore(cl)) return true; } - std::cout << "flip dscore\n"; - IF_VERBOSE(2, verbose_stream() << "(sls " << m_stats.m_num_flips << " " << m_unsat.size() << ")\n"); + IF_VERBOSE(2, verbose_stream() << "(sls " << m_stats.m_num_flips << " " << unsat().size() << ")\n"); return false; } - bool solver::sls::flip_dscore(clause const& clause) { + bool solver::sls::flip_dscore(unsigned cl) { + auto const& clause = get_clause(cl); rational new_value, min_value, min_score(-1); var_t min_var = UINT_MAX; - for (auto a : clause.m_arith) { - auto const& ai = m_atoms[a]; - ineq const& ineq = ai.m_ineq; + for (auto lit : clause) { + auto const* ai = atom(lit); + if (!ai) + continue; + ineq const& ineq = ai->m_ineq; for (auto const& [coeff, v] : ineq.m_args) { if (!ineq.is_true() && cm(ineq, v, new_value)) { rational score = dscore(v, new_value); @@ -199,8 +222,9 @@ namespace arith { } void solver::sls::paws() { - for (auto& clause : m_clauses) { - bool above = 10000 * m_config.sp <= (m_rand() % 10000); + for (unsigned cl = num_clauses(); cl-- > 0; ) { + auto& clause = get_clause_info(cl); + bool above = 10000 * m_config.sp <= (s.random() % 10000); if (!above && clause.is_true() && clause.m_weight > 1) clause.m_weight -= 1; if (above && !clause.is_true()) @@ -208,103 +232,6 @@ namespace arith { } } - void solver::sls::update(var_t v, rational const& new_value) { - auto& vi = m_vars[v]; - auto const& old_value = vi.m_value; - for (auto const& [coeff, atm] : vi.m_atoms) { - auto& ai = m_atoms[atm]; - SASSERT(!ai.m_is_bool); - auto& clause = m_clauses[ai.m_clause_idx]; - rational dtt_old = dtt(ai.m_ineq); - ai.m_ineq.m_args_value += coeff * (new_value - old_value); - rational dtt_new = dtt(ai.m_ineq); - bool was_true = clause.is_true(); - if (dtt_new < clause.m_dts) { - if (was_true && clause.m_dts > 0 && dtt_new == 0 && 1 == clause.m_num_trues) { - for (auto lit : clause.m_bools) { - if (is_true(lit)) { - dec_break(lit); - break; - } - } - } - clause.m_dts = dtt_new; - if (!was_true && clause.is_true()) - m_unsat.remove(ai.m_clause_idx); - } - else if (clause.m_dts == dtt_old && dtt_old < dtt_new) { - clause.m_dts = dts(clause); - if (was_true && !clause.is_true()) - m_unsat.insert(ai.m_clause_idx); - if (was_true && clause.is_true() && clause.m_dts > 0 && dtt_old == 0 && 1 == clause.m_num_trues) { - for (auto lit : clause.m_bools) { - if (is_true(lit)) { - inc_break(lit); - break; - } - } - } - } - SASSERT(clause.m_dts >= 0); - } - vi.m_value = new_value; - } - - bool solver::sls::flip_arith(clause const& clause) { - rational new_value; - for (auto a : clause.m_arith) { - auto const& ai = m_atoms[a]; - ineq const& ineq = ai.m_ineq; - for (auto const& [coeff, v] : ineq.m_args) { - if (!ineq.is_true() && cm(ineq, v, new_value)) { - int score = cm_score(v, new_value); - if (score <= 0) - continue; - unsigned num_unsat = m_unsat.size(); - update(v, new_value); - std::cout << "score " << v << " " << score << "\n"; - std::cout << num_unsat << " -> " << m_unsat.size() << "\n"; - return true; - } - } - } - return false; - } - - - - rational solver::sls::dts(clause const& cl) const { - rational d(1), d2; - bool first = true; - for (auto a : cl.m_arith) { - auto const& ai = m_atoms[a]; - d2 = dtt(ai.m_ineq); - if (first) - d = d2, first = false; - else - d = std::min(d, d2); - if (d == 0) - break; - } - return d; - } - - rational solver::sls::dts(clause const& cl, var_t v, rational const& new_value) const { - rational d(1), d2; - bool first = true; - for (auto a : cl.m_arith) { - auto const& ai = m_atoms[a]; - d2 = dtt(ai.m_ineq, v, new_value); - if (first) - d = d2, first = false; - else - d = std::min(d, d2); - if (d == 0) - break; - } - return d; - } - // // dscore(op) = sum_c (dts(c,alpha) - dts(c,alpha_after)) * weight(c) // @@ -312,9 +239,9 @@ namespace arith { auto const& vi = m_vars[v]; rational score(0); for (auto const& [coeff, atm] : vi.m_atoms) { - auto const& ai = m_atoms[atm]; - auto const& cl = m_clauses[ai.m_clause_idx]; - score += (cl.m_dts - dts(cl, v, new_value)) * rational(cl.m_weight); + auto const& ai = *m_atoms[atm]; + auto const& cl = get_clause_info(ai.m_clause_idx); + // score += (dts(cl) - dts(cl, v, new_value)) * rational(cl.m_weight); } return score; } @@ -323,8 +250,8 @@ namespace arith { int score = 0; auto& vi = m_vars[v]; for (auto const& [coeff, atm] : vi.m_atoms) { - auto const& ai = m_atoms[atm]; - auto const& clause = m_clauses[ai.m_clause_idx]; + auto const& ai = *m_atoms[atm]; + auto const& clause = get_clause_info(ai.m_clause_idx); rational dtt_old = dtt(ai.m_ineq); rational dtt_new = dtt(ai.m_ineq, v, new_value); if (!clause.is_true()) { @@ -335,8 +262,10 @@ namespace arith { continue; else { bool has_true = false; - for (auto a : clause.m_arith) { - auto const& ai = m_atoms[a]; + for (auto lit : *clause.m_clause) { + if (!atom(lit)) + continue; + auto const& ai = *atom(lit); rational d = dtt(ai.m_ineq, v, new_value); has_true |= (d == 0); } @@ -347,6 +276,121 @@ namespace arith { return score; } + rational solver::sls::dts(unsigned cl) const { + rational d(1), d2; + bool first = true; + for (auto a : get_clause(cl)) { + auto const* ai = atom(a); + if (!ai) + continue; + d2 = dtt(ai->m_ineq); + if (first) + d = d2, first = false; + else + d = std::min(d, d2); + if (d == 0) + break; + } + return d; + } + + rational solver::sls::dts(unsigned cl, var_t v, rational const& new_value) const { + rational d(1), d2; + bool first = true; + for (auto lit : get_clause(cl)) { + auto const* ai = atom(lit); + if (!ai) + continue; + d2 = dtt(ai->m_ineq, v, new_value); + if (first) + d = d2, first = false; + else + d = std::min(d, d2); + if (d == 0) + break; + } + return d; + } + + void solver::sls::update(var_t v, rational const& new_value) { + auto& vi = m_vars[v]; + auto const& old_value = vi.m_value; + for (auto const& [coeff, atm] : vi.m_atoms) { + auto& ai = *m_atoms[atm]; + SASSERT(!ai.m_is_bool); + auto& clause = get_clause_info(ai.m_clause_idx); + rational dtt_old = dtt(ai.m_ineq); + ai.m_ineq.m_args_value += coeff * (new_value - old_value); + rational dtt_new = dtt(ai.m_ineq); + bool was_true = clause.is_true(); + auto& dts_value = dts(ai.m_clause_idx); + if (dtt_new < dts_value) { + if (was_true && dts_value > 0 && dtt_new == 0 && 1 == clause.m_num_trues) { + for (auto lit : *clause.m_clause) { +#if false + TODO + if (is_true(lit)) { + dec_break(lit); + break; + } +#endif + } + } + dts_value = dtt_new; + if (!was_true && clause.is_true()) + unsat().remove(ai.m_clause_idx); + } + else if (dts_value == dtt_old && dtt_old < dtt_new) { + dts_value = dts(ai.m_clause_idx); + if (was_true && !clause.is_true()) + unsat().insert(ai.m_clause_idx); + if (was_true && clause.is_true() && dts_value > 0 && dtt_old == 0 && 1 == clause.m_num_trues) { + for (auto lit : *clause.m_clause) { +#if false + TODO + if (is_true(lit)) { + inc_break(lit); + break; + } +#endif + } + } + } + SASSERT(dts_value >= 0); + } + vi.m_value = new_value; + } + +#if 0 + + + + + + + void solver::sls::add_clause(sat::clause* cl) { + unsigned clause_idx = m_clauses.size(); + m_clauses.push_back({ cl, 1, rational::zero() }); + clause& cls = m_clauses.back(); + cls.m_dts = dts(cls); + for (sat::literal lit : *cl) { + if (is_true(lit)) + cls.add(lit); + } + + for (auto a : arith) + m_atoms[a].m_clause_idx = clause_idx; + + if (!cl.is_true()) { + m_best_min_unsat++; + m_unsat.insert(clause_idx); + } + else if (cl.m_dts > 0 && cl.m_num_trues == 1) + inc_break(sat::to_literal(cl.m_trues)); + + } + + #endif } diff --git a/src/sat/smt/arith_solver.h b/src/sat/smt/arith_solver.h index 1717b93d1..17207bd80 100644 --- a/src/sat/smt/arith_solver.h +++ b/src/sat/smt/arith_solver.h @@ -30,6 +30,7 @@ Author: #include "math/polynomial/algebraic_numbers.h" #include "math/polynomial/polynomial.h" #include "sat/smt/sat_th.h" +#include "sat/sat_ddfw.h" namespace euf { class solver; @@ -197,6 +198,14 @@ namespace arith { typedef unsigned var_t; typedef unsigned atom_t; + struct config { + double cb = 0.0; + unsigned L = 20; + unsigned t = 45; + unsigned max_no_improve = 500000; + double sp = 0.0003; + }; + struct stats { unsigned m_num_flips = 0; }; @@ -237,26 +246,49 @@ namespace arith { unsigned m_breaks = 0; }; - solver& s; - ast_manager& m; - unsigned m_max_arith_steps = 0; - stats m_stats; - vector m_atoms; - vector m_vars; + struct clause { + unsigned m_weight = 1; + rational m_dts = rational::one(); + }; + solver& s; + ast_manager& m; + sat::ddfw* m_bool_search = nullptr; + unsigned m_max_arith_steps = 0; + unsigned m_best_min_unsat = UINT_MAX; + stats m_stats; + config m_config; + scoped_ptr_vector m_atoms; + vector m_vars; + vector m_clauses; + + indexed_uint_set& unsat() { return m_bool_search->unsat_set(); } + unsigned num_clauses() const { return m_bool_search->num_clauses(); } + sat::clause& get_clause(unsigned idx) { return *get_clause_info(idx).m_clause; } + sat::clause const& get_clause(unsigned idx) const { return *get_clause_info(idx).m_clause; } + sat::ddfw::clause_info& get_clause_info(unsigned idx) { return m_bool_search->get_clause_info(idx); } + sat::ddfw::clause_info const& get_clause_info(unsigned idx) const { return m_bool_search->get_clause_info(idx); } + + atom_info* atom(sat::literal lit) const { return m_atoms[lit.index()]; } + rational& dts(unsigned idx) { return m_clauses[idx].m_dts; } bool flip(); void log() {} - bool flip_unsat() { return false; } - bool flip_clauses() { return false; } - bool flip_dscore() { return false; } -// bool flip_dscore(clause const&); -// bool flip(clause const&); + bool flip_unsat(); + bool flip_clauses(); + bool flip_dscore(); + bool flip_dscore(unsigned cl); + bool flip(unsigned cl); rational dtt(ineq const& ineq) const { return dtt(ineq.m_args_value, ineq); } rational dtt(rational const& args, ineq const& ineq) const; rational dtt(ineq const& ineq, var_t v, rational const& new_value) const; -// rational dts(clause const& cl, var_t v, rational const& new_value) const; -// rational dts(clause const& cl) const; + rational dts(unsigned cl, var_t v, rational const& new_value) const; + rational dts(unsigned cl) const; bool cm(ineq const& ineq, var_t v, rational& new_value); + int cm_score(var_t v, rational const& new_value); + void update(var_t v, rational const& new_value); + void paws(); + rational dscore(var_t v, rational const& new_value) const; + void save_best_values() {} rational value(var_t v) const { return m_vars[v].m_value; } public: @@ -265,6 +297,7 @@ namespace arith { void set_bounds_begin(); void set_bounds_end(unsigned num_literals); void set_bounds(enode* n); + void set(sat::ddfw* d) { m_bool_search = d; } }; sls m_local_search; @@ -590,6 +623,7 @@ namespace arith { void set_bounds_end(unsigned num_literals) override { m_local_search.set_bounds_end(num_literals); } void set_bounds(enode* n) override { m_local_search.set_bounds(n); } void local_search(bool_vector& phase) override { m_local_search(phase); } + void set_bool_search(sat::ddfw* ddfw) override { m_local_search.set(ddfw); } // bounds and equality propagation callbacks lp::lar_solver& lp() { return *m_solver; } diff --git a/src/sat/smt/euf_local_search.cpp b/src/sat/smt/euf_local_search.cpp index 90889ca3e..4ee6490b7 100644 --- a/src/sat/smt/euf_local_search.cpp +++ b/src/sat/smt/euf_local_search.cpp @@ -31,19 +31,24 @@ namespace euf { unsigned max_rounds = 30; + for (auto* th : m_solvers) + th->set_bool_search(&bool_search); + for (unsigned rounds = 0; m.inc() && rounds < max_rounds; ++rounds) { - setup_bounds(phase); + bool_search.reinit(s(), phase); + setup_bounds(phase); + // Non-boolean literals are assumptions to Boolean search - literal_vector _lits; + literal_vector assumptions; for (unsigned v = 0; v < phase.size(); ++v) if (!is_propositional(literal(v))) - _lits.push_back(literal(v, !phase[v])); + assumptions.push_back(literal(v, !phase[v])); bool_search.rlimit().push(m_max_bool_steps); - lbool r = bool_search.check(_lits.size(), _lits.data(), nullptr); + lbool r = bool_search.check(assumptions.size(), assumptions.data(), nullptr); auto const& mdl = bool_search.get_model(); @@ -85,8 +90,6 @@ namespace euf { return phase[lit.var()] == !lit.sign(); }; - svector bin_clauses; - s().collect_bin_clauses(bin_clauses, false, false); for (auto* cp : s().clauses()) { if (any_of(*cp, [&](auto lit) { return is_true(lit); })) continue; @@ -95,14 +98,6 @@ namespace euf { init_literal(l); } - for (auto [l1, l2] : bin_clauses) { - if (is_true(l1) || is_true(l2)) - continue; - num_literals += 2; - init_literal(l1); - init_literal(l2); - }; - m_max_bool_steps = (m_ls_config.L * num_bool) / num_literals; for (auto* th : m_solvers) diff --git a/src/sat/smt/sat_th.h b/src/sat/smt/sat_th.h index 2101dfd64..e226566b8 100644 --- a/src/sat/smt/sat_th.h +++ b/src/sat/smt/sat_th.h @@ -18,6 +18,7 @@ Author: #include "util/top_sort.h" #include "sat/smt/sat_smt.h" +#include "sat/sat_ddfw.h" #include "ast/euf/euf_egraph.h" #include "model/model.h" #include "smt/params/smt_params.h" @@ -139,6 +140,8 @@ namespace euf { /** * Local search interface */ + virtual void set_bool_search(sat::ddfw* ddfw) {} + virtual void set_bounds_begin() {} virtual void set_bounds_end(unsigned num_literals) {}