3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-08 02:15:19 +00:00

wip - integrating arithmetic local search

This commit is contained in:
Nikolaj Bjorner 2023-02-09 15:52:32 -08:00
parent 1b0c76e3f0
commit d22e4aa525
7 changed files with 280 additions and 169 deletions

View file

@ -129,6 +129,7 @@ namespace sat {
void ddfw::add(unsigned n, literal const* c) {
clause* cls = m_alloc.mk_clause(n, c, false);
unsigned idx = m_clauses.size();
m_clauses.push_back(clause_info(cls, m_config.m_init_clause_weight));
for (literal lit : *cls) {
m_use_list.reserve(2*(lit.var()+1));
@ -137,6 +138,18 @@ namespace sat {
}
}
/**
* Remove the last clause that was added
*/
void ddfw::del() {
auto& info = m_clauses.back();
for (literal lit : *info.m_clause)
m_use_list[lit.index()].pop_back();
m_alloc.del_clause(info.m_clause);
m_clauses.pop_back();
m_unsat.remove(m_clauses.size());
}
void ddfw::add(solver const& s) {
for (auto& ci : m_clauses)
m_alloc.del_clause(ci.m_clause);
@ -169,9 +182,17 @@ namespace sat {
}
void ddfw::add_assumptions() {
for (unsigned i = 0; i < m_assumptions.size(); ++i) {
add(1, m_assumptions.data() + i);
}
for (unsigned i = 0; i < m_assumptions.size(); ++i)
add(1, m_assumptions.data() + i);
}
void ddfw::remove_assumptions() {
for (unsigned i = 0; i < m_assumptions.size(); ++i)
del();
m_unsat_vars.reset();
for (auto idx : m_unsat)
for (auto lit : get_clause(idx))
m_unsat_vars.insert(lit.var());
}
void ddfw::init(unsigned sz, literal const* assumptions) {

View file

@ -32,7 +32,7 @@ namespace sat {
class parallel;
class ddfw : public i_local_search {
public:
struct clause_info {
clause_info(clause* cl, double init_weight): m_weight(init_weight), m_clause(cl) {}
double m_weight; // weight of clause
@ -43,6 +43,7 @@ namespace sat {
void add(literal lit) { ++m_num_trues; m_trues += lit.index(); }
void del(literal lit) { SASSERT(m_num_trues > 0); --m_num_trues; m_trues -= lit.index(); }
};
protected:
struct config {
config() { reset(); }
@ -197,6 +198,8 @@ namespace sat {
void add(unsigned sz, literal const* c);
void del();
void add_assumptions();
inline void transfer_weight(unsigned from, unsigned to, double w);
@ -232,6 +235,16 @@ namespace sat {
void collect_statistics(statistics& st) const override {}
double get_priority(bool_var v) const override { return m_probs[v]; }
// access clause information and state of Boolean search
indexed_uint_set& unsat_set() { return m_unsat; }
unsigned num_clauses() const { return m_clauses.size(); }
clause_info& get_clause_info(unsigned idx) { return m_clauses[idx]; }
void remove_assumptions();
};
}

View file

@ -347,6 +347,7 @@ namespace sat {
s.m_checkpoint_enabled = true;
}
};
unsigned select_watch_lit(clause const & cls, unsigned starting_at) const;
unsigned select_learned_watch_lit(clause const & cls) const;
bool simplify_clause(unsigned & num_lits, literal * lits) const;

View file

@ -36,22 +36,19 @@ namespace arith {
// need to init variables/atoms/ineqs
m.limit().push(m_max_arith_steps);
unsigned m_best_min_unsat = 1;
unsigned best = m_best_min_unsat;
while (m.inc() && m_best_min_unsat > 0) {
// unsigned prev = m_unsat.size();
m_best_min_unsat = unsat().size();
unsigned num_steps = 0;
while (m.inc() && m_best_min_unsat > 0 && num_steps < m_max_arith_steps) {
if (!flip())
return;
#if 0
if (m_unsat.size() < best) {
best = m_unsat.size();
++m_stats.m_num_flips;
++num_steps;
unsigned num_unsat = unsat().size();
if (num_unsat < m_best_min_unsat) {
m_best_min_unsat = num_unsat;
num_steps = 0;
}
if (m_unsat.size() < m_best_min_unsat)
save_best_values();
#endif
}
}
}
@ -68,7 +65,6 @@ namespace arith {
}
bool solver::sls::flip() {
++m_stats.m_num_flips;
log();
if (flip_unsat())
return true;
@ -141,45 +137,72 @@ namespace arith {
return false;
}
#if 0
bool solver::sls::flip_unsat() {
unsigned start = m_rand();
for (unsigned i = m_unsat.size(); i-- > 0; ) {
unsigned cl = m_unsat.elem_at((i + start) % m_unsat.size());
if (flip(m_clauses[cl]))
unsigned start = s.random();
unsigned sz = unsat().size();
for (unsigned i = sz; i-- > 0; ) {
unsigned cl = unsat().elem_at((i + start) % sz);
if (flip(cl))
return true;
}
return false;
}
bool solver::sls::flip(unsigned cl) {
auto const& clause = get_clause(cl);
rational new_value;
for (literal lit : clause) {
auto const* ai = atom(lit);
if (!ai)
continue;
ineq const& ineq = ai->m_ineq;
for (auto const& [coeff, v] : ineq.m_args) {
if (!ineq.is_true() && cm(ineq, v, new_value)) {
int score = cm_score(v, new_value);
if (score <= 0)
continue;
unsigned num_unsat = unsat().size();
update(v, new_value);
IF_VERBOSE(0,
verbose_stream() << "score " << v << " " << score << "\n"
<< num_unsat << " -> " << unsat().size() << "\n");
return true;
}
}
}
return false;
}
bool solver::sls::flip_clauses() {
unsigned start = m_rand();
for (unsigned i = m_clauses.size(); i-- > 0; )
if (flip_arith(m_clauses[(i + start) % m_clauses.size()]))
unsigned start = s.random();
for (unsigned i = num_clauses(); i-- > 0; )
if (flip((i + start) % num_clauses()))
return true;
return false;
}
bool solver::sls::flip_dscore() {
paws();
unsigned start = m_rand();
for (unsigned i = m_unsat.size(); i-- > 0; ) {
unsigned cl = m_unsat.elem_at((i + start) % m_unsat.size());
if (flip_dscore(m_clauses[cl]))
unsigned start = s.random();
for (unsigned i = unsat().size(); i-- > 0; ) {
unsigned cl = unsat().elem_at((i + start) % unsat().size());
if (flip_dscore(cl))
return true;
}
std::cout << "flip dscore\n";
IF_VERBOSE(2, verbose_stream() << "(sls " << m_stats.m_num_flips << " " << m_unsat.size() << ")\n");
IF_VERBOSE(2, verbose_stream() << "(sls " << m_stats.m_num_flips << " " << unsat().size() << ")\n");
return false;
}
bool solver::sls::flip_dscore(clause const& clause) {
bool solver::sls::flip_dscore(unsigned cl) {
auto const& clause = get_clause(cl);
rational new_value, min_value, min_score(-1);
var_t min_var = UINT_MAX;
for (auto a : clause.m_arith) {
auto const& ai = m_atoms[a];
ineq const& ineq = ai.m_ineq;
for (auto lit : clause) {
auto const* ai = atom(lit);
if (!ai)
continue;
ineq const& ineq = ai->m_ineq;
for (auto const& [coeff, v] : ineq.m_args) {
if (!ineq.is_true() && cm(ineq, v, new_value)) {
rational score = dscore(v, new_value);
@ -199,8 +222,9 @@ namespace arith {
}
void solver::sls::paws() {
for (auto& clause : m_clauses) {
bool above = 10000 * m_config.sp <= (m_rand() % 10000);
for (unsigned cl = num_clauses(); cl-- > 0; ) {
auto& clause = get_clause_info(cl);
bool above = 10000 * m_config.sp <= (s.random() % 10000);
if (!above && clause.is_true() && clause.m_weight > 1)
clause.m_weight -= 1;
if (above && !clause.is_true())
@ -208,103 +232,6 @@ namespace arith {
}
}
void solver::sls::update(var_t v, rational const& new_value) {
auto& vi = m_vars[v];
auto const& old_value = vi.m_value;
for (auto const& [coeff, atm] : vi.m_atoms) {
auto& ai = m_atoms[atm];
SASSERT(!ai.m_is_bool);
auto& clause = m_clauses[ai.m_clause_idx];
rational dtt_old = dtt(ai.m_ineq);
ai.m_ineq.m_args_value += coeff * (new_value - old_value);
rational dtt_new = dtt(ai.m_ineq);
bool was_true = clause.is_true();
if (dtt_new < clause.m_dts) {
if (was_true && clause.m_dts > 0 && dtt_new == 0 && 1 == clause.m_num_trues) {
for (auto lit : clause.m_bools) {
if (is_true(lit)) {
dec_break(lit);
break;
}
}
}
clause.m_dts = dtt_new;
if (!was_true && clause.is_true())
m_unsat.remove(ai.m_clause_idx);
}
else if (clause.m_dts == dtt_old && dtt_old < dtt_new) {
clause.m_dts = dts(clause);
if (was_true && !clause.is_true())
m_unsat.insert(ai.m_clause_idx);
if (was_true && clause.is_true() && clause.m_dts > 0 && dtt_old == 0 && 1 == clause.m_num_trues) {
for (auto lit : clause.m_bools) {
if (is_true(lit)) {
inc_break(lit);
break;
}
}
}
}
SASSERT(clause.m_dts >= 0);
}
vi.m_value = new_value;
}
bool solver::sls::flip_arith(clause const& clause) {
rational new_value;
for (auto a : clause.m_arith) {
auto const& ai = m_atoms[a];
ineq const& ineq = ai.m_ineq;
for (auto const& [coeff, v] : ineq.m_args) {
if (!ineq.is_true() && cm(ineq, v, new_value)) {
int score = cm_score(v, new_value);
if (score <= 0)
continue;
unsigned num_unsat = m_unsat.size();
update(v, new_value);
std::cout << "score " << v << " " << score << "\n";
std::cout << num_unsat << " -> " << m_unsat.size() << "\n";
return true;
}
}
}
return false;
}
rational solver::sls::dts(clause const& cl) const {
rational d(1), d2;
bool first = true;
for (auto a : cl.m_arith) {
auto const& ai = m_atoms[a];
d2 = dtt(ai.m_ineq);
if (first)
d = d2, first = false;
else
d = std::min(d, d2);
if (d == 0)
break;
}
return d;
}
rational solver::sls::dts(clause const& cl, var_t v, rational const& new_value) const {
rational d(1), d2;
bool first = true;
for (auto a : cl.m_arith) {
auto const& ai = m_atoms[a];
d2 = dtt(ai.m_ineq, v, new_value);
if (first)
d = d2, first = false;
else
d = std::min(d, d2);
if (d == 0)
break;
}
return d;
}
//
// dscore(op) = sum_c (dts(c,alpha) - dts(c,alpha_after)) * weight(c)
//
@ -312,9 +239,9 @@ namespace arith {
auto const& vi = m_vars[v];
rational score(0);
for (auto const& [coeff, atm] : vi.m_atoms) {
auto const& ai = m_atoms[atm];
auto const& cl = m_clauses[ai.m_clause_idx];
score += (cl.m_dts - dts(cl, v, new_value)) * rational(cl.m_weight);
auto const& ai = *m_atoms[atm];
auto const& cl = get_clause_info(ai.m_clause_idx);
// score += (dts(cl) - dts(cl, v, new_value)) * rational(cl.m_weight);
}
return score;
}
@ -323,8 +250,8 @@ namespace arith {
int score = 0;
auto& vi = m_vars[v];
for (auto const& [coeff, atm] : vi.m_atoms) {
auto const& ai = m_atoms[atm];
auto const& clause = m_clauses[ai.m_clause_idx];
auto const& ai = *m_atoms[atm];
auto const& clause = get_clause_info(ai.m_clause_idx);
rational dtt_old = dtt(ai.m_ineq);
rational dtt_new = dtt(ai.m_ineq, v, new_value);
if (!clause.is_true()) {
@ -335,8 +262,10 @@ namespace arith {
continue;
else {
bool has_true = false;
for (auto a : clause.m_arith) {
auto const& ai = m_atoms[a];
for (auto lit : *clause.m_clause) {
if (!atom(lit))
continue;
auto const& ai = *atom(lit);
rational d = dtt(ai.m_ineq, v, new_value);
has_true |= (d == 0);
}
@ -347,6 +276,121 @@ namespace arith {
return score;
}
rational solver::sls::dts(unsigned cl) const {
rational d(1), d2;
bool first = true;
for (auto a : get_clause(cl)) {
auto const* ai = atom(a);
if (!ai)
continue;
d2 = dtt(ai->m_ineq);
if (first)
d = d2, first = false;
else
d = std::min(d, d2);
if (d == 0)
break;
}
return d;
}
rational solver::sls::dts(unsigned cl, var_t v, rational const& new_value) const {
rational d(1), d2;
bool first = true;
for (auto lit : get_clause(cl)) {
auto const* ai = atom(lit);
if (!ai)
continue;
d2 = dtt(ai->m_ineq, v, new_value);
if (first)
d = d2, first = false;
else
d = std::min(d, d2);
if (d == 0)
break;
}
return d;
}
void solver::sls::update(var_t v, rational const& new_value) {
auto& vi = m_vars[v];
auto const& old_value = vi.m_value;
for (auto const& [coeff, atm] : vi.m_atoms) {
auto& ai = *m_atoms[atm];
SASSERT(!ai.m_is_bool);
auto& clause = get_clause_info(ai.m_clause_idx);
rational dtt_old = dtt(ai.m_ineq);
ai.m_ineq.m_args_value += coeff * (new_value - old_value);
rational dtt_new = dtt(ai.m_ineq);
bool was_true = clause.is_true();
auto& dts_value = dts(ai.m_clause_idx);
if (dtt_new < dts_value) {
if (was_true && dts_value > 0 && dtt_new == 0 && 1 == clause.m_num_trues) {
for (auto lit : *clause.m_clause) {
#if false
TODO
if (is_true(lit)) {
dec_break(lit);
break;
}
#endif
}
}
dts_value = dtt_new;
if (!was_true && clause.is_true())
unsat().remove(ai.m_clause_idx);
}
else if (dts_value == dtt_old && dtt_old < dtt_new) {
dts_value = dts(ai.m_clause_idx);
if (was_true && !clause.is_true())
unsat().insert(ai.m_clause_idx);
if (was_true && clause.is_true() && dts_value > 0 && dtt_old == 0 && 1 == clause.m_num_trues) {
for (auto lit : *clause.m_clause) {
#if false
TODO
if (is_true(lit)) {
inc_break(lit);
break;
}
#endif
}
}
}
SASSERT(dts_value >= 0);
}
vi.m_value = new_value;
}
#if 0
void solver::sls::add_clause(sat::clause* cl) {
unsigned clause_idx = m_clauses.size();
m_clauses.push_back({ cl, 1, rational::zero() });
clause& cls = m_clauses.back();
cls.m_dts = dts(cls);
for (sat::literal lit : *cl) {
if (is_true(lit))
cls.add(lit);
}
for (auto a : arith)
m_atoms[a].m_clause_idx = clause_idx;
if (!cl.is_true()) {
m_best_min_unsat++;
m_unsat.insert(clause_idx);
}
else if (cl.m_dts > 0 && cl.m_num_trues == 1)
inc_break(sat::to_literal(cl.m_trues));
}
#endif
}

View file

@ -30,6 +30,7 @@ Author:
#include "math/polynomial/algebraic_numbers.h"
#include "math/polynomial/polynomial.h"
#include "sat/smt/sat_th.h"
#include "sat/sat_ddfw.h"
namespace euf {
class solver;
@ -197,6 +198,14 @@ namespace arith {
typedef unsigned var_t;
typedef unsigned atom_t;
struct config {
double cb = 0.0;
unsigned L = 20;
unsigned t = 45;
unsigned max_no_improve = 500000;
double sp = 0.0003;
};
struct stats {
unsigned m_num_flips = 0;
};
@ -237,26 +246,49 @@ namespace arith {
unsigned m_breaks = 0;
};
solver& s;
ast_manager& m;
unsigned m_max_arith_steps = 0;
stats m_stats;
vector<atom_info> m_atoms;
vector<var_info> m_vars;
struct clause {
unsigned m_weight = 1;
rational m_dts = rational::one();
};
solver& s;
ast_manager& m;
sat::ddfw* m_bool_search = nullptr;
unsigned m_max_arith_steps = 0;
unsigned m_best_min_unsat = UINT_MAX;
stats m_stats;
config m_config;
scoped_ptr_vector<atom_info> m_atoms;
vector<var_info> m_vars;
vector<clause> m_clauses;
indexed_uint_set& unsat() { return m_bool_search->unsat_set(); }
unsigned num_clauses() const { return m_bool_search->num_clauses(); }
sat::clause& get_clause(unsigned idx) { return *get_clause_info(idx).m_clause; }
sat::clause const& get_clause(unsigned idx) const { return *get_clause_info(idx).m_clause; }
sat::ddfw::clause_info& get_clause_info(unsigned idx) { return m_bool_search->get_clause_info(idx); }
sat::ddfw::clause_info const& get_clause_info(unsigned idx) const { return m_bool_search->get_clause_info(idx); }
atom_info* atom(sat::literal lit) const { return m_atoms[lit.index()]; }
rational& dts(unsigned idx) { return m_clauses[idx].m_dts; }
bool flip();
void log() {}
bool flip_unsat() { return false; }
bool flip_clauses() { return false; }
bool flip_dscore() { return false; }
// bool flip_dscore(clause const&);
// bool flip(clause const&);
bool flip_unsat();
bool flip_clauses();
bool flip_dscore();
bool flip_dscore(unsigned cl);
bool flip(unsigned cl);
rational dtt(ineq const& ineq) const { return dtt(ineq.m_args_value, ineq); }
rational dtt(rational const& args, ineq const& ineq) const;
rational dtt(ineq const& ineq, var_t v, rational const& new_value) const;
// rational dts(clause const& cl, var_t v, rational const& new_value) const;
// rational dts(clause const& cl) const;
rational dts(unsigned cl, var_t v, rational const& new_value) const;
rational dts(unsigned cl) const;
bool cm(ineq const& ineq, var_t v, rational& new_value);
int cm_score(var_t v, rational const& new_value);
void update(var_t v, rational const& new_value);
void paws();
rational dscore(var_t v, rational const& new_value) const;
void save_best_values() {}
rational value(var_t v) const { return m_vars[v].m_value; }
public:
@ -265,6 +297,7 @@ namespace arith {
void set_bounds_begin();
void set_bounds_end(unsigned num_literals);
void set_bounds(enode* n);
void set(sat::ddfw* d) { m_bool_search = d; }
};
sls m_local_search;
@ -590,6 +623,7 @@ namespace arith {
void set_bounds_end(unsigned num_literals) override { m_local_search.set_bounds_end(num_literals); }
void set_bounds(enode* n) override { m_local_search.set_bounds(n); }
void local_search(bool_vector& phase) override { m_local_search(phase); }
void set_bool_search(sat::ddfw* ddfw) override { m_local_search.set(ddfw); }
// bounds and equality propagation callbacks
lp::lar_solver& lp() { return *m_solver; }

View file

@ -31,19 +31,24 @@ namespace euf {
unsigned max_rounds = 30;
for (auto* th : m_solvers)
th->set_bool_search(&bool_search);
for (unsigned rounds = 0; m.inc() && rounds < max_rounds; ++rounds) {
setup_bounds(phase);
bool_search.reinit(s(), phase);
setup_bounds(phase);
// Non-boolean literals are assumptions to Boolean search
literal_vector _lits;
literal_vector assumptions;
for (unsigned v = 0; v < phase.size(); ++v)
if (!is_propositional(literal(v)))
_lits.push_back(literal(v, !phase[v]));
assumptions.push_back(literal(v, !phase[v]));
bool_search.rlimit().push(m_max_bool_steps);
lbool r = bool_search.check(_lits.size(), _lits.data(), nullptr);
lbool r = bool_search.check(assumptions.size(), assumptions.data(), nullptr);
auto const& mdl = bool_search.get_model();
@ -85,8 +90,6 @@ namespace euf {
return phase[lit.var()] == !lit.sign();
};
svector<sat::solver::bin_clause> bin_clauses;
s().collect_bin_clauses(bin_clauses, false, false);
for (auto* cp : s().clauses()) {
if (any_of(*cp, [&](auto lit) { return is_true(lit); }))
continue;
@ -95,14 +98,6 @@ namespace euf {
init_literal(l);
}
for (auto [l1, l2] : bin_clauses) {
if (is_true(l1) || is_true(l2))
continue;
num_literals += 2;
init_literal(l1);
init_literal(l2);
};
m_max_bool_steps = (m_ls_config.L * num_bool) / num_literals;
for (auto* th : m_solvers)

View file

@ -18,6 +18,7 @@ Author:
#include "util/top_sort.h"
#include "sat/smt/sat_smt.h"
#include "sat/sat_ddfw.h"
#include "ast/euf/euf_egraph.h"
#include "model/model.h"
#include "smt/params/smt_params.h"
@ -139,6 +140,8 @@ namespace euf {
/**
* Local search interface
*/
virtual void set_bool_search(sat::ddfw* ddfw) {}
virtual void set_bounds_begin() {}
virtual void set_bounds_end(unsigned num_literals) {}