3
0
Fork 0
mirror of https://github.com/Z3Prover/z3 synced 2025-04-06 01:24:08 +00:00

sls updates

- add SINGLE_THREAD mode
- add interface to retrieve "best" model so far
This commit is contained in:
Nikolaj Bjorner 2024-04-13 16:42:26 +02:00
parent 43dd6a5436
commit 2682c2ef2b
5 changed files with 128 additions and 66 deletions

View file

@ -61,6 +61,17 @@ namespace bv {
}
}
void sls::set_model() {
if (!m_set_model)
return;
if (m_repair_roots.size() >= m_min_repair_size)
return;
m_min_repair_size = m_repair_roots.size();
IF_VERBOSE(2, verbose_stream() << "(sls-update-model :num-unsat " << m_min_repair_size << ")\n");
m_set_model(*get_model());
}
void sls::init_repair_goal(app* t) {
m_eval.init_eval(t);
}
@ -94,6 +105,9 @@ namespace bv {
if (m_to_repair.empty())
return;
// refresh the best model so far to a callback
set_model();
// add fresh units, if any
bool new_assertion = false;
while (m_get_unit) {
@ -130,7 +144,7 @@ namespace bv {
return m_rand() % 2 == 0;
};
m_eval.init_eval(m_terms.assertions(), eval);
init_repair();
init_repair();
// m_engine_init = false;
}
@ -295,10 +309,12 @@ namespace bv {
model_ref mdl = alloc(model, m);
auto& terms = m_eval.sort_assertions(m_terms.assertions());
for (expr* e : terms) {
#if 0
if (!m_eval.re_eval_is_correct(to_app(e))) {
verbose_stream() << "missed evaluation #" << e->get_id() << " " << mk_bounded_pp(e, m) << "\n";
m_eval.display_value(verbose_stream(), e) << "\n";
}
#endif
if (!is_uninterp_const(e))
continue;

View file

@ -54,10 +54,13 @@ namespace bv {
bool m_engine_model = false;
bool m_engine_init = false;
std::function<expr_ref()> m_get_unit;
std::function<void(model& mdl)> m_set_model;
unsigned m_min_repair_size = UINT_MAX;
std::pair<bool, app*> next_to_repair();
void init_repair_goal(app* e);
void set_model();
void try_repair_down(app* e);
void try_repair_up(app* e);
void set_repair_down(expr* e) { m_repair_down = e->get_id(); }
@ -96,6 +99,11 @@ namespace bv {
*/
void init_unit(std::function<expr_ref()> get_unit) { m_get_unit = get_unit; }
/**
* Add callback to set model
*/
void set_model(std::function<void(model& mdl)> f) { m_set_model = f; }
/**
* Run (bounded) local search to find feasible assignments.
*/

View file

@ -1069,7 +1069,7 @@ namespace intblast {
if (e->get_family_id() != bv.get_family_id())
return false;
for (euf::enode* arg : euf::enode_args(n))
dep.add(n, arg->get_root());
dep.add(n, arg);
return true;
}

View file

@ -22,31 +22,37 @@ Author:
namespace sls {
#ifdef SINGLE_THREAD
solver::solver(euf::solver& ctx) :
th_euf_solver(ctx, symbol("sls"), ctx.get_manager().mk_family_id("sls"))
{}
#else
solver::solver(euf::solver& ctx):
th_euf_solver(ctx, symbol("sls"), ctx.get_manager().mk_family_id("sls")),
m_units(m) {}
th_euf_solver(ctx, symbol("sls"), ctx.get_manager().mk_family_id("sls"))
{}
solver::~solver() {
finalize();
}
void solver::finalize() {
if (!m_completed && m_bvsls) {
m_bvsls->cancel();
if (!m_completed && m_sls) {
m_sls->cancel();
m_thread.join();
m_bvsls->collect_statistics(m_st);
m_bvsls = nullptr;
m_sls->collect_statistics(m_st);
m_sls = nullptr;
m_shared = nullptr;
m_slsm = nullptr;
m_units = nullptr;
}
}
sat::check_result solver::check() {
return sat::check_result::CR_DONE;
}
void solver::simplify() {
}
bool solver::unit_propagate() {
force_push();
sample_local_search();
@ -66,10 +72,6 @@ namespace sls {
return false;
}
void solver::push_core() {
}
void solver::pop_core(unsigned n) {
for (; m_trail_lim < s().init_trail_size(); ++m_trail_lim) {
auto lit = s().trail_literal(m_trail_lim);
@ -77,60 +79,63 @@ namespace sls {
if (is_unit(e)) {
// IF_VERBOSE(1, verbose_stream() << "add unit " << mk_pp(e, m) << "\n");
std::lock_guard<std::mutex> lock(m_mutex);
m_units.push_back(e);
ast_translation tr(m, *m_shared);
m_units->push_back(tr(e.get()));
m_has_units = true;
}
}
}
void solver::init_search() {
init_local_search();
}
}
void solver::init_local_search() {
if (m_bvsls) {
m_bvsls->cancel();
void solver::init_search() {
if (m_sls) {
m_sls->cancel();
m_thread.join();
m_result = l_undef;
m_completed = false;
m_has_units = false;
m_model = nullptr;
m_units.reset();
m_units = nullptr;
}
// set up state for local search solver here
m_m = alloc(ast_manager, m);
ast_translation tr(m, *m_m);
m_shared = alloc(ast_manager);
m_slsm = alloc(ast_manager);
m_units = alloc(expr_ref_vector, *m_shared);
ast_translation tr(m, *m_slsm);
params_ref p;
m_completed = false;
m_result = l_undef;
m_model = nullptr;
m_bvsls = alloc(bv::sls, *m_m, p);
m_sls = alloc(bv::sls, *m_slsm, s().params());
for (expr* a : ctx.get_assertions())
m_bvsls->assert_expr(tr(a));
m_sls->assert_expr(tr(a));
std::function<bool(expr*, unsigned)> eval = [&](expr* e, unsigned r) {
return false;
};
m_bvsls->init();
m_bvsls->init_eval(eval);
m_bvsls->updt_params(s().params());
m_bvsls->init_unit([&]() {
m_sls->init();
m_sls->init_eval(eval);
m_sls->updt_params(s().params());
m_sls->init_unit([&]() {
if (!m_has_units)
return expr_ref(*m_m);
expr_ref e(m);
return expr_ref(*m_slsm);
expr_ref e(*m_slsm);
{
std::lock_guard<std::mutex> lock(m_mutex);
if (m_units.empty())
return expr_ref(*m_m);
e = m_units.back();
m_units.pop_back();
if (m_units->empty())
return expr_ref(*m_slsm);
ast_translation tr(*m_shared, *m_slsm);
e = tr(m_units->back());
m_units->pop_back();
}
ast_translation tr(m, *m_m);
return expr_ref(tr(e.get()), *m_m);
return e;
});
m_sls->set_model([&](model& mdl) {
std::lock_guard<std::mutex> lock(m_mutex);
ast_translation tr(*m_shared, m);
m_model = mdl.translate(tr);
});
m_thread = std::thread([this]() { run_local_search(); });
@ -141,20 +146,21 @@ namespace sls {
return;
m_thread.join();
m_completed = false;
m_bvsls->collect_statistics(m_st);
m_sls->collect_statistics(m_st);
if (m_result == l_true) {
IF_VERBOSE(2, verbose_stream() << "(sat.sls :model-completed)\n";);
auto mdl = m_bvsls->get_model();
ast_translation tr(*m_m, m);
auto mdl = m_sls->get_model();
ast_translation tr(*m_slsm, m);
m_model = mdl->translate(tr);
s().set_canceled();
}
m_bvsls = nullptr;
m_sls = nullptr;
}
void solver::run_local_search() {
m_result = (*m_bvsls)();
m_result = (*m_sls)();
m_completed = true;
}
#endif
}

View file

@ -16,13 +16,45 @@ Author:
--*/
#pragma once
#include <thread>
#include <mutex>
#include "util/rlimit.h"
#include "ast/sls/bv_sls.h"
#include "sat/smt/sat_th.h"
#ifdef SINGLE_THREAD
namespace euf {
class solver;
}
namespace sls {
class solver : public euf::th_euf_solver {
public:
solver(euf::solver& ctx);
sat::literal internalize(expr* e, bool sign, bool root) override { UNREACHABLE(); return sat::null_literal; }
void internalize(expr* e) override { UNREACHABLE(); }
th_solver* clone(euf::solver& ctx) override { return alloc(solver, ctx); }
model_ref get_model() { return model_ref(nullptr); }
bool unit_propagate() override { return false; }
void get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector& r, bool probing) override { UNREACHABLE(); }
sat::check_result check() override { return sat::check_result::CR_DONE;}
std::ostream& display(std::ostream& out) const override { return out; }
std::ostream& display_justification(std::ostream& out, sat::ext_justification_idx idx) const override { UNREACHABLE(); return out; }
std::ostream& display_constraint(std::ostream& out, sat::ext_constraint_idx idx) const override { UNREACHABLE(); return out; }
};
}
#else
#include <thread>
#include <mutex>
namespace euf {
class solver;
}
@ -34,38 +66,36 @@ namespace sls {
std::atomic<bool> m_completed, m_has_units;
std::thread m_thread;
std::mutex m_mutex;
scoped_ptr<ast_manager> m_m;
scoped_ptr<bv::sls> m_bvsls;
// m is accessed by the main thread
// m_slsm is accessed by the sls thread
// m_shared is only accessed at synchronization points
scoped_ptr<ast_manager> m_shared, m_slsm;
scoped_ptr<bv::sls> m_sls;
scoped_ptr<expr_ref_vector> m_units;
model_ref m_model;
unsigned m_trail_lim = 0;
expr_ref_vector m_units;
statistics m_st;
void run_local_search();
void init_local_search();
void sample_local_search();
bool is_unit(expr*);
public:
solver(euf::solver& ctx);
~solver();
void simplify() override;
void init_search() override;
model_ref get_model() { return m_model; }
void push_core() override;
void init_search() override;
void push_core() override {}
void pop_core(unsigned n) override;
th_solver* clone(euf::solver& ctx) override { return alloc(solver, ctx); }
void collect_statistics(statistics& st) const override { st.copy(m_st); }
void finalize() override;
bool unit_propagate() override;
sat::literal internalize(expr* e, bool sign, bool root) override { UNREACHABLE(); return sat::null_literal; }
void internalize(expr* e) override { UNREACHABLE(); }
th_solver* clone(euf::solver& ctx) override { return alloc(solver, ctx); }
void collect_statistics(statistics& st) const override { st.copy(m_st); }
model_ref get_model() { return m_model; }
void finalize() override;
bool unit_propagate() override;
void get_antecedents(sat::literal l, sat::ext_justification_idx idx, sat::literal_vector & r, bool probing) override { UNREACHABLE(); }
sat::check_result check() override;
std::ostream & display(std::ostream & out) const override { return out; }
@ -75,3 +105,5 @@ namespace sls {
};
}
#endif