3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-14 04:48:45 +00:00

wip - bounded local search for arithmetic

This commit is contained in:
Nikolaj Bjorner 2023-02-11 15:46:39 -08:00
parent 4b2c166e8b
commit 5e30323b1a
8 changed files with 124 additions and 66 deletions

View file

@ -148,7 +148,8 @@ namespace sat {
m_use_list[lit.index()].pop_back();
m_alloc.del_clause(info.m_clause);
m_clauses.pop_back();
m_unsat.remove(m_clauses.size());
if (m_unsat.contains(m_clauses.size()))
m_unsat.remove(m_clauses.size());
}
void ddfw::add(solver const& s) {
@ -188,12 +189,11 @@ namespace sat {
}
void ddfw::remove_assumptions() {
if (m_assumptions.empty())
return;
for (unsigned i = 0; i < m_assumptions.size(); ++i)
del();
m_unsat_vars.reset();
for (auto idx : m_unsat)
for (auto lit : get_clause(idx))
m_unsat_vars.insert(lit.var());
init(0, nullptr);
}
void ddfw::init(unsigned sz, literal const* assumptions) {

View file

@ -126,7 +126,7 @@ namespace sat {
virtual void add_assumptions(literal_set& ext_assumptions) {}
virtual bool tracking_assumptions() { return false; }
virtual bool enable_self_propagate() const { return false; }
virtual void local_search(bool_vector& phase) {}
virtual lbool local_search(bool_vector& phase) { return l_undef; }
virtual bool extract_pb(std::function<void(unsigned sz, literal const* c, unsigned k)>& card,
std::function<void(unsigned sz, literal const* c, unsigned const* coeffs, unsigned k)>& pb) {

View file

@ -1302,6 +1302,9 @@ namespace sat {
return l_undef;
}
// uncomment this to test bounded local search:
// bounded_local_search();
log_stats();
if (m_config.m_max_conflicts > 0 && m_config.m_burst_search > 0) {
m_restart_threshold = m_config.m_burst_search;
@ -1360,6 +1363,12 @@ namespace sat {
};
void solver::bounded_local_search() {
if (m_ext) {
verbose_stream() << "bounded local search\n";
do_restart(true);
m_ext->local_search(m_best_phase);
return;
}
literal_vector _lits;
scoped_limits scoped_rl(rlimit());
m_local_search = alloc(ddfw);

View file

@ -20,21 +20,24 @@ Author:
namespace arith {
///
/// need to initialize ineqs (arithmetical atoms)
///
sls::sls(solver& s):
s(s), m(s.m) {}
void sls::operator()(bool_vector& phase) {
void sls::reset() {
m_literals.reset();
m_vars.reset();
m_clauses.reset();
m_terms.reset();
}
lbool sls::operator()(bool_vector& phase) {
unsigned num_steps = 0;
for (unsigned v = 0; v < s.s().num_vars(); ++v)
init_bool_var_assignment(v);
m_best_min_unsat = unsat().size();
verbose_stream() << "max arith steps " << m_max_arith_steps << "\n";
//m_max_arith_steps = 10000;
while (m.inc() && m_best_min_unsat > 0 && num_steps < m_max_arith_steps) {
if (!flip())
break;
@ -47,24 +50,27 @@ namespace arith {
save_best_values();
}
}
IF_VERBOSE(2, verbose_stream() << "(sls " << m_stats.m_num_flips << " " << unsat().size() << ")\n");
log();
return unsat().empty() ? l_true : l_undef;
}
void sls::log() {
IF_VERBOSE(2, verbose_stream() << "(sls :flips " << m_stats.m_num_flips << " :unsat " << unsat().size() << ")\n");
}
void sls::save_best_values() {
// first compute assignment to terms
// then update non-basic variables in tableau, assuming a sat solution was found.
#if false
for (auto const& [t, v] : terms) {
// then update non-basic variables in tableau.
for (auto const& [t, v] : m_terms) {
rational val;
lp::lar_term const& term = lp().get_term(t);
lp::lar_term const& term = s.lp().get_term(t);
for (lp::lar_term::ival arg : term) {
auto t2 = lp().column2tv(arg.column());
auto w = lp().local_to_external(t2.id());
val += arg.coeff() * local_search.value(w);
auto t2 = s.lp().column2tv(arg.column());
auto w = s.lp().local_to_external(t2.id());
val += arg.coeff() * value(w);
}
update(v, val);
}
#endif
for (unsigned v = 0; v < s.get_num_vars(); ++v) {
if (s.is_bool(v))
@ -87,6 +93,8 @@ namespace arith {
void sls::set(sat::ddfw* d) {
m_bool_search = d;
reset();
m_literals.reserve(s.s().num_vars() * 2);
add_vars();
m_clauses.resize(d->num_clauses());
for (unsigned i = 0; i < d->num_clauses(); ++i)
@ -151,12 +159,16 @@ namespace arith {
bool sls::cm(ineq const& ineq, var_t v, rational& new_value) {
SASSERT(!ineq.is_true());
auto delta = ineq.m_args_value - ineq.m_bound;
if (ineq.m_op == ineq_kind::NE || ineq.m_op == ineq_kind::LT)
delta--;
for (auto const& [coeff, w] : ineq.m_args) {
if (w == v) {
if (coeff > 0)
new_value = value(v) - abs(ceil(delta / coeff));
else
new_value = value(v) + abs(floor(delta / coeff));
switch (ineq.m_op) {
case ineq_kind::LE:
SASSERT(delta + coeff * (new_value - value(v)) <= 0);
@ -189,9 +201,12 @@ namespace arith {
auto const& clause = get_clause(cl);
rational new_value;
for (literal lit : clause) {
auto const* ineq = atom(lit);
if (!ineq || ineq->is_true())
if (is_true(lit))
continue;
auto const* ineq = atom(lit);
if (!ineq)
continue;
SASSERT(!ineq->is_true());
for (auto const& [coeff, v] : ineq->m_args) {
if (!cm(*ineq, v, new_value))
continue;
@ -201,8 +216,9 @@ namespace arith {
unsigned num_unsat = unsat().size();
update(v, new_value);
IF_VERBOSE(2,
verbose_stream() << "score " << v << " " << score << "\n"
verbose_stream() << "v" << v << " score " << score << " "
<< num_unsat << " -> " << unsat().size() << "\n");
SASSERT(num_unsat > unsat().size());
return true;
}
}
@ -255,7 +271,8 @@ namespace arith {
}
/**
* redistribute weights of clauses. TODO - re-use ddfw weights instead.
* redistribute weights of clauses.
* TODO - re-use ddfw weights instead.
*/
void sls::paws() {
for (unsigned cl = num_clauses(); cl-- > 0; ) {
@ -270,13 +287,15 @@ namespace arith {
//
// dscore(op) = sum_c (dts(c,alpha) - dts(c,alpha_after)) * weight(c)
// TODO - use cached dts instead of computed dts
// cached dts has to be updated when the score of literals are updated.
//
rational sls::dscore(var_t v, rational const& new_value) const {
auto const& vi = m_vars[v];
rational score(0);
for (auto const& [coeff, lit] : vi.m_literals)
for (auto cl : m_bool_search->get_use_list(lit))
score += (dts(cl) - dts(cl, v, new_value)) * rational(get_weight(cl));
score += (compute_dts(cl) - dts(cl, v, new_value)) * rational(get_weight(cl));
return score;
}
@ -290,10 +309,11 @@ namespace arith {
for (auto cl : m_bool_search->get_use_list(lit)) {
auto const& clause = get_clause_info(cl);
if (!clause.is_true()) {
VERIFY(dtt_old != 0);
if (dtt_new == 0)
++score; // false -> true
}
else if (dtt_new == 0 || dtt_old > 0 || clause.m_num_trues > 0) // true -> true ?? TODO
else if (dtt_new == 0 || dtt_old > 0 || clause.m_num_trues > 1) // true -> true not really, same variable can be in multiple literals
continue;
else if (all_of(*clause.m_clause, [&](auto lit) { return !atom(lit) || dtt(*atom(lit), v, new_value) > 0; })) // ?? TODO
--score;
@ -302,7 +322,7 @@ namespace arith {
return score;
}
rational sls::dts(unsigned cl) const {
rational sls::compute_dts(unsigned cl) const {
rational d(1), d2;
bool first = true;
for (auto a : get_clause(cl)) {
@ -346,14 +366,20 @@ namespace arith {
rational dtt_old = dtt(ineq);
ineq.m_args_value += coeff * (new_value - old_value);
rational dtt_new = dtt(ineq);
SASSERT(!(dtt_new == 0 && dtt_new < dtt_old) || m_bool_search->get_value(lit.var()) == lit.sign());
SASSERT(!(dtt_old == 0 && dtt_new > dtt_old) || m_bool_search->get_value(lit.var()) != lit.sign());
if ((dtt_new == 0) == is_true(lit)) {
dtt(ineq) = dtt_new;
continue;
}
VERIFY((dtt_old == 0) == is_true(lit));
VERIFY(!(dtt_new == 0 && dtt_new < dtt_old) || !is_true(lit));
VERIFY(!(dtt_old == 0 && dtt_new > dtt_old) || is_true(lit));
if (dtt_new == 0 && dtt_new < dtt_old) // flip from false to true
m_bool_search->flip(lit.var());
else if (dtt_old == 0 && dtt_old < dtt_new) // flip from true to false
m_bool_search->flip(lit.var());
dtt(ineq) = dtt_new;
SASSERT((dtt_new == 0) == (m_bool_search->get_value(lit.var()) != lit.sign()));
VERIFY((dtt_new == 0) == is_true(lit));
}
vi.m_value = new_value;
}
@ -422,18 +448,18 @@ namespace arith {
}
void sls::add_args(ineq& ineq, lp::tv t, theory_var v, rational sign) {
void sls::add_args(sat::literal lit, ineq& ineq, lp::tv t, theory_var v, rational sign) {
if (t.is_term()) {
lp::lar_term const& term = s.lp().get_term(t);
for (lp::lar_term::ival arg : term) {
auto t2 = s.lp().column2tv(arg.column());
auto w = s.lp().local_to_external(t2.id());
ineq.m_args.push_back({ sign * arg.coeff(), w });
add_arg(lit, ineq, sign * arg.coeff(), w);
}
}
else
ineq.m_args.push_back({ sign, s.lp().local_to_external(t.id()) });
add_arg(lit, ineq, sign, s.lp().local_to_external(t.id()));
}
@ -465,7 +491,7 @@ namespace arith {
bound.neg();
auto& ineq = new_ineq(op, bound);
add_args(ineq, t, b->get_var(), should_minus ? rational::minus_one() :rational::one());
add_args(lit, ineq, t, b->get_var(), should_minus ? rational::minus_one() :rational::one());
m_literals.set(lit.index(), &ineq);
return;
}
@ -478,8 +504,8 @@ namespace arith {
lp::tv tu = s.get_tv(u);
lp::tv tv = s.get_tv(v);
auto& ineq = new_ineq(lit.sign() ? sls::ineq_kind::NE : sls::ineq_kind::EQ, rational::zero());
add_args(ineq, tu, u, rational::one());
add_args(ineq, tv, v, -rational::one());
add_args(lit, ineq, tu, u, rational::one());
add_args(lit, ineq, tv, v, -rational::one());
m_literals.set(lit.index(), &ineq);
return;
}
@ -492,8 +518,9 @@ namespace arith {
void sls::init_literal_assignment(sat::literal lit) {
auto* ineq = m_literals.get(lit.index(), nullptr);
if (ineq && m_bool_search->get_value(lit.var()) != (dtt(*ineq) == 0))
m_bool_search->flip(lit.var());
if (ineq && is_true(lit) != (dtt(*ineq) == 0))
m_bool_search->flip(lit.var());
}
}

View file

@ -55,6 +55,7 @@ namespace arith {
unsigned m_num_flips = 0;
};
public:
// encode args <= bound, args = bound, args < bound
struct ineq {
vector<std::pair<rational, var_t>> m_args;
@ -74,7 +75,23 @@ namespace arith {
return m_args_value < m_bound;
}
}
std::ostream& display(std::ostream& out) const {
bool first = true;
for (auto const& [c, v] : m_args)
out << (first? "": " + ") << c << " * v" << v, first = false;
switch (m_op) {
case ineq_kind::LE:
return out << " <= " << m_bound << "(" << m_args_value << ")";
case ineq_kind::EQ:
return out << " == " << m_bound << "(" << m_args_value << ")";
case ineq_kind::NE:
return out << " != " << m_bound << "(" << m_args_value << ")";
default:
return out << " < " << m_bound << "(" << m_args_value << ")";
}
}
};
private:
struct var_info {
rational m_value;
@ -85,6 +102,7 @@ namespace arith {
struct clause {
unsigned m_weight = 1;
rational m_dts = rational::one();
};
solver& s;
@ -97,6 +115,8 @@ namespace arith {
scoped_ptr_vector<ineq> m_literals;
vector<var_info> m_vars;
vector<clause> m_clauses;
svector<std::pair<lp::tv, euf::theory_var>> m_terms;
indexed_uint_set& unsat() { return m_bool_search->unsat_set(); }
unsigned num_clauses() const { return m_bool_search->num_clauses(); }
@ -104,12 +124,14 @@ namespace arith {
sat::clause const& get_clause(unsigned idx) const { return *get_clause_info(idx).m_clause; }
sat::ddfw::clause_info& get_clause_info(unsigned idx) { return m_bool_search->get_clause_info(idx); }
sat::ddfw::clause_info const& get_clause_info(unsigned idx) const { return m_bool_search->get_clause_info(idx); }
bool is_true(sat::literal lit) { return lit.sign() != m_bool_search->get_value(lit.var()); }
void reset();
ineq* atom(sat::literal lit) const { return m_literals[lit.index()]; }
unsigned& get_weight(unsigned idx) { return m_clauses[idx].m_weight; }
unsigned get_weight(unsigned idx) const { return m_clauses[idx].m_weight; }
bool flip();
void log() {}
void log();
bool flip_unsat();
bool flip_clauses();
bool flip_dscore();
@ -119,7 +141,7 @@ namespace arith {
rational dtt(rational const& args, ineq const& ineq) const;
rational dtt(ineq const& ineq, var_t v, rational const& new_value) const;
rational dts(unsigned cl, var_t v, rational const& new_value) const;
rational dts(unsigned cl) const;
rational compute_dts(unsigned cl) const;
bool cm(ineq const& ineq, var_t v, rational& new_value);
int cm_score(var_t v, rational const& new_value);
void update(var_t v, rational const& new_value);
@ -130,7 +152,7 @@ namespace arith {
sls::ineq& new_ineq(ineq_kind op, rational const& bound);
void add_arg(sat::literal lit, ineq& ineq, rational const& c, var_t v);
void add_bounds(sat::literal_vector& bounds);
void add_args(ineq& ineq, lp::tv t, euf::theory_var v, rational sign);
void add_args(sat::literal lit, ineq& ineq, lp::tv t, euf::theory_var v, rational sign);
void init_literal(sat::literal lit);
void init_bool_var_assignment(sat::bool_var v);
void init_literal_assignment(sat::literal lit);
@ -138,11 +160,14 @@ namespace arith {
rational value(var_t v) const { return m_vars[v].m_value; }
public:
sls(solver& s);
void operator ()(bool_vector& phase);
lbool operator ()(bool_vector& phase);
void set_bounds_begin();
void set_bounds_end(unsigned num_literals);
void set_bounds(euf::enode* n);
void set(sat::ddfw* d);
};
inline std::ostream& operator<<(std::ostream& out, sls::ineq const& ineq) {
return ineq.display(out);
}
}

View file

@ -515,7 +515,7 @@ namespace arith {
void set_bounds_begin() override { m_local_search.set_bounds_begin(); }
void set_bounds_end(unsigned num_literals) override { m_local_search.set_bounds_end(num_literals); }
void set_bounds(enode* n) override { m_local_search.set_bounds(n); }
void local_search(bool_vector& phase) override { m_local_search(phase); }
lbool local_search(bool_vector& phase) override { return m_local_search(phase); }
void set_bool_search(sat::ddfw* ddfw) override { m_local_search.set(ddfw); }
// bounds and equality propagation callbacks

View file

@ -21,7 +21,7 @@ Author:
namespace euf {
void solver::local_search(bool_vector& phase) {
lbool solver::local_search(bool_vector& phase) {
scoped_limits scoped_rl(m.limit());
sat::ddfw bool_search;
bool_search.reinit(s(), phase);
@ -36,7 +36,7 @@ namespace euf {
for (unsigned rounds = 0; m.inc() && rounds < max_rounds; ++rounds) {
setup_bounds(phase);
setup_bounds(bool_search, phase);
// Non-boolean literals are assumptions to Boolean search
literal_vector assumptions;
@ -44,6 +44,8 @@ namespace euf {
if (!is_propositional(literal(v)))
assumptions.push_back(literal(v, !bool_search.get_value(v)));
verbose_stream() << "assumptions " << assumptions.size() << "\n";
bool_search.rlimit().push(m_max_bool_steps);
lbool r = bool_search.check(assumptions.size(), assumptions.data(), nullptr);
@ -51,15 +53,15 @@ namespace euf {
for (auto* th : m_solvers)
th->local_search(phase);
// if is_sat break;
if (bool_search.unsat_set().empty())
break;
}
auto const& mdl = bool_search.get_model();
for (unsigned i = 0; i < mdl.size(); ++i)
phase[i] = mdl[i] == l_true;
phase[i] = mdl[i] == l_true;
return bool_search.unsat_set().empty() ? l_true : l_undef;
}
bool solver::is_propositional(sat::literal lit) {
@ -67,13 +69,13 @@ namespace euf {
return !e || is_uninterp_const(e) || !m_egraph.find(e);
}
void solver::setup_bounds(bool_vector const& phase) {
void solver::setup_bounds(sat::ddfw& bool_search, bool_vector const& phase) {
unsigned num_literals = 0;
unsigned num_bool = 0;
for (auto* th : m_solvers)
th->set_bounds_begin();
auto init_literal = [&](sat::literal l) {
auto count_literal = [&](sat::literal l) {
if (is_propositional(l)) {
++num_bool;
return;
@ -86,16 +88,11 @@ namespace euf {
}
};
auto is_true = [&](auto lit) {
return phase[lit.var()] == !lit.sign();
};
for (auto* cp : s().clauses()) {
if (any_of(*cp, [&](auto lit) { return is_true(lit); }))
continue;
num_literals += cp->size();
for (auto l : *cp)
init_literal(l);
for (auto cl : bool_search.unsat_set()) {
auto& c = *bool_search.get_clause_info(cl).m_clause;
num_literals += c.size();
for (auto l : c)
count_literal(l);
}
m_max_bool_steps = (m_ls_config.L * num_bool) / num_literals;

View file

@ -265,7 +265,7 @@ namespace euf {
// local search
unsigned m_max_bool_steps = 10;
bool is_propositional(sat::literal lit);
void setup_bounds(bool_vector const& mdl);
void setup_bounds(sat::ddfw& bool_search, bool_vector const& mdl);
// user propagator
void check_for_user_propagator() {
@ -353,7 +353,7 @@ namespace euf {
void add_assumptions(sat::literal_set& assumptions) override;
bool tracking_assumptions() override;
std::string reason_unknown() override { return m_reason_unknown; }
void local_search(bool_vector& phase) override;
lbool local_search(bool_vector& phase) override;
void propagate(literal lit, ext_justification_idx idx);
bool propagate(enode* a, enode* b, ext_justification_idx idx);