3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-14 04:48:45 +00:00

wip - local search - move to plugin model

This commit is contained in:
Nikolaj Bjorner 2023-02-15 13:32:30 -08:00
parent a1f73d3805
commit c1ecc49021
5 changed files with 225 additions and 59 deletions

View file

@ -27,6 +27,10 @@
#include "sat/sat_clause.h"
#include "sat/sat_types.h"
namespace arith {
class sls;
}
namespace sat {
class solver;
class parallel;
@ -44,6 +48,7 @@ namespace sat {
};
class ddfw : public i_local_search {
friend class arith::sls;
public:
struct clause_info {
clause_info(clause* cl, double init_weight): m_weight(init_weight), m_clause(cl) {}
@ -126,7 +131,7 @@ namespace sat {
stopwatch m_stopwatch;
parallel* m_par;
scoped_ptr< local_search_plugin> m_plugin;
local_search_plugin* m_plugin = nullptr;
void flatten_use_list();
@ -148,7 +153,7 @@ namespace sat {
inline double reward(bool_var v) const { return m_vars[v].m_reward; }
inline double plugin_reward(bool_var v) const { return m_plugin->reward(v); }
inline double plugin_reward(bool_var v) const { return is_external(v) ? m_plugin->reward(v) : reward(v); }
void set_external(bool_var v) { m_vars[v].m_external = true; }

View file

@ -0,0 +1,90 @@
/*++
Copyright (c) 2006 Microsoft Corporation
Module Name:
sat_smt_setup.h
Author:
Nikolaj Bjorner (nbjorner) 2023-01-17
--*/
#pragma once
#include "ast/ast.h"
#include "smt/params/smt_params.h"
#include "sat/sat_config.h"
#include "ast/simplifiers/dependent_expr_state.h"
struct static_features;
namespace sat_smt {
void setup_sat_config(smt_params const& p, sat::config& config);
class setup {
ast_manager& m;
dependent_expr_state& m_st;
smt_params& m_params;
symbol m_logic;
bool m_already_configured = false;
void setup_auto_config();
void setup_default();
//
// setup_<logic>() methods do not depend on static features of the formula. So, they are safe to use
// even in an incremental setting.
//
// setup_<logic>(static_features & st) can only be used if the logical context will perform a single
// check.
//
void setup_QF_DT();
void setup_QF_UF();
void setup_QF_UF(static_features const & st);
void setup_QF_RDL();
void setup_QF_RDL(static_features & st);
void setup_QF_IDL();
void setup_QF_IDL(static_features & st);
void setup_QF_UFIDL();
void setup_QF_UFIDL(static_features & st);
void setup_QF_LRA();
void setup_QF_LRA(static_features const & st);
void setup_QF_LIA();
void setup_QF_LIRA(static_features const& st);
void setup_QF_LIA(static_features const & st);
void setup_QF_UFLIA();
void setup_QF_UFLIA(static_features & st);
void setup_QF_UFLRA();
void setup_QF_BV();
void setup_QF_AUFBV();
void setup_QF_AX();
void setup_QF_AX(static_features const & st);
void setup_QF_AUFLIA();
void setup_QF_AUFLIA(static_features const & st);
void setup_QF_FP();
void setup_QF_FPBV();
void setup_QF_S();
void setup_LRA();
void setup_CSP();
void setup_AUFLIA(bool simple_array = true);
void setup_AUFLIA(static_features const & st);
void setup_AUFLIRA(bool simple_array = true);
void setup_UFNIA();
void setup_UFLRA();
void setup_AUFLIAp();
void setup_AUFNIRA();
void setup_QF_BVRE();
void setup_unknown();
void setup_unknown(static_features & st);
public:
setup(ast_manager& m, dependent_expr_state& st, smt_params & params);
void setk_already_configured() { m_already_configured = true; }
bool already_configured() const { return m_already_configured; }
symbol const & get_logic() const { return m_logic; }
void operator()();
};
};

View file

@ -1,5 +1,5 @@
/*++
Copyright (c) 2020 Microsoft Corporation
Copyright (c) 2023 Microsoft Corporation
Module Name:
@ -112,6 +112,8 @@ namespace arith {
for (unsigned v = 0; v < s.s().num_vars(); ++v)
init_bool_var_assignment(v);
m_best_min_unsat = std::numeric_limits<unsigned>::max();
d->set(this);
}
void sls::set_bounds_begin() {
@ -209,14 +211,13 @@ namespace arith {
unsigned start = s.random();
unsigned sz = unsat().size();
for (unsigned i = sz; i-- > 0; )
if (flip(unsat().elem_at((i + start) % sz)))
if (flip_clause(unsat().elem_at((i + start) % sz)))
return true;
return false;
}
bool sls::flip(unsigned cl) {
bool sls::flip_clause(unsigned cl) {
auto const& clause = get_clause(cl);
int64_t new_value;
for (literal lit : clause) {
if (is_true(lit))
continue;
@ -224,20 +225,32 @@ namespace arith {
if (!ineq)
continue;
SASSERT(!ineq->is_true());
for (auto const& [coeff, v] : ineq->m_args) {
if (!cm(*ineq, v, new_value))
continue;
int score = cm_score(v, new_value);
if (score <= 0)
continue;
unsigned num_unsat = unsat().size();
update(v, new_value);
IF_VERBOSE(2,
verbose_stream() << "v" << v << " score " << score << " "
<< num_unsat << " -> " << unsat().size() << "\n");
SASSERT(num_unsat > unsat().size());
if (flip(*ineq))
return true;
}
}
return false;
}
// flip on the first positive score
// it could be changed to flip on maximal positive score
// or flip on maximal non-negative score
// or flip on first non-negative score
bool sls::flip(ineq const& ineq) {
int64_t new_value;
for (auto const& [coeff, v] : ineq.m_args) {
if (!cm(ineq, v, new_value))
continue;
int score = cm_score(v, new_value);
if (score <= 0)
continue;
unsigned num_unsat = unsat().size();
update(v, new_value);
IF_VERBOSE(2,
verbose_stream() << "v" << v << " score " << score << " "
<< num_unsat << " -> " << unsat().size() << "\n");
SASSERT(num_unsat > unsat().size());
return true;
}
return false;
}
@ -246,7 +259,7 @@ namespace arith {
unsigned start = s.random();
unsigned sz = m_bool_search->num_clauses();
for (unsigned i = sz; i-- > 0; )
if (flip((i + start) % sz))
if (flip_clause((i + start) % sz))
return true;
return false;
}
@ -541,9 +554,85 @@ namespace arith {
void sls::init_literal_assignment(sat::literal lit) {
auto* ineq = m_literals.get(lit.index(), nullptr);
if (ineq && is_true(lit) != (dtt(*ineq) == 0))
m_bool_search->flip(lit.var());
}
void sls::init_search() {
on_restart();
}
void sls::finish_search() {
store_best_values();
}
void sls::flip(sat::bool_var v) {
sat::literal lit(v, m_bool_search->get_value(v));
SASSERT(!is_true(lit));
auto const* ineq = atom(lit);
if (!ineq)
IF_VERBOSE(0, verbose_stream() << "no inequality for variable " << v << "\n");
if (!ineq)
return;
IF_VERBOSE(1, verbose_stream() << "flip " << lit << "\n");
SASSERT(!ineq->is_true());
flip(*ineq);
}
double sls::reward(sat::bool_var v) {
if (m_dscore_mode)
return dscore_reward(v);
else
return dtt_reward(v);
}
double sls::dtt_reward(sat::bool_var v) {
sat::literal litv(v, m_bool_search->get_value(v));
auto const* ineq = atom(litv);
if (!ineq)
return 0;
int64_t new_value;
double result = 0;
for (auto const & [coeff, x] : ineq->m_args) {
if (!cm(*ineq, x, new_value))
continue;
for (auto const [coeff, lit] : m_vars[x].m_literals) {
auto dtt_old = dtt(*atom(lit));
auto dtt_new = dtt(*atom(lit), x, new_value);
if ((dtt_new == 0) != (dtt_old == 0))
result += m_bool_search->reward(lit.var());
}
}
return result;
}
double sls::dscore_reward(sat::bool_var x) {
m_dscore_mode = false;
sat::literal litv(x, m_bool_search->get_value(x));
auto const* ineq = atom(litv);
if (!ineq)
return 0;
SASSERT(!ineq->is_true());
int64_t new_value;
double result = 0;
for (auto const& [coeff, v] : ineq->m_args)
if (cm(*ineq, v, new_value))
result += dscore(v, new_value);
return result;
}
// switch to dscore mode
void sls::on_rescale() {
m_dscore_mode = true;
}
void sls::on_save_model() {
save_best_values();
}
void sls::on_restart() {
for (unsigned v = 0; v < s.s().num_vars(); ++v)
init_bool_var_assignment(v);
}
}

View file

@ -37,7 +37,7 @@ namespace arith {
class solver;
// local search portion for arithmetic
class sls {
class sls : public sat::local_search_plugin {
enum class ineq_kind { EQ, LE, LT, NE };
enum class var_kind { INT, REAL };
typedef unsigned var_t;
@ -78,7 +78,7 @@ namespace arith {
std::ostream& display(std::ostream& out) const {
bool first = true;
for (auto const& [c, v] : m_args)
out << (first? "": " + ") << c << " * v" << v, first = false;
out << (first ? "" : " + ") << c << " * v" << v, first = false;
switch (m_op) {
case ineq_kind::LE:
return out << " <= " << m_bound << "(" << m_args_value << ")";
@ -97,7 +97,7 @@ namespace arith {
int64_t m_value;
int64_t m_best_value;
var_kind m_kind = var_kind::INT;
vector<std::pair<int64_t, sat::literal>> m_literals;
svector<std::pair<int64_t, sat::literal>> m_literals;
};
struct clause {
@ -116,6 +116,7 @@ namespace arith {
vector<var_info> m_vars;
vector<clause> m_clauses;
svector<std::pair<lp::tv, euf::theory_var>> m_terms;
bool m_dscore_mode = false;
indexed_uint_set& unsat() { return m_bool_search->unsat_set(); }
@ -136,7 +137,8 @@ namespace arith {
bool flip_clauses();
bool flip_dscore();
bool flip_dscore(unsigned cl);
bool flip(unsigned cl);
bool flip_clause(unsigned cl);
bool flip(ineq const& ineq);
int64_t dtt(ineq const& ineq) const { return dtt(ineq.m_args_value, ineq); }
int64_t dtt(int64_t args, ineq const& ineq) const;
int64_t dtt(ineq const& ineq, var_t v, int64_t new_value) const;
@ -145,6 +147,8 @@ namespace arith {
bool cm(ineq const& ineq, var_t v, int64_t& new_value);
int cm_score(var_t v, int64_t new_value);
void update(var_t v, int64_t new_value);
double dscore_reward(sat::bool_var v);
double dtt_reward(sat::bool_var v);
void paws();
int64_t dscore(var_t v, int64_t new_value) const;
void save_best_values();
@ -163,11 +167,20 @@ namespace arith {
public:
sls(solver& s);
~sls() override {}
lbool operator ()(bool_vector& phase);
void set_bounds_begin();
void set_bounds_end(unsigned num_literals);
void set_bounds(euf::enode* n);
void set(sat::ddfw* d);
void init_search() override;
void finish_search() override;
void flip(sat::bool_var v) override;
double reward(sat::bool_var v) override;
void on_rescale() override;
void on_save_model() override;
void on_restart() override;
};
inline std::ostream& operator<<(std::ostream& out, sls::ineq const& ineq) {

View file

@ -29,44 +29,13 @@ namespace euf {
bool_search.set_seed(rand());
scoped_rl.push_child(&(bool_search.rlimit()));
unsigned max_rounds = 30;
for (auto* th : m_solvers)
th->set_bool_search(&bool_search);
for (unsigned rounds = 0; m.inc() && rounds < max_rounds; ++rounds) {
bool_search.rlimit().push(m_max_bool_steps);
lbool r = bool_search.check(0, nullptr, nullptr);
bool_search.rlimit().pop();
setup_bounds(bool_search, phase);
// Non-boolean literals are assumptions to Boolean search
literal_vector assumptions;
#if 0
for (unsigned v = 0; v < phase.size(); ++v)
if (!is_propositional(literal(v)))
assumptions.push_back(literal(v, !bool_search.get_value(v)));
#endif
verbose_stream() << "assumptions " << assumptions.size() << "\n";
bool_search.rlimit().push(m_max_bool_steps);
lbool r = bool_search.check(assumptions.size(), assumptions.data(), nullptr);
bool_search.rlimit().pop();
#if 0
// restore state to optimal model
auto const& mdl = bool_search.get_model();
for (unsigned i = 0; i < mdl.size(); ++i)
if ((mdl[i] == l_true) != bool_search.get_value(i))
bool_search.flip(i);
#endif
for (auto* th : m_solvers)
th->local_search(phase);
if (bool_search.unsat_set().empty())
break;
}
auto const& mdl = bool_search.get_model();
for (unsigned i = 0; i < mdl.size(); ++i)
phase[i] = mdl[i] == l_true;