3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-18 14:49:01 +00:00

use common infrastructure for sls-smt

This commit is contained in:
Nikolaj Bjorner 2024-10-25 23:29:26 -07:00
parent 894bfc7e17
commit 848bfb14a1
5 changed files with 126 additions and 407 deletions

View file

@ -39,9 +39,8 @@ namespace sls {
smt_plugin::~smt_plugin() {
SASSERT(!m_ddfw);
}
void smt_plugin::check(expr_ref_vector const& fmls) {
void smt_plugin::check(expr_ref_vector const& fmls, vector <sat::literal_vector> const& clauses) {
SASSERT(!m_ddfw);
// set up state for local search theory_sls here
m_result = l_undef;
@ -53,29 +52,37 @@ namespace sls {
m_ddfw->set_plugin(this);
m_ddfw->updt_params(ctx.get_params());
for (auto const& clause : clauses) {
m_ddfw->add(clause.size(), clause.data());
for (auto lit : clause)
add_shared_var(lit.var(), lit.var());
}
for (auto v : m_shared_bool_vars) {
expr* e = ctx.bool_var2expr(v);
if (!e)
continue;
m_context.register_atom(v, m_smt2sls_tr(e));
for (auto t : subterms::all(expr_ref(e, m)))
add_shared_term(e);
}
for (auto fml : fmls)
m_context.add_constraint(m_smt2sls_tr(fml));
// m_context.display(verbose_stream());
for (unsigned v = 0; v < ctx.get_num_bool_vars(); ++v) {
expr* e = ctx.bool_var2expr(v);
if (!e)
continue;
for (auto t : subterms::all(expr_ref(e, m)))
add_shared_term(e);
expr_ref sls_e(m_sls);
sls_e = m_smt2sls_tr(e);
auto w = m_context.atom2bool_var(sls_e);
if (w != sat::null_bool_var) {
m_smt_bool_var2sls_bool_var.setx(v, w, sat::null_bool_var);
m_sls_bool_var2smt_bool_var.setx(w, v, sat::null_bool_var);
m_sls_phase.reserve(v + 1);
m_sat_phase.reserve(v + 1);
m_rewards.reserve(v + 1);
m_shared_bool_vars.insert(v);
}
if (w == sat::null_bool_var)
continue;
add_shared_var(v, w);
for (auto t : subterms::all(expr_ref(e, m)))
add_shared_term(e);
}
m_thread = std::thread([this]() { run(); });
@ -139,6 +146,14 @@ namespace sls {
return false;
}
void smt_plugin::add_shared_var(sat::bool_var v, sat::bool_var w) {
m_smt_bool_var2sls_bool_var.setx(v, w, sat::null_bool_var);
m_sls_bool_var2smt_bool_var.setx(w, v, sat::null_bool_var);
m_sls_phase.reserve(v + 1);
m_sat_phase.reserve(v + 1);
m_rewards.reserve(v + 1);
m_shared_bool_vars.insert(v);
}
void smt_plugin::add_unit(sat::literal lit) {
if (!is_shared(lit))

View file

@ -80,8 +80,8 @@ namespace sls {
bool is_shared(sat::literal lit);
void run();
void add_shared_term(expr* t);
void add_uninterp(expr* smt_t);
void add_shared_var(sat::bool_var v, sat::bool_var w);
void import_phase_from_smt();
void import_values_from_sls();
@ -93,7 +93,6 @@ namespace sls {
void export_activity_to_smt();
void export_phase_to_smt();
void export_from_sls();
friend class sat::ddfw;
@ -103,7 +102,7 @@ namespace sls {
smt_plugin(smt_context& ctx);
// interface to calling solver:
void check(expr_ref_vector const& fmls);
void check(expr_ref_vector const& fmls, vector <sat::literal_vector> const& clauses);
void finalize(model_ref& md, ::statistics& st);
void updt_params(params_ref& p) {}
std::ostream& display(std::ostream& out) override;

View file

@ -23,399 +23,108 @@ Author:
namespace sls {
#ifdef SINGLE_THREAD
solver::solver(euf::solver& ctx) :
th_euf_solver(ctx, symbol("sls"), ctx.get_manager().mk_family_id("sls"))
{}
{}
#ifdef SINGLE_THREAD
#else
solver::solver(euf::solver& ctx):
th_euf_solver(ctx, symbol("sls"), ctx.get_manager().mk_family_id("sls"))
{}
solver::~solver() {
finalize();
}
class solver::smt_plugin : public sat::local_search_plugin, public sls::sat_solver_context {
ast_manager& m;
solver& s;
sat::ddfw* m_ddfw;
sls::context m_context;
bool m_new_clause_added = false;
unsigned m_min_unsat_size = UINT_MAX;
ast_manager m_sync_manager;
obj_map<expr, expr*> m_sls2sync_uninterp; // hashtable from sls-uninterp to sync uninterp
obj_map<expr, expr*> m_smt2sync_uninterp; // hashtable from external uninterp to sync uninterp
ast_translation m_smt2sync_tr, m_smt2sls_tr;
expr_ref_vector m_sync_uninterp;
expr_ref_vector m_sync_values;
std::atomic<bool> m_has_new_sls_values = false;
// export from SAT to SLS:
// - unit literals
// - phase
bool export_units_to_sls() {
if (!s.m_has_units)
return false;
std::lock_guard<std::mutex> lock(s.m_mutex);
IF_VERBOSE(2, verbose_stream() << "SMT -> SLS units " << s.m_units << "\n");
for (auto lit : s.m_units) {
if (m_shared_bool_vars.contains(lit.var())) {
IF_VERBOSE(10, verbose_stream() << "unit " << lit << "\n");
m_ddfw->add(1, &lit);
}
else {
IF_VERBOSE(0, verbose_stream() << "value restriction " << lit << " " << mk_bounded_pp(s.ctx.bool_var2expr(lit.var()), s.ctx.get_manager()) << "\n");
}
}
s.m_has_units = false;
s.m_units.reset();
return true;
}
bool export_phase_to_sls() {
if (!m_has_new_sat_phase)
return false;
std::lock_guard<std::mutex> lock(s.m_mutex);
IF_VERBOSE(3, verbose_stream() << "SMT -> SLS phase\n");
for (auto i : m_shared_bool_vars) {
if (m_sat_phase[i] != is_true(sat::literal(i, false)))
flip(i);
m_ddfw->bias(i) = m_sat_phase[i] ? 1 : -1;
}
m_has_new_sat_phase = false;
return true;
}
bool export_to_sls() {
bool updated = false;
if (export_units_to_sls())
updated = true;
if (export_phase_to_sls())
updated = true;
return updated;
}
// import from SLS:
// - phase
// - values
// - activity
void import_from_sls() {
if (unsat().size() > m_min_unsat_size)
return;
m_min_unsat_size = unsat().size();
std::lock_guard<std::mutex> lock(s.m_mutex);
for (auto v : m_shared_bool_vars) {
m_rewards[v] = m_ddfw->get_reward_avg(v);
m_sls_phase[v] = l_true == m_ddfw->get_model()[v];
m_has_new_sls_phase = true;
}
// import_values_from_sls();
}
void import_values_from_sls() {
IF_VERBOSE(3, verbose_stream() << "import values from sls\n");
std::lock_guard<std::mutex> lock(s.m_mutex);
ast_translation tr(m, m_sync_manager);
for (auto const& [t, t_sync] : m_sls2sync_uninterp) {
expr_ref val_t = m_context.get_value(t);
m_sync_values.set(t_sync->get_id(), tr(val_t.get()));
}
m_has_new_sls_values = true;
}
void add_uninterp(expr* smt_t) {
auto sync_t = m_smt2sync_tr(smt_t);
auto sls_t = m_smt2sls_tr(smt_t);
m_sync_uninterp.push_back(sync_t);
m_smt2sync_uninterp.insert(smt_t, sync_t);
m_sls2sync_uninterp.insert(sls_t, sync_t);
}
public:
smt_plugin(ast_manager& m, solver& s, sat::ddfw* d) :
m(m), s(s), m_ddfw(d), m_context(m, *this),
m_sync_uninterp(m_sync_manager),
m_sync_values(m_sync_manager),
m_smt2sync_tr(s.ctx.get_manager(), m_sync_manager),
m_smt2sls_tr(s.ctx.get_manager(), m)
{
}
uint_set m_shared_bool_vars, m_shared_terms;
svector<bool> m_sat_phase;
std::atomic<bool> m_has_new_sat_phase = false;
std::atomic<bool> m_has_new_sls_phase = false;
svector<bool> m_sls_phase;
svector<double> m_rewards;
void add_shared_term(expr* t) {
m_shared_terms.insert(t->get_id());
if (is_uninterp(t))
add_uninterp(t);
}
void add_shared_var(sat::bool_var v) {
m_sls_phase.reserve(v + 1);
m_sat_phase.reserve(v + 1);
m_rewards.reserve(v + 1);
m_shared_bool_vars.insert(v);
}
void init_search() override {}
void finish_search() override {}
void on_rescale() override {}
void on_restart() override {
if (export_to_sls())
m_ddfw->reinit();
}
void on_save_model() override {
TRACE("sls", display(tout));
while (unsat().empty()) {
m_context.check();
if (!m_new_clause_added)
break;
m_ddfw->reinit();
m_new_clause_added = false;
}
//import_from_sls();
}
void on_model(model_ref& mdl) override {
IF_VERBOSE(3, verbose_stream() << "on-model " << "\n");
s.m_sls_model = mdl;
}
void register_atom(sat::bool_var v, expr* e) {
m_context.register_atom(v, e);
}
std::ostream& display(std::ostream& out) {
m_ddfw->display(out);
m_context.display(out);
return out;
}
vector<sat::clause_info> const& clauses() const override { return m_ddfw->clauses(); }
sat::clause_info const& get_clause(unsigned idx) const override { return m_ddfw->get_clause_info(idx); }
ptr_iterator<unsigned> get_use_list(sat::literal lit) override { return m_ddfw->use_list(lit); }
void flip(sat::bool_var v) override {
m_ddfw->flip(v);
}
double reward(sat::bool_var v) override { return m_ddfw->get_reward(v); }
double get_weigth(unsigned clause_idx) override { return m_ddfw->get_clause_info(clause_idx).m_weight; }
bool is_true(sat::literal lit) override {
return m_ddfw->get_value(lit.var()) != lit.sign();
}
unsigned num_vars() const override { return m_ddfw->num_vars(); }
indexed_uint_set const& unsat() const override { return m_ddfw->unsat_set(); }
sat::bool_var add_var() override {
return m_ddfw->add_var();
}
void add_clause(unsigned n, sat::literal const* lits) override {
m_ddfw->add(n, lits);
m_new_clause_added = true;
}
void force_restart() override { m_ddfw->force_restart(); }
void export_values_to_smt() {
if (!m_has_new_sls_values)
return;
IF_VERBOSE(3, verbose_stream() << "SLS -> SMT values\n");
std::lock_guard<std::mutex> lock(s.m_mutex);
ast_translation tr(m_sync_manager, s.ctx.get_manager());
for (auto const& [t, t_sync] : m_smt2sync_uninterp) {
expr* sync_val = m_sync_values.get(t_sync->get_id(), nullptr);
if (sync_val)
s.ctx.user_propagate_initialize_value(t, tr(sync_val));
}
m_has_new_sls_values = false;
}
void export_phase_to_smt() {
if (!m_has_new_sls_phase)
return;
IF_VERBOSE(3, verbose_stream() << "new SLS -> SMT phase\n");
std::lock_guard<std::mutex> lock(s.m_mutex);
for (unsigned i = 0; i < m_sls_phase.size(); ++i)
s.s().set_phase(sat::literal(i, !m_sls_phase[i]));
m_has_new_sls_phase = false;
}
void import_phase_from_smt() {
if (m_has_new_sat_phase)
return;
m_has_new_sat_phase = true;
IF_VERBOSE(3, verbose_stream() << "new SMT -> SLS phase\n");
s.s().set_has_new_best_phase(false);
std::lock_guard<std::mutex> lock(s.m_mutex);
for (auto v : m_shared_bool_vars)
m_sat_phase[v] = s.s().get_best_phase(v);
}
void export_activity_to_smt() {
// TODO
}
// determine if unit literal restricts values of shared subterms.
bool is_value_restriction(sat::literal lit) {
auto e = s.ctx.bool_var2expr(lit.var());
expr* t = nullptr;
if (!e)
return false;
bv_util bv(s.ctx.get_manager());
if (bv.is_bit2bool(e, t) && m_shared_terms.contains(t->get_id())) {
verbose_stream() << "shared bit2bool " << mk_bounded_pp(e, s.ctx.get_manager()) << "\n";
return true;
}
// if arith.is_le(e, s, t) && t is a numeral, s is shared-term....
return false;
}
};
void solver::finalize() {
if (!m_completed && m_ddfw) {
m_ddfw->rlimit().cancel();
m_thread.join();
m_ddfw->collect_statistics(m_st);
m_ddfw = nullptr;
m_slsm = nullptr;
m_smt_plugin = nullptr;
m_units.reset();
}
params_ref solver::get_params() {
return s().params();
}
sat::check_result solver::check() {
return sat::check_result::CR_DONE;
void solver::initialize_value(expr* t, expr* v) {
ctx.user_propagate_initialize_value(t, v);
}
void solver::force_phase(sat::literal lit) {
ctx.s().set_phase(lit);
}
void solver::set_has_new_best_phase(bool b) {
}
bool solver::get_best_phase(sat::bool_var v) {
return false;
}
expr* solver::bool_var2expr(sat::bool_var v) {
return ctx.bool_var2expr(v);
}
void solver::set_finished() {
m.limit().cancel();
}
unsigned solver::get_num_bool_vars() const {
return s().num_vars();
}
void solver::finalize() {
if (!m_smt_plugin)
return;
m_smt_plugin->finalize(m_model, m_st);
m_model = nullptr;
m_smt_plugin = nullptr;
}
bool solver::unit_propagate() {
force_push();
sample_local_search();
return false;
if (m_smt_plugin && !m_checking) {
expr_ref_vector fmls(m);
m_checking = true;
m_smt_plugin->check(fmls, ctx.top_level_clauses());
return true;
}
if (!m_smt_plugin)
return false;
if (!m_smt_plugin->completed())
return false;
m_smt_plugin->finalize(m_model, m_st);
m_smt_plugin = nullptr;
return true;
}
void solver::pop_core(unsigned n) {
if (!m_smt_plugin)
return;
for (; m_trail_lim < s().init_trail_size(); ++m_trail_lim) {
auto lit = s().trail_literal(m_trail_lim);
if (m_smt_plugin->is_value_restriction(lit) ||
m_smt_plugin->m_shared_bool_vars.contains(lit.var())) {
std::lock_guard<std::mutex> lock(m_mutex);
m_units.push_back(lit);
m_has_units = true;
}
}
if (s().has_new_best_phase())
m_smt_plugin->import_phase_from_smt();
m_smt_plugin->export_phase_to_smt();
m_smt_plugin->export_activity_to_smt();
m_smt_plugin->export_values_to_smt();
}
unsigned scope_lvl = s().scope_lvl();
if (s().search_lvl() == scope_lvl - n) {
for (; m_trail_lim < s().init_trail_size(); ++m_trail_lim) {
auto lit = s().trail_literal(m_trail_lim);
m_smt_plugin->add_unit(lit);
}
}
#if 0
if (ctx.has_new_best_phase())
m_smt_plugin->import_phase_from_smt();
#endif
m_smt_plugin->import_from_sls();
}
void solver::init_search() {
if (m_ddfw) {
m_ddfw->rlimit().cancel();
m_thread.join();
}
// set up state for local search solver here
m_result = l_undef;
m_completed = false;
m_slsm = alloc(ast_manager);
m_units.reset();
m_has_units = false;
m_model = nullptr;
m_sls_model = nullptr;
m_ddfw = alloc(sat::ddfw);
ast_translation tr(m, *m_slsm);
scoped_limits scoped_limits(m.limit());
scoped_limits.push_child(&m_slsm->limit());
scoped_limits.push_child(&m_ddfw->rlimit());
m_smt_plugin = alloc(smt_plugin, *m_slsm, *this, m_ddfw.get());
m_ddfw->set_plugin(m_smt_plugin);
m_ddfw->updt_params(s().params());
for (auto const& clause : ctx.top_level_clauses()) {
m_ddfw->add(clause.size(), clause.data());
for (auto lit : clause)
m_smt_plugin->add_shared_var(lit.var());
}
for (auto v : m_smt_plugin->m_shared_bool_vars) {
expr* e = ctx.bool_var2expr(v);
if (!e)
continue;
m_smt_plugin->register_atom(v, tr(e));
for (auto t : subterms::all(expr_ref(e, m)))
m_smt_plugin->add_shared_term(e);
}
m_thread = std::thread([this]() { run_local_search_async(); });
}
void solver::sample_local_search() {
if (!m_completed)
return;
m_thread.join();
local_search_done();
}
void solver::local_search_done() {
IF_VERBOSE(1, verbose_stream() << "local-search-done\n");
m_completed = false;
CTRACE("sls", m_smt_plugin, m_smt_plugin->display(tout));
if (m_ddfw)
m_ddfw->collect_statistics(m_st);
TRACE("sls", tout << "result " << m_result << "\n");
if (m_result == l_true && m_sls_model) {
ast_translation tr(*m_slsm, m);
m_model = m_sls_model->translate(tr);
TRACE("sls", tout << "model: " << *m_sls_model << "\n";);
s().set_canceled();
}
m_ddfw = nullptr;
m_smt_plugin = nullptr;
m_sls_model = nullptr;
}
void solver::run_local_search_async() {
if (m_ddfw) {
m_result = m_ddfw->check(0, nullptr);
IF_VERBOSE(1, verbose_stream() << "sls-result " << m_result << "\n");
m_completed = true;
}
}
void solver::run_local_search_sync() {
m_result = m_ddfw->check(0, nullptr);
local_search_done();
if (m_smt_plugin)
finalize();
m_smt_plugin = alloc(sls::smt_plugin, *this);
m_checking = false;
}
std::ostream& solver::display(std::ostream& out) const {
out << "sls-solver\n";
return out;
return out << "theory-sls\n";
}
#endif
}

View file

@ -54,6 +54,7 @@ namespace sls {
#include <thread>
#include <mutex>
#include "ast/sls/sls_smt_plugin.h"
namespace euf {
class solver;
@ -61,29 +62,12 @@ namespace euf {
namespace sls {
class solver : public euf::th_euf_solver {
class smt_plugin;
std::atomic<lbool> m_result;
std::atomic<bool> m_completed, m_has_units;
std::thread m_thread;
std::mutex m_mutex;
// m is accessed by the main thread
// m_slsm is accessed by the sls thread
scoped_ptr<ast_manager> m_slsm;
scoped_ptr<sat::ddfw> m_ddfw;
sat::literal_vector m_units;
smt_plugin* m_smt_plugin = nullptr;
model_ref m_model, m_sls_model;
class solver : public euf::th_euf_solver, public sls::smt_context {
model_ref m_model;
sls::smt_plugin* m_smt_plugin = nullptr;
unsigned m_trail_lim = 0;
statistics m_st;
void run_local_search_async();
void run_local_search_sync();
void sample_local_search();
void local_search_done();
bool m_checking = false;
::statistics m_st;
public:
solver(euf::solver& ctx);
@ -102,10 +86,21 @@ namespace sls {
sat::literal internalize(expr* e, bool sign, bool root) override { UNREACHABLE(); return sat::null_literal; }
void internalize(expr* e) override { UNREACHABLE(); }
void get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector & r, bool probing) override { UNREACHABLE(); }
sat::check_result check() override;
sat::check_result check() override { return sat::check_result::CR_DONE; }
std::ostream& display(std::ostream& out) const override;
std::ostream & display_justification(std::ostream & out, sat::ext_justification_idx idx) const override { UNREACHABLE(); return out; }
std::ostream & display_constraint(std::ostream & out, sat::ext_constraint_idx idx) const override { UNREACHABLE(); return out; }
ast_manager& get_manager() override { return m; }
params_ref get_params() override;
void initialize_value(expr* t, expr* v) override;
void force_phase(sat::literal lit) override;
void set_has_new_best_phase(bool b) override;
bool get_best_phase(sat::bool_var v) override;
expr* bool_var2expr(sat::bool_var v) override;
void set_finished() override;
unsigned get_num_bool_vars() const override;
};

View file

@ -79,7 +79,8 @@ namespace smt {
for (unsigned i = 0; i < ctx.get_num_asserted_formulas(); ++i)
fmls.push_back(ctx.get_asserted_formula(i));
m_checking = true;
m_smt_plugin->check(fmls);
vector<sat::literal_vector> clauses;
m_smt_plugin->check(fmls, clauses);
return;
}
if (!m_smt_plugin)