mirror of
https://github.com/Z3Prover/z3
synced 2025-04-22 16:45:31 +00:00
add plugin to smt_context, factor out sls_smt_plugin functionality.
This commit is contained in:
parent
f453cdec92
commit
ef95b4eaf2
14 changed files with 744 additions and 31 deletions
|
@ -15,7 +15,8 @@ z3_add_component(ast_sls
|
|||
sls_context.cpp
|
||||
sls_datatype_plugin.cpp
|
||||
sls_euf_plugin.cpp
|
||||
sls_smt_solver.cpp
|
||||
sls_smt_plugin.cpp
|
||||
sls_smt_solver.cpp
|
||||
COMPONENT_DEPENDENCIES
|
||||
ast
|
||||
euf
|
||||
|
|
|
@ -321,6 +321,8 @@ namespace sls {
|
|||
expr_ref _e(f, m);
|
||||
expr* g, * h, * k;
|
||||
sat::literal_vector clause;
|
||||
if (m.is_true(f))
|
||||
return;
|
||||
if (m.is_not(f, g) && m.is_not(g, g)) {
|
||||
add_clause(g);
|
||||
return;
|
||||
|
@ -486,7 +488,7 @@ namespace sls {
|
|||
for (sat::literal lit : m_unit_literals)
|
||||
m_unit_indices.insert(lit.index());
|
||||
|
||||
verbose_stream() << "UNITS " << m_unit_literals << "\n";
|
||||
IF_VERBOSE(0, verbose_stream() << "UNITS " << m_unit_literals << "\n");
|
||||
for (unsigned i = 0; i < m_atoms.size(); ++i)
|
||||
if (m_atoms.get(i))
|
||||
register_terms(m_atoms.get(i));
|
||||
|
|
263
src/ast/sls/sls_smt_plugin.cpp
Normal file
263
src/ast/sls/sls_smt_plugin.cpp
Normal file
|
@ -0,0 +1,263 @@
|
|||
/*++
|
||||
Copyright (c) 2024 Microsoft Corporation
|
||||
|
||||
Module Name:
|
||||
|
||||
sls_smt_plugin.cpp
|
||||
|
||||
Abstract:
|
||||
|
||||
A Stochastic Local Search (SLS) Plugin.
|
||||
|
||||
Author:
|
||||
|
||||
Nikolaj Bjorner (nbjorner) 2024-07-10
|
||||
|
||||
--*/
|
||||
|
||||
|
||||
#include "ast/sls/sls_smt_plugin.h"
|
||||
#include "ast/for_each_expr.h"
|
||||
#include "ast/bv_decl_plugin.h"
|
||||
|
||||
namespace sls {
|
||||
|
||||
smt_plugin::smt_plugin(smt_context& ctx) :
|
||||
ctx(ctx),
|
||||
m(ctx.get_manager()),
|
||||
m_sls(),
|
||||
m_sync(),
|
||||
m_smt2sync_tr(m, m_sync),
|
||||
m_smt2sls_tr(m, m_sls),
|
||||
m_sync_uninterp(m_sync),
|
||||
m_sls_uninterp(m_sls),
|
||||
m_sync_values(m_sync),
|
||||
m_context(m_sls, *this)
|
||||
{
|
||||
}
|
||||
|
||||
smt_plugin::~smt_plugin() {
|
||||
SASSERT(!m_ddfw);
|
||||
}
|
||||
|
||||
|
||||
void smt_plugin::check(expr_ref_vector const& fmls) {
|
||||
SASSERT(!m_ddfw);
|
||||
// set up state for local search theory_sls here
|
||||
m_result = l_undef;
|
||||
m_completed = false;
|
||||
m_units.reset();
|
||||
m_has_units = false;
|
||||
m_model = nullptr;
|
||||
m_sls_model = nullptr;
|
||||
m_ddfw = alloc(sat::ddfw);
|
||||
m_ddfw->set_plugin(this);
|
||||
m_ddfw->updt_params(ctx.get_params());
|
||||
|
||||
for (auto fml : fmls)
|
||||
m_context.add_constraint(m_smt2sls_tr(fml));
|
||||
|
||||
// m_context.display(verbose_stream());
|
||||
|
||||
for (unsigned v = 0; v < ctx.get_num_bool_vars(); ++v) {
|
||||
expr* e = ctx.bool_var2expr(v);
|
||||
if (!e)
|
||||
continue;
|
||||
for (auto t : subterms::all(expr_ref(e, m)))
|
||||
add_shared_term(e);
|
||||
|
||||
expr_ref sls_e(m_sls);
|
||||
sls_e = m_smt2sls_tr(e);
|
||||
auto w = m_context.atom2bool_var(sls_e);
|
||||
if (w != sat::null_bool_var) {
|
||||
verbose_stream() << mk_bounded_pp(e, m) << ": " << v << " -> " << w << "\n";
|
||||
m_smt_bool_var2sls_bool_var.setx(v, w, sat::null_bool_var);
|
||||
m_sls_bool_var2smt_bool_var.setx(w, v, sat::null_bool_var);
|
||||
|
||||
add_shared_var(w);
|
||||
}
|
||||
}
|
||||
|
||||
m_thread = std::thread([this]() { run(); });
|
||||
}
|
||||
|
||||
void smt_plugin::run() {
|
||||
if (!m_ddfw)
|
||||
return;
|
||||
m_result = m_ddfw->check(0, nullptr);
|
||||
IF_VERBOSE(1, verbose_stream() << "sls-result " << m_result << "\n");
|
||||
m_completed = true;
|
||||
}
|
||||
|
||||
void smt_plugin::finalize(model_ref& mdl, ::statistics& st) {
|
||||
auto* d = m_ddfw;
|
||||
if (!d)
|
||||
return;
|
||||
bool canceled = !m_completed;
|
||||
IF_VERBOSE(0, verbose_stream() << "finalize\n");
|
||||
mdl = m_model;
|
||||
if (!m_completed) {
|
||||
d->rlimit().cancel();
|
||||
if (m_thread.joinable())
|
||||
m_thread.join();
|
||||
}
|
||||
if (m_result == l_true && m_sls_model) {
|
||||
ast_translation tr(m_sls, m);
|
||||
m_model = m_sls_model->translate(tr);
|
||||
TRACE("sls", tout << "model: " << *m_sls_model << "\n";);
|
||||
if (!canceled)
|
||||
ctx.set_finished();
|
||||
}
|
||||
m_ddfw = nullptr;
|
||||
// m_ddfw owns the pointer to smt_plugin and destructs it.
|
||||
dealloc(d);
|
||||
}
|
||||
|
||||
void smt_plugin::collect_statistics(statistics& st) {
|
||||
|
||||
}
|
||||
std::ostream& smt_plugin::display(std::ostream& out) {
|
||||
m_ddfw->display(out);
|
||||
m_context.display(out);
|
||||
return out;
|
||||
}
|
||||
|
||||
bool smt_plugin::is_shared(sat::literal lit) {
|
||||
auto w = m_smt_bool_var2sls_bool_var.get(lit.var(), sat::null_bool_var);
|
||||
if (w != sat::null_bool_var)
|
||||
return true;
|
||||
auto e = ctx.bool_var2expr(lit.var());
|
||||
expr* t = nullptr;
|
||||
if (!e)
|
||||
return false;
|
||||
bv_util bv(m);
|
||||
if (bv.is_bit2bool(e, t) && m_shared_terms.contains(t->get_id())) {
|
||||
verbose_stream() << "shared bit2bool " << mk_bounded_pp(e, ctx.get_manager()) << "\n";
|
||||
return true;
|
||||
}
|
||||
|
||||
// if arith.is_le(e, s, t) && t is a numeral, s is shared-term....
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
void smt_plugin::add_unit(sat::literal lit) {
|
||||
if (!is_shared(lit))
|
||||
return;
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
m_units.push_back(lit);
|
||||
m_has_units = true;
|
||||
}
|
||||
|
||||
void smt_plugin::import_phase_from_smt() {
|
||||
if (m_has_new_sat_phase)
|
||||
return;
|
||||
m_has_new_sat_phase = true;
|
||||
IF_VERBOSE(3, verbose_stream() << "new SMT -> SLS phase\n");
|
||||
ctx.set_has_new_best_phase(false);
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
for (auto v : m_shared_bool_vars)
|
||||
m_sat_phase[v] = ctx.get_best_phase(v);
|
||||
}
|
||||
|
||||
bool smt_plugin::export_to_sls() {
|
||||
bool updated = false;
|
||||
if (export_units_to_sls())
|
||||
updated = true;
|
||||
if (export_phase_to_sls())
|
||||
updated = true;
|
||||
return updated;
|
||||
}
|
||||
|
||||
bool smt_plugin::export_phase_to_sls() {
|
||||
if (!m_has_new_sat_phase)
|
||||
return false;
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
IF_VERBOSE(3, verbose_stream() << "SMT -> SLS phase\n");
|
||||
for (auto i : m_shared_bool_vars) {
|
||||
auto v = m_smt_bool_var2sls_bool_var[i];
|
||||
if (m_sat_phase[v] != is_true(sat::literal(v, false)))
|
||||
flip(v);
|
||||
m_ddfw->bias(v) = m_sat_phase[v] ? 1 : -1;
|
||||
}
|
||||
m_has_new_sat_phase = false;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool smt_plugin::export_units_to_sls() {
|
||||
if (!m_has_units)
|
||||
return false;
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
IF_VERBOSE(2, verbose_stream() << "SMT -> SLS units " << m_units << "\n");
|
||||
for (auto lit : m_units) {
|
||||
if (m_shared_bool_vars.contains(lit.var())) {
|
||||
sat::literal sls_lit(m_smt_bool_var2sls_bool_var[lit.var()], false);
|
||||
IF_VERBOSE(10, verbose_stream() << "unit " << sls_lit << "\n");
|
||||
m_ddfw->add(1, &sls_lit);
|
||||
}
|
||||
else {
|
||||
IF_VERBOSE(0, verbose_stream() << "value restriction " << lit << " "
|
||||
<< mk_bounded_pp(ctx.bool_var2expr(lit.var()), m) << "\n");
|
||||
}
|
||||
}
|
||||
m_has_units = false;
|
||||
m_units.reset();
|
||||
return true;
|
||||
}
|
||||
|
||||
void smt_plugin::import_from_sls() {
|
||||
if (unsat().size() > m_min_unsat_size)
|
||||
return;
|
||||
m_min_unsat_size = unsat().size();
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
for (auto v : m_shared_bool_vars) {
|
||||
m_rewards[v] = m_ddfw->get_reward_avg(v);
|
||||
m_sls_phase[v] = l_true == m_ddfw->get_model()[v];
|
||||
m_has_new_sls_phase = true;
|
||||
}
|
||||
// import_values_from_sls();
|
||||
}
|
||||
|
||||
void smt_plugin::import_values_from_sls() {
|
||||
IF_VERBOSE(3, verbose_stream() << "import values from sls\n");
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
for (auto const& [t, t_sync] : m_sls2sync_uninterp) {
|
||||
expr_ref val_t = m_context.get_value(t);
|
||||
m_sync_values.set(t_sync->get_id(), m_smt2sync_tr(val_t.get()));
|
||||
}
|
||||
m_has_new_sls_values = true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void smt_plugin::export_activity_to_smt() {
|
||||
|
||||
}
|
||||
|
||||
void smt_plugin::export_values_to_smt() {
|
||||
|
||||
}
|
||||
|
||||
void smt_plugin::add_shared_term(expr* t) {
|
||||
m_shared_terms.insert(t->get_id());
|
||||
if (is_uninterp(t))
|
||||
add_uninterp(t);
|
||||
}
|
||||
|
||||
void smt_plugin::add_uninterp(expr* smt_t) {
|
||||
auto sync_t = m_smt2sync_tr(smt_t);
|
||||
auto sls_t = m_smt2sls_tr(smt_t);
|
||||
m_sync_uninterp.push_back(sync_t);
|
||||
m_sls_uninterp.push_back(sls_t);
|
||||
m_smt2sync_uninterp.insert(smt_t, sync_t);
|
||||
m_sls2sync_uninterp.insert(sls_t, sync_t);
|
||||
}
|
||||
|
||||
void smt_plugin::add_shared_var(sat::bool_var v) {
|
||||
m_sls_phase.reserve(v + 1);
|
||||
m_sat_phase.reserve(v + 1);
|
||||
m_rewards.reserve(v + 1);
|
||||
m_shared_bool_vars.insert(v);
|
||||
}
|
||||
|
||||
}
|
168
src/ast/sls/sls_smt_plugin.h
Normal file
168
src/ast/sls/sls_smt_plugin.h
Normal file
|
@ -0,0 +1,168 @@
|
|||
/*++
|
||||
Copyright (c) 2024 Microsoft Corporation
|
||||
|
||||
Module Name:
|
||||
|
||||
sls_smt_plugin.h
|
||||
|
||||
Abstract:
|
||||
|
||||
A Stochastic Local Search (SLS) Plugin.
|
||||
|
||||
Author:
|
||||
|
||||
Nikolaj Bjorner (nbjorner) 2024-07-10
|
||||
|
||||
--*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ast/sls/sls_context.h"
|
||||
#include "ast/sls/sat_ddfw.h"
|
||||
#include "util/statistics.h"
|
||||
#include <thread>
|
||||
#include <mutex>
|
||||
|
||||
namespace sls {
|
||||
|
||||
class smt_context {
|
||||
public:
|
||||
virtual ast_manager& get_manager() = 0;
|
||||
virtual params_ref get_params() = 0;
|
||||
virtual void initialize_value(expr* t, expr* v) = 0;
|
||||
virtual void force_phase(sat::literal lit) = 0;
|
||||
virtual void set_has_new_best_phase(bool b) = 0;
|
||||
virtual bool get_best_phase(sat::bool_var v) = 0;
|
||||
virtual expr* bool_var2expr(sat::bool_var v) = 0;
|
||||
virtual void set_finished() = 0;
|
||||
virtual unsigned get_num_bool_vars() const = 0;
|
||||
};
|
||||
|
||||
class smt_plugin : public sat::local_search_plugin, public sat_solver_context {
|
||||
smt_context& ctx;
|
||||
ast_manager& m;
|
||||
ast_manager m_sls;
|
||||
ast_manager m_sync;
|
||||
ast_translation m_smt2sync_tr, m_smt2sls_tr;
|
||||
expr_ref_vector m_sync_uninterp, m_sls_uninterp;
|
||||
expr_ref_vector m_sync_values;
|
||||
sat::ddfw* m_ddfw = nullptr;
|
||||
sls::context m_context;
|
||||
std::atomic<lbool> m_result;
|
||||
std::atomic<bool> m_completed, m_has_units;
|
||||
std::thread m_thread;
|
||||
std::mutex m_mutex;
|
||||
// m is accessed by the main thread
|
||||
// m_slsm is accessed by the sls thread
|
||||
sat::literal_vector m_units;
|
||||
model_ref m_model, m_sls_model;
|
||||
unsigned m_trail_lim = 0;
|
||||
::statistics m_st;
|
||||
bool m_new_clause_added = false;
|
||||
unsigned m_min_unsat_size = UINT_MAX;
|
||||
obj_map<expr, expr*> m_sls2sync_uninterp; // hashtable from sls-uninterp to sync uninterp
|
||||
obj_map<expr, expr*> m_smt2sync_uninterp; // hashtable from external uninterp to sync uninterp
|
||||
std::atomic<bool> m_has_new_sls_values = false;
|
||||
|
||||
uint_set m_shared_bool_vars, m_shared_terms;
|
||||
svector<bool> m_sat_phase;
|
||||
std::atomic<bool> m_has_new_sat_phase = false;
|
||||
|
||||
std::atomic<bool> m_has_new_sls_phase = false;
|
||||
svector<bool> m_sls_phase;
|
||||
|
||||
svector<double> m_rewards;
|
||||
svector<sat::bool_var> m_smt_bool_var2sls_bool_var, m_sls_bool_var2smt_bool_var;
|
||||
|
||||
|
||||
|
||||
bool is_shared(sat::literal lit);
|
||||
void run();
|
||||
void add_shared_term(expr* t);
|
||||
|
||||
void add_uninterp(expr* smt_t);
|
||||
void add_shared_var(sat::bool_var v);
|
||||
|
||||
void import_phase_from_smt();
|
||||
void import_values_from_sls();
|
||||
bool export_phase_to_sls();
|
||||
bool export_units_to_sls();
|
||||
void export_activity_to_smt();
|
||||
void export_values_to_smt();
|
||||
|
||||
|
||||
friend class sat::ddfw;
|
||||
~smt_plugin();
|
||||
|
||||
public:
|
||||
smt_plugin(smt_context& ctx);
|
||||
|
||||
// interface to calling solver:
|
||||
void check(expr_ref_vector const& fmls);
|
||||
void finalize(model_ref& md, ::statistics& st);
|
||||
void updt_params(params_ref& p) {}
|
||||
void collect_statistics(statistics& st);
|
||||
std::ostream& display(std::ostream& out) override;
|
||||
void import_from_sls();
|
||||
bool export_to_sls();
|
||||
bool completed() { return m_completed; }
|
||||
void add_unit(sat::literal lit);
|
||||
|
||||
|
||||
// local_search_plugin:
|
||||
void on_restart() override {
|
||||
if (export_to_sls())
|
||||
m_ddfw->reinit();
|
||||
}
|
||||
|
||||
void on_save_model() override {
|
||||
TRACE("sls", display(tout));
|
||||
while (unsat().empty()) {
|
||||
m_context.check();
|
||||
if (!m_new_clause_added)
|
||||
break;
|
||||
m_ddfw->reinit();
|
||||
m_new_clause_added = false;
|
||||
}
|
||||
//import_from_sls();
|
||||
}
|
||||
|
||||
void on_model(model_ref& mdl) override {
|
||||
IF_VERBOSE(3, verbose_stream() << "on-model " << "\n");
|
||||
m_sls_model = mdl;
|
||||
}
|
||||
|
||||
void init_search() override {}
|
||||
|
||||
void finish_search() override {}
|
||||
|
||||
void on_rescale() override {}
|
||||
|
||||
|
||||
|
||||
// sat_solver_context:
|
||||
vector<sat::clause_info> const& clauses() const override { return m_ddfw->clauses(); }
|
||||
sat::clause_info const& get_clause(unsigned idx) const override { return m_ddfw->get_clause_info(idx); }
|
||||
ptr_iterator<unsigned> get_use_list(sat::literal lit) override { return m_ddfw->use_list(lit); }
|
||||
void flip(sat::bool_var v) override {
|
||||
m_ddfw->flip(v);
|
||||
}
|
||||
double reward(sat::bool_var v) override { return m_ddfw->get_reward(v); }
|
||||
double get_weigth(unsigned clause_idx) override { return m_ddfw->get_clause_info(clause_idx).m_weight; }
|
||||
bool is_true(sat::literal lit) override {
|
||||
return m_ddfw->get_value(lit.var()) != lit.sign();
|
||||
}
|
||||
unsigned num_vars() const override { return m_ddfw->num_vars(); }
|
||||
indexed_uint_set const& unsat() const override { return m_ddfw->unsat_set(); }
|
||||
sat::bool_var add_var() override {
|
||||
return m_ddfw->add_var();
|
||||
}
|
||||
void add_clause(unsigned n, sat::literal const* lits) override {
|
||||
m_ddfw->add(n, lits);
|
||||
m_new_clause_added = true;
|
||||
}
|
||||
void force_restart() override { m_ddfw->force_restart(); }
|
||||
|
||||
|
||||
};
|
||||
}
|
|
@ -1248,7 +1248,7 @@ namespace arith {
|
|||
for (auto ev : m_explanation)
|
||||
set_evidence(ev.ci());
|
||||
|
||||
TRACE("arith",
|
||||
TRACE("arith_conflict",
|
||||
tout << "Lemma - " << (is_conflict ? "conflict" : "propagation") << "\n";
|
||||
for (literal c : m_core) tout << c << ": " << literal2expr(c) << "\n";
|
||||
for (auto p : m_eqs) tout << ctx.bpp(p.first) << " == " << ctx.bpp(p.second) << "\n";);
|
||||
|
@ -1270,6 +1270,7 @@ namespace arith {
|
|||
m_core.push_back(ctx.mk_literal(m.mk_eq(eq.first->get_expr(), eq.second->get_expr())));
|
||||
for (literal& c : m_core)
|
||||
c.neg();
|
||||
DEBUG_CODE(for (literal c : m_core) { SASSERT(s().value(c) != l_true); });
|
||||
|
||||
add_redundant(m_core, explain(ty));
|
||||
}
|
||||
|
@ -1520,10 +1521,13 @@ namespace arith {
|
|||
}
|
||||
for (auto const& ineq : m_nla->literals()) {
|
||||
auto lit = mk_ineq_literal(ineq);
|
||||
if (s().value(lit) == l_true)
|
||||
continue;
|
||||
ctx.mark_relevant(lit);
|
||||
s().set_phase(lit);
|
||||
verbose_stream() << lit << ":= " << s().value(lit) << "\n";
|
||||
// force trichotomy axiom for equality literals
|
||||
if (ineq.cmp() == lp::EQ) {
|
||||
if (ineq.cmp() == lp::EQ && false) {
|
||||
nla::lemma l;
|
||||
l.push_back(ineq);
|
||||
l.push_back(nla::ineq(lp::LT, ineq.term(), ineq.rs()));
|
||||
|
|
|
@ -62,12 +62,17 @@ namespace sls {
|
|||
if (!s.m_has_units)
|
||||
return false;
|
||||
std::lock_guard<std::mutex> lock(s.m_mutex);
|
||||
IF_VERBOSE(3, verbose_stream() << "SMT -> SLS units " << s.m_units << "\n");
|
||||
for (auto lit : s.m_units)
|
||||
if (m_shared_vars.contains(lit.var())) {
|
||||
IF_VERBOSE(10, verbose_stream() << "unit " << lit << "\n");
|
||||
IF_VERBOSE(2, verbose_stream() << "SMT -> SLS units " << s.m_units << "\n");
|
||||
for (auto lit : s.m_units) {
|
||||
if (m_shared_bool_vars.contains(lit.var())) {
|
||||
IF_VERBOSE(10, verbose_stream() << "unit " << lit << "\n");
|
||||
m_ddfw->add(1, &lit);
|
||||
}
|
||||
else {
|
||||
IF_VERBOSE(0, verbose_stream() << "value restriction " << lit << " " << mk_bounded_pp(s.ctx.bool_var2expr(lit.var()), s.ctx.get_manager()) << "\n");
|
||||
}
|
||||
}
|
||||
|
||||
s.m_has_units = false;
|
||||
s.m_units.reset();
|
||||
return true;
|
||||
|
@ -78,7 +83,7 @@ namespace sls {
|
|||
return false;
|
||||
std::lock_guard<std::mutex> lock(s.m_mutex);
|
||||
IF_VERBOSE(3, verbose_stream() << "SMT -> SLS phase\n");
|
||||
for (auto i : m_shared_vars) {
|
||||
for (auto i : m_shared_bool_vars) {
|
||||
if (m_sat_phase[i] != is_true(sat::literal(i, false)))
|
||||
flip(i);
|
||||
m_ddfw->bias(i) = m_sat_phase[i] ? 1 : -1;
|
||||
|
@ -105,7 +110,7 @@ namespace sls {
|
|||
return;
|
||||
m_min_unsat_size = unsat().size();
|
||||
std::lock_guard<std::mutex> lock(s.m_mutex);
|
||||
for (auto v : m_shared_vars) {
|
||||
for (auto v : m_shared_bool_vars) {
|
||||
m_rewards[v] = m_ddfw->get_reward_avg(v);
|
||||
m_sls_phase[v] = l_true == m_ddfw->get_model()[v];
|
||||
m_has_new_sls_phase = true;
|
||||
|
@ -123,7 +128,16 @@ namespace sls {
|
|||
}
|
||||
m_has_new_sls_values = true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void add_uninterp(expr* smt_t) {
|
||||
auto sync_t = m_smt2sync_tr(smt_t);
|
||||
auto sls_t = m_smt2sls_tr(smt_t);
|
||||
m_sync_uninterp.push_back(sync_t);
|
||||
m_smt2sync_uninterp.insert(smt_t, sync_t);
|
||||
m_sls2sync_uninterp.insert(sls_t, sync_t);
|
||||
}
|
||||
|
||||
public:
|
||||
smt_plugin(ast_manager& m, solver& s, sat::ddfw* d) :
|
||||
|
@ -135,7 +149,7 @@ namespace sls {
|
|||
{
|
||||
}
|
||||
|
||||
uint_set m_shared_vars;
|
||||
uint_set m_shared_bool_vars, m_shared_terms;
|
||||
svector<bool> m_sat_phase;
|
||||
std::atomic<bool> m_has_new_sat_phase = false;
|
||||
|
||||
|
@ -144,19 +158,19 @@ namespace sls {
|
|||
|
||||
svector<double> m_rewards;
|
||||
|
||||
void add_uninterp(expr* smt_t) {
|
||||
auto sync_t = m_smt2sync_tr(smt_t);
|
||||
auto sls_t = m_smt2sls_tr(smt_t);
|
||||
m_sync_uninterp.push_back(sync_t);
|
||||
m_smt2sync_uninterp.insert(smt_t, sync_t);
|
||||
m_sls2sync_uninterp.insert(sls_t, sync_t);
|
||||
|
||||
|
||||
void add_shared_term(expr* t) {
|
||||
m_shared_terms.insert(t->get_id());
|
||||
if (is_uninterp(t))
|
||||
add_uninterp(t);
|
||||
}
|
||||
|
||||
void add_shared_var(sat::bool_var v) {
|
||||
m_sls_phase.reserve(v + 1);
|
||||
m_sat_phase.reserve(v + 1);
|
||||
m_rewards.reserve(v + 1);
|
||||
m_shared_vars.insert(v);
|
||||
m_shared_bool_vars.insert(v);
|
||||
}
|
||||
|
||||
void init_search() override {}
|
||||
|
@ -250,7 +264,7 @@ namespace sls {
|
|||
IF_VERBOSE(3, verbose_stream() << "new SMT -> SLS phase\n");
|
||||
s.s().set_has_new_best_phase(false);
|
||||
std::lock_guard<std::mutex> lock(s.m_mutex);
|
||||
for (auto v : m_shared_vars)
|
||||
for (auto v : m_shared_bool_vars)
|
||||
m_sat_phase[v] = s.s().get_best_phase(v);
|
||||
}
|
||||
|
||||
|
@ -258,6 +272,22 @@ namespace sls {
|
|||
// TODO
|
||||
}
|
||||
|
||||
// determine if unit literal restricts values of shared subterms.
|
||||
bool is_value_restriction(sat::literal lit) {
|
||||
auto e = s.ctx.bool_var2expr(lit.var());
|
||||
expr* t = nullptr;
|
||||
if (!e)
|
||||
return false;
|
||||
bv_util bv(s.ctx.get_manager());
|
||||
if (bv.is_bit2bool(e, t) && m_shared_terms.contains(t->get_id())) {
|
||||
verbose_stream() << "shared bit2bool " << mk_bounded_pp(e, s.ctx.get_manager()) << "\n";
|
||||
return true;
|
||||
}
|
||||
|
||||
// if arith.is_le(e, s, t) && t is a numeral, s is shared-term....
|
||||
return false;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
void solver::finalize() {
|
||||
|
@ -287,12 +317,12 @@ namespace sls {
|
|||
return;
|
||||
for (; m_trail_lim < s().init_trail_size(); ++m_trail_lim) {
|
||||
auto lit = s().trail_literal(m_trail_lim);
|
||||
if (!m_smt_plugin->m_shared_vars.contains(lit.var()))
|
||||
continue;
|
||||
IF_VERBOSE(10, verbose_stream() << "push unit " << lit << " " << mk_bounded_pp(ctx.literal2expr(lit), m) << "\n");
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
m_units.push_back(lit);
|
||||
m_has_units = true;
|
||||
if (m_smt_plugin->is_value_restriction(lit) ||
|
||||
m_smt_plugin->m_shared_bool_vars.contains(lit.var())) {
|
||||
std::lock_guard<std::mutex> lock(m_mutex);
|
||||
m_units.push_back(lit);
|
||||
m_has_units = true;
|
||||
}
|
||||
}
|
||||
if (s().has_new_best_phase())
|
||||
m_smt_plugin->import_phase_from_smt();
|
||||
|
@ -329,15 +359,14 @@ namespace sls {
|
|||
for (auto lit : clause)
|
||||
m_smt_plugin->add_shared_var(lit.var());
|
||||
}
|
||||
for (auto v : m_smt_plugin->m_shared_vars) {
|
||||
for (auto v : m_smt_plugin->m_shared_bool_vars) {
|
||||
expr* e = ctx.bool_var2expr(v);
|
||||
if (!e)
|
||||
continue;
|
||||
m_smt_plugin->register_atom(v, tr(e));
|
||||
for (auto t : subterms::all(expr_ref(e, m))) {
|
||||
if (is_uninterp(t))
|
||||
m_smt_plugin->add_uninterp(t);
|
||||
}
|
||||
for (auto t : subterms::all(expr_ref(e, m)))
|
||||
m_smt_plugin->add_shared_term(e);
|
||||
|
||||
}
|
||||
|
||||
m_thread = std::thread([this]() { run_local_search_async(); });
|
||||
|
|
|
@ -66,6 +66,7 @@ z3_add_component(smt
|
|||
theory_pb.cpp
|
||||
theory_recfun.cpp
|
||||
theory_seq.cpp
|
||||
theory_sls.cpp
|
||||
theory_special_relations.cpp
|
||||
theory_str.cpp
|
||||
theory_str_mc.cpp
|
||||
|
|
|
@ -37,6 +37,7 @@ Revision History:
|
|||
#include "smt/uses_theory.h"
|
||||
#include "smt/theory_special_relations.h"
|
||||
#include "smt/theory_polymorphism.h"
|
||||
#include "smt/theory_sls.h"
|
||||
#include "smt/smt_for_each_relevant_expr.h"
|
||||
#include "smt/smt_model_generator.h"
|
||||
#include "smt/smt_model_checker.h"
|
||||
|
@ -3506,6 +3507,10 @@ namespace smt {
|
|||
if (r == l_true && get_cancel_flag()) {
|
||||
r = l_undef;
|
||||
}
|
||||
if (r == l_undef && get_cancel_flag() && has_sls_model()) {
|
||||
m.limit().reset_cancel();
|
||||
r = l_true;
|
||||
}
|
||||
if (r == l_true && gparams::get_value("model_validate") == "true") {
|
||||
recfun::util u(m);
|
||||
if (u.get_rec_funs().empty() && m_proto_model) {
|
||||
|
@ -3581,6 +3586,20 @@ namespace smt {
|
|||
return r;
|
||||
}
|
||||
|
||||
bool context::has_sls_model() {
|
||||
if (!m_fparams.m_sls_enable)
|
||||
return false;
|
||||
auto tid = m.get_family_id("sls");
|
||||
auto p = m_theories.get_plugin(tid);
|
||||
if (!p)
|
||||
return false;
|
||||
auto mdl = dynamic_cast<theory_sls*>(p)->get_model();
|
||||
if (!mdl)
|
||||
return false;
|
||||
m_model = mdl;
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
\brief Setup the logical context based on the current set of
|
||||
asserted formulas and execute the check command.
|
||||
|
|
|
@ -619,6 +619,9 @@ namespace smt {
|
|||
friend class set_var_theory_trail;
|
||||
void set_var_theory(bool_var v, theory_id tid);
|
||||
|
||||
|
||||
bool has_sls_model();
|
||||
|
||||
// -----------------------------------
|
||||
//
|
||||
// Backtracking support
|
||||
|
|
|
@ -35,6 +35,7 @@ Revision History:
|
|||
#include "smt/theory_seq.h"
|
||||
#include "smt/theory_char.h"
|
||||
#include "smt/theory_special_relations.h"
|
||||
#include "smt/theory_sls.h"
|
||||
#include "smt/theory_pb.h"
|
||||
#include "smt/theory_fpa.h"
|
||||
#include "smt/theory_str.h"
|
||||
|
@ -67,6 +68,7 @@ namespace smt {
|
|||
case CFG_AUTO: setup_auto_config(); break;
|
||||
}
|
||||
setup_card();
|
||||
setup_sls();
|
||||
}
|
||||
|
||||
void setup::setup_default() {
|
||||
|
@ -766,6 +768,11 @@ namespace smt {
|
|||
m_context.register_plugin(alloc(theory_pb, m_context));
|
||||
}
|
||||
|
||||
void setup::setup_sls() {
|
||||
if (m_params.m_sls_enable)
|
||||
m_context.register_plugin(alloc(theory_sls, m_context));
|
||||
}
|
||||
|
||||
void setup::setup_fpa() {
|
||||
setup_bv();
|
||||
m_context.register_plugin(alloc(theory_fpa, m_context));
|
||||
|
|
|
@ -103,6 +103,7 @@ namespace smt {
|
|||
void setup_seq();
|
||||
void setup_char();
|
||||
void setup_card();
|
||||
void setup_sls();
|
||||
void setup_i_arith();
|
||||
void setup_mi_arith();
|
||||
void setup_lra_arith();
|
||||
|
|
|
@ -3254,7 +3254,7 @@ public:
|
|||
tout << "@" << ctx().get_scope_level() << (is_conflict ? " conflict":" lemma");
|
||||
for (auto const& p : m_params) tout << " " << p;
|
||||
tout << "\n";
|
||||
display_evidence(tout, m_explanation););
|
||||
display_evidence(tout << core << " ", m_explanation););
|
||||
for (auto ev : m_explanation)
|
||||
set_evidence(ev.ci(), m_core, m_eqs);
|
||||
|
||||
|
|
136
src/smt/theory_sls.cpp
Normal file
136
src/smt/theory_sls.cpp
Normal file
|
@ -0,0 +1,136 @@
|
|||
/*++
|
||||
Copyright (c) 2020 Microsoft Corporation
|
||||
|
||||
Module Name:
|
||||
|
||||
theory_sls
|
||||
|
||||
Abstract:
|
||||
|
||||
Interface to Concurrent SLS solver
|
||||
|
||||
Author:
|
||||
|
||||
Nikolaj Bjorner (nbjorner) 2024-02-21
|
||||
|
||||
|
||||
--*/
|
||||
|
||||
#include "smt/theory_sls.h"
|
||||
#include "smt/smt_context.h"
|
||||
#include "ast/sls/sls_context.h"
|
||||
#include "ast/for_each_expr.h"
|
||||
|
||||
namespace smt {
|
||||
|
||||
#ifdef SINGLE_THREAD
|
||||
|
||||
|
||||
#else
|
||||
theory_sls::theory_sls(smt::context& ctx):
|
||||
theory(ctx, ctx.get_manager().mk_family_id("sls"))
|
||||
{}
|
||||
|
||||
theory_sls::~theory_sls() {
|
||||
finalize();
|
||||
}
|
||||
|
||||
params_ref theory_sls::get_params() {
|
||||
return ctx.get_params();
|
||||
}
|
||||
|
||||
void theory_sls::initialize_value(expr* t, expr* v) {
|
||||
ctx.user_propagate_initialize_value(t, v);
|
||||
}
|
||||
|
||||
void theory_sls::force_phase(sat::literal lit) {
|
||||
ctx.force_phase(lit);
|
||||
}
|
||||
|
||||
void theory_sls::set_has_new_best_phase(bool b) {
|
||||
|
||||
}
|
||||
|
||||
bool theory_sls::get_best_phase(sat::bool_var v) {
|
||||
return false;
|
||||
}
|
||||
|
||||
expr* theory_sls::bool_var2expr(sat::bool_var v) {
|
||||
return ctx.bool_var2expr(v);
|
||||
}
|
||||
|
||||
void theory_sls::set_finished() {
|
||||
m.limit().cancel();
|
||||
}
|
||||
|
||||
unsigned theory_sls::get_num_bool_vars() const {
|
||||
return ctx.get_num_bool_vars();
|
||||
}
|
||||
|
||||
void theory_sls::finalize() {
|
||||
if (!m_smt_plugin)
|
||||
return;
|
||||
|
||||
m_smt_plugin->finalize(m_model, m_st);
|
||||
m_model = nullptr;
|
||||
m_smt_plugin = nullptr;
|
||||
}
|
||||
|
||||
void theory_sls::propagate() {
|
||||
if (m_smt_plugin && !m_checking) {
|
||||
expr_ref_vector fmls(m);
|
||||
for (unsigned i = 0; i < ctx.get_num_asserted_formulas(); ++i)
|
||||
fmls.push_back(ctx.get_asserted_formula(i));
|
||||
m_checking = true;
|
||||
m_smt_plugin->check(fmls);
|
||||
return;
|
||||
}
|
||||
if (!m_smt_plugin)
|
||||
return;
|
||||
if (!m_smt_plugin->completed())
|
||||
return;
|
||||
m_smt_plugin->finalize(m_model, m_st);
|
||||
m_smt_plugin = nullptr;
|
||||
}
|
||||
|
||||
void theory_sls::pop_scope_eh(unsigned n) {
|
||||
if (!m_smt_plugin)
|
||||
return;
|
||||
|
||||
unsigned scope_lvl = ctx.get_scope_level();
|
||||
if (ctx.get_search_level() == scope_lvl - n) {
|
||||
auto& lits = ctx.assigned_literals();
|
||||
for (; m_trail_lim < lits.size() && ctx.get_assign_level(lits[m_trail_lim]) == scope_lvl; ++m_trail_lim) {
|
||||
auto lit = lits[m_trail_lim];
|
||||
m_smt_plugin->add_unit(lit);
|
||||
}
|
||||
}
|
||||
#if 0
|
||||
if (ctx.has_new_best_phase())
|
||||
m_smt_plugin->import_phase_from_smt();
|
||||
|
||||
#endif
|
||||
|
||||
m_smt_plugin->import_from_sls();
|
||||
}
|
||||
|
||||
|
||||
void theory_sls::init() {
|
||||
if (m_smt_plugin)
|
||||
finalize();
|
||||
m_smt_plugin = alloc(sls::smt_plugin, *this);
|
||||
m_checking = false;
|
||||
}
|
||||
|
||||
void theory_sls::collect_statistics(::statistics& st) const {
|
||||
st.copy(m_st);
|
||||
}
|
||||
|
||||
void theory_sls::display(std::ostream& out) const {
|
||||
out << "theory-sls\n";
|
||||
}
|
||||
|
||||
|
||||
|
||||
#endif
|
||||
}
|
79
src/smt/theory_sls.h
Normal file
79
src/smt/theory_sls.h
Normal file
|
@ -0,0 +1,79 @@
|
|||
/*++
|
||||
Copyright (c) 2020 Microsoft Corporation
|
||||
|
||||
Module Name:
|
||||
|
||||
theory_sls
|
||||
|
||||
Abstract:
|
||||
|
||||
Interface to Concurrent SLS solver
|
||||
|
||||
Author:
|
||||
|
||||
Nikolaj Bjorner (nbjorner) 2024-02-21
|
||||
|
||||
--*/
|
||||
#pragma once
|
||||
|
||||
|
||||
#include "util/rlimit.h"
|
||||
#include "ast/sls/sat_ddfw.h"
|
||||
#include "smt/smt_theory.h"
|
||||
#include "model/model.h"
|
||||
|
||||
|
||||
#ifdef SINGLE_THREAD
|
||||
|
||||
|
||||
namespace sls {
|
||||
|
||||
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
#include "ast/sls/sls_smt_plugin.h"
|
||||
class context;
|
||||
|
||||
namespace smt {
|
||||
class theory_sls : public smt::theory, public sls::smt_context {
|
||||
model_ref m_model;
|
||||
sls::smt_plugin* m_smt_plugin = nullptr;
|
||||
unsigned m_trail_lim = 0;
|
||||
bool m_checking = false;
|
||||
::statistics m_st;
|
||||
|
||||
void finalize();
|
||||
|
||||
public:
|
||||
theory_sls(context& ctx);
|
||||
~theory_sls() override;
|
||||
model_ref get_model() { return m_model; }
|
||||
char const* get_name() const override { return "sls"; }
|
||||
void init() override;
|
||||
void pop_scope_eh(unsigned n) override;
|
||||
smt::theory* mk_fresh(context* new_ctx) override { return alloc(theory_sls, *new_ctx); }
|
||||
void collect_statistics(::statistics& st) const override;
|
||||
void propagate() override;
|
||||
void display(std::ostream& out) const override;
|
||||
bool internalize_atom(app * atom, bool gate_ctx) override { return false; }
|
||||
bool internalize_term(app* term) override { return false; }
|
||||
void new_eq_eh(theory_var v1, theory_var v2) override {}
|
||||
void new_diseq_eh(theory_var v1, theory_var v2) override {}
|
||||
|
||||
ast_manager& get_manager() override { return m; }
|
||||
params_ref get_params() override;
|
||||
void initialize_value(expr* t, expr* v) override;
|
||||
void force_phase(sat::literal lit) override;
|
||||
void set_has_new_best_phase(bool b) override;
|
||||
bool get_best_phase(sat::bool_var v) override;
|
||||
expr* bool_var2expr(sat::bool_var v) override;
|
||||
void set_finished() override;
|
||||
unsigned get_num_bool_vars() const override;
|
||||
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
#endif
|
Loading…
Add table
Add a link
Reference in a new issue