mirror of
https://github.com/Z3Prover/z3
synced 2025-10-01 13:39:28 +00:00
Par (#7945)
* port parallel Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com> * updates Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com> * update smt-parallel Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com> * cleanup Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com> * neat Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com> * configuration parameter renaming Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com> * config parameters Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com> --------- Signed-off-by: Nikolaj Bjorner <nbjorner@microsoft.com>
This commit is contained in:
parent
2b5b985492
commit
ce53e06e29
7 changed files with 1006 additions and 228 deletions
|
@ -33,6 +33,7 @@ Author:
|
|||
#include "util/statistics.h"
|
||||
#include "util/params.h"
|
||||
#include "util/z3_exception.h"
|
||||
#include "ast/ast_util.h"
|
||||
#include "ast/converters/model_converter.h"
|
||||
#include "ast/simplifiers/dependent_expr.h"
|
||||
#include "ast/simplifiers/model_reconstruction_trail.h"
|
||||
|
@ -113,9 +114,80 @@ public:
|
|||
model_reconstruction_trail& model_trail() override { throw default_exception("unexpected access to model reconstruction"); }
|
||||
bool updated() override { return false; }
|
||||
void reset_updated() override {}
|
||||
|
||||
};
|
||||
|
||||
|
||||
struct base_dependent_expr_state : public dependent_expr_state {
|
||||
ast_manager& m;
|
||||
model_reconstruction_trail m_reconstruction_trail;
|
||||
bool m_updated = false;
|
||||
bool m_inconsistent = false;
|
||||
vector<dependent_expr> m_fmls;
|
||||
base_dependent_expr_state(ast_manager& m) :dependent_expr_state(m), m(m), m_reconstruction_trail(m, m_trail) {}
|
||||
unsigned qtail() const override { return m_fmls.size(); }
|
||||
dependent_expr const& operator[](unsigned i) override { return m_fmls[i]; }
|
||||
void update(unsigned i, dependent_expr const& j) override {
|
||||
SASSERT(j.fml());
|
||||
check_false(j.fml());
|
||||
m_fmls[i] = j;
|
||||
m_updated = true;
|
||||
}
|
||||
void add(dependent_expr const& j) override { m_updated = true; check_false(j.fml()); m_fmls.push_back(j); }
|
||||
bool inconsistent() override { return m_inconsistent; }
|
||||
bool updated() override { return m_updated; }
|
||||
void reset_updated() override { m_updated = false; }
|
||||
model_reconstruction_trail& model_trail() override { return m_reconstruction_trail; }
|
||||
std::ostream& display(std::ostream& out) const override {
|
||||
unsigned i = 0;
|
||||
for (auto const& d : m_fmls) {
|
||||
if (i > 0 && i == qhead())
|
||||
out << "---- head ---\n";
|
||||
out << d << "\n";
|
||||
++i;
|
||||
}
|
||||
m_reconstruction_trail.display(out);
|
||||
return out;
|
||||
}
|
||||
void check_false(expr* f) {
|
||||
if (m.is_false(f))
|
||||
m_inconsistent = true;
|
||||
}
|
||||
void replay(unsigned qhead, expr_ref_vector& assumptions) {
|
||||
m_reconstruction_trail.replay(qhead, assumptions, *this);
|
||||
}
|
||||
void flatten_suffix() override {
|
||||
expr_mark seen;
|
||||
unsigned j = qhead();
|
||||
expr_ref_vector pinned(m);
|
||||
for (unsigned i = qhead(); i < qtail(); ++i) {
|
||||
expr* f = m_fmls[i].fml(), * g = nullptr;
|
||||
pinned.push_back(f);
|
||||
if (seen.is_marked(f))
|
||||
continue;
|
||||
seen.mark(f, true);
|
||||
if (m.is_true(f))
|
||||
continue;
|
||||
if (m.is_and(f)) {
|
||||
auto* d = m_fmls[i].dep();
|
||||
for (expr* arg : *to_app(f))
|
||||
add(dependent_expr(m, arg, nullptr, d));
|
||||
continue;
|
||||
}
|
||||
if (m.is_not(f, g) && m.is_or(g)) {
|
||||
auto* d = m_fmls[i].dep();
|
||||
for (expr* arg : *to_app(g))
|
||||
add(dependent_expr(m, mk_not(m, arg), nullptr, d));
|
||||
continue;
|
||||
}
|
||||
if (i != j)
|
||||
m_fmls[j] = m_fmls[i];
|
||||
++j;
|
||||
}
|
||||
m_fmls.shrink(j);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& out, dependent_expr_state& st) {
|
||||
return st.display(out);
|
||||
}
|
||||
|
|
|
@ -4751,6 +4751,11 @@ namespace smt {
|
|||
}
|
||||
mdl = m_model.get();
|
||||
}
|
||||
if (m_fmls && mdl) {
|
||||
auto convert = m_fmls->model_trail().get_model_converter();
|
||||
if (convert)
|
||||
(*convert)(mdl);
|
||||
}
|
||||
}
|
||||
|
||||
void context::get_levels(ptr_vector<expr> const& vars, unsigned_vector& depth) {
|
||||
|
|
|
@ -19,6 +19,7 @@ Revision History:
|
|||
#pragma once
|
||||
|
||||
#include "ast/quantifier_stat.h"
|
||||
#include "ast/simplifiers/dependent_expr_state.h"
|
||||
#include "smt/smt_clause.h"
|
||||
#include "smt/smt_setup.h"
|
||||
#include "smt/smt_enode.h"
|
||||
|
@ -132,6 +133,11 @@ namespace smt {
|
|||
bool m_internalizing_assertions = false;
|
||||
lbool m_internal_completed = l_undef;
|
||||
|
||||
scoped_ptr<dependent_expr_simplifier> m_simplifier;
|
||||
scoped_ptr<base_dependent_expr_state> m_fmls;
|
||||
|
||||
svector<double> m_lit_scores[2];
|
||||
|
||||
|
||||
// -----------------------------------
|
||||
//
|
||||
|
@ -1292,6 +1298,8 @@ namespace smt {
|
|||
|
||||
virtual bool resolve_conflict();
|
||||
|
||||
void add_scores(unsigned n, literal const *lits);
|
||||
|
||||
|
||||
// -----------------------------------
|
||||
//
|
||||
|
|
|
@ -933,6 +933,10 @@ namespace smt {
|
|||
m_activity.reserve(v+1);
|
||||
m_bool_var2expr.reserve(v+1);
|
||||
m_bool_var2expr[v] = n;
|
||||
m_lit_scores[0].reserve(v + 1);
|
||||
m_lit_scores[1].reserve(v + 1);
|
||||
m_lit_scores[0][v] = m_lit_scores[1][v] = 0.0;
|
||||
|
||||
literal l(v, false);
|
||||
literal not_l(v, true);
|
||||
unsigned aux = std::max(l.index(), not_l.index()) + 1;
|
||||
|
@ -961,6 +965,15 @@ namespace smt {
|
|||
return v;
|
||||
}
|
||||
|
||||
void context::add_scores(unsigned n, literal const *lits) {
|
||||
for (unsigned i = 0; i < n; ++i) {
|
||||
auto lit = lits[i];
|
||||
unsigned v = lit.var(); // unique key per literal
|
||||
m_lit_scores[lit.sign()][v] += 1.0 / n;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void context::undo_mk_bool_var() {
|
||||
SASSERT(!m_b_internalized_stack.empty());
|
||||
m_stats.m_num_del_bool_var++;
|
||||
|
@ -1419,6 +1432,7 @@ namespace smt {
|
|||
break;
|
||||
case CLS_LEARNED:
|
||||
dump_lemma(num_lits, lits);
|
||||
add_scores(num_lits, lits);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
|
|
|
@ -12,17 +12,39 @@ Abstract:
|
|||
Author:
|
||||
|
||||
nbjorner 2020-01-31
|
||||
Ilana Shapiro 2025
|
||||
|
||||
--*/
|
||||
|
||||
|
||||
#include "util/scoped_ptr_vector.h"
|
||||
#include "ast/ast_util.h"
|
||||
#include "ast/ast_pp.h"
|
||||
#include "ast/ast_ll_pp.h"
|
||||
#include "ast/ast_translation.h"
|
||||
#include "ast/simplifiers/then_simplifier.h"
|
||||
#include "smt/smt_parallel.h"
|
||||
#include "smt/smt_lookahead.h"
|
||||
#include "solver/solver_preprocess.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <mutex>
|
||||
|
||||
class bounded_pp_exprs {
|
||||
expr_ref_vector const &es;
|
||||
|
||||
public:
|
||||
bounded_pp_exprs(expr_ref_vector const &es) : es(es) {}
|
||||
|
||||
std::ostream &display(std::ostream &out) const {
|
||||
for (auto e : es)
|
||||
out << mk_bounded_pp(e, es.get_manager()) << "\n";
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
inline std::ostream &operator<<(std::ostream &out, bounded_pp_exprs const &pp) {
|
||||
return pp.display(out);
|
||||
}
|
||||
|
||||
#ifdef SINGLE_THREAD
|
||||
|
||||
|
@ -31,243 +53,487 @@ namespace smt {
|
|||
lbool parallel::operator()(expr_ref_vector const &asms) {
|
||||
return l_undef;
|
||||
}
|
||||
}
|
||||
} // namespace smt
|
||||
|
||||
#else
|
||||
|
||||
#include <thread>
|
||||
|
||||
#define LOG_WORKER(lvl, s) IF_VERBOSE(lvl, verbose_stream() << "Worker " << id << s)
|
||||
|
||||
namespace smt {
|
||||
|
||||
lbool parallel::operator()(expr_ref_vector const& asms) {
|
||||
void parallel::worker::run() {
|
||||
search_tree::node<cube_config> *node = nullptr;
|
||||
expr_ref_vector cube(m);
|
||||
while (true) {
|
||||
|
||||
lbool result = l_undef;
|
||||
unsigned num_threads = std::min((unsigned) std::thread::hardware_concurrency(), ctx.get_fparams().m_threads);
|
||||
flet<unsigned> _nt(ctx.m_fparams.m_threads, 1);
|
||||
unsigned thread_max_conflicts = ctx.get_fparams().m_threads_max_conflicts;
|
||||
unsigned max_conflicts = ctx.get_fparams().m_max_conflicts;
|
||||
if (!b.get_cube(m_g2l, id, cube, node)) {
|
||||
LOG_WORKER(1, " no more cubes\n");
|
||||
return;
|
||||
}
|
||||
collect_shared_clauses(m_g2l);
|
||||
|
||||
// try first sequential with a low conflict budget to make super easy problems cheap
|
||||
unsigned max_c = std::min(thread_max_conflicts, 40u);
|
||||
flet<unsigned> _mc(ctx.get_fparams().m_max_conflicts, max_c);
|
||||
result = ctx.check(asms.size(), asms.data());
|
||||
if (result != l_undef || ctx.m_num_conflicts < max_c) {
|
||||
check_cube_start:
|
||||
LOG_WORKER(1, " CUBE SIZE IN MAIN LOOP: " << cube.size() << "\n");
|
||||
lbool r = check_cube(cube);
|
||||
|
||||
if (!m.inc()) {
|
||||
b.set_exception("context cancelled");
|
||||
return;
|
||||
}
|
||||
|
||||
switch (r) {
|
||||
case l_undef: {
|
||||
update_max_thread_conflicts();
|
||||
LOG_WORKER(1, " found undef cube\n");
|
||||
// return unprocessed cubes to the batch manager
|
||||
// add a split literal to the batch manager.
|
||||
// optionally process other cubes and delay sending back unprocessed cubes to batch manager.
|
||||
if (m_config.m_max_cube_depth <= cube.size())
|
||||
goto check_cube_start;
|
||||
|
||||
auto atom = get_split_atom();
|
||||
if (!atom)
|
||||
goto check_cube_start;
|
||||
b.split(m_l2g, id, node, atom);
|
||||
simplify();
|
||||
break;
|
||||
}
|
||||
case l_true: {
|
||||
LOG_WORKER(1, " found sat cube\n");
|
||||
model_ref mdl;
|
||||
ctx->get_model(mdl);
|
||||
b.set_sat(m_l2g, *mdl);
|
||||
return;
|
||||
}
|
||||
case l_false: {
|
||||
expr_ref_vector const &unsat_core = ctx->unsat_core();
|
||||
LOG_WORKER(2, " unsat core:\n";
|
||||
for (auto c : unsat_core) verbose_stream() << mk_bounded_pp(c, m, 3) << "\n");
|
||||
// If the unsat core only contains external assumptions,
|
||||
// unsatisfiability does not depend on the current cube and the entire problem is unsat.
|
||||
if (all_of(unsat_core, [&](expr *e) { return asms.contains(e); })) {
|
||||
LOG_WORKER(1, " determined formula unsat\n");
|
||||
b.set_unsat(m_l2g, unsat_core);
|
||||
return;
|
||||
}
|
||||
// report assumptions used in unsat core, so they can be used in final core
|
||||
for (expr *e : unsat_core)
|
||||
if (asms.contains(e))
|
||||
b.report_assumption_used(m_l2g, e);
|
||||
|
||||
LOG_WORKER(1, " found unsat cube\n");
|
||||
b.backtrack(m_l2g, unsat_core, node);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (m_config.m_share_units)
|
||||
share_units(m_l2g);
|
||||
}
|
||||
}
|
||||
|
||||
parallel::worker::worker(unsigned id, parallel &p, expr_ref_vector const &_asms)
|
||||
: id(id), p(p), b(p.m_batch_manager), m_smt_params(p.ctx.get_fparams()), asms(m), m_g2l(p.ctx.m, m),
|
||||
m_l2g(m, p.ctx.m), m_search_tree(expr_ref(m)) {
|
||||
for (auto e : _asms)
|
||||
asms.push_back(m_g2l(e));
|
||||
LOG_WORKER(1, " created with " << asms.size() << " assumptions\n");
|
||||
m_smt_params.m_preprocess = false;
|
||||
ctx = alloc(context, m, m_smt_params, p.ctx.get_params());
|
||||
context::copy(p.ctx, *ctx, true);
|
||||
ctx->set_random_seed(id + m_smt_params.m_random_seed);
|
||||
// don't share initial units
|
||||
ctx->pop_to_base_lvl();
|
||||
m_num_shared_units = ctx->assigned_literals().size();
|
||||
m_num_initial_atoms = ctx->get_num_bool_vars();
|
||||
}
|
||||
|
||||
void parallel::worker::share_units(ast_translation &l2g) {
|
||||
// Collect new units learned locally by this worker and send to batch manager
|
||||
ctx->pop_to_base_lvl();
|
||||
unsigned sz = ctx->assigned_literals().size();
|
||||
for (unsigned j = m_num_shared_units; j < sz; ++j) { // iterate only over new literals since last sync
|
||||
literal lit = ctx->assigned_literals()[j];
|
||||
if (!ctx->is_relevant(lit.var()) && m_config.m_share_units_relevant_only)
|
||||
continue;
|
||||
|
||||
if (m_config.m_share_units_initial_only && lit.var() >= m_num_initial_atoms) {
|
||||
LOG_WORKER(2, " Skipping non-initial unit: " << lit.var() << "\n");
|
||||
continue; // skip non-iniial atoms if configured to do so
|
||||
}
|
||||
|
||||
expr_ref e(ctx->bool_var2expr(lit.var()), ctx->m); // turn literal into a Boolean expression
|
||||
if (m.is_and(e) || m.is_or(e))
|
||||
continue;
|
||||
|
||||
if (lit.sign())
|
||||
e = m.mk_not(e); // negate if literal is negative
|
||||
b.collect_clause(l2g, id, e);
|
||||
}
|
||||
m_num_shared_units = sz;
|
||||
}
|
||||
|
||||
void parallel::worker::simplify() {
|
||||
if (!m.inc())
|
||||
return;
|
||||
// first attempt: one-shot simplification of the context.
|
||||
// a precise schedule of repeated simplification is TBD.
|
||||
// also, the in-processing simplifier should be applied to
|
||||
// a current set of irredundant clauses that may be reduced by
|
||||
// unit propagation. By including the units we are effectively
|
||||
// repeating unit propagation, but potentially not subsumption or
|
||||
// Boolean simplifications that a solver could perform (smt_context doesnt really)
|
||||
// Integration of inprocssing simplifcation here or in sat/smt solver could
|
||||
// be based on taking the current clause set instead of the asserted formulas.
|
||||
if (!m_config.m_inprocessing)
|
||||
return;
|
||||
if (m_config.m_inprocessing_delay > 0) {
|
||||
--m_config.m_inprocessing_delay;
|
||||
return;
|
||||
}
|
||||
ctx->pop_to_base_lvl();
|
||||
if (ctx->m_base_lvl > 0)
|
||||
return; // simplification only at base level
|
||||
m_config.m_inprocessing = false; // initial strategy is to immediately disable inprocessing for future calls.
|
||||
dependent_expr_simplifier *s = ctx->m_simplifier.get();
|
||||
if (!s) {
|
||||
// create a simplifier if none exists
|
||||
// initialize it to a default pre-processing simplifier.
|
||||
ctx->m_fmls = alloc(base_dependent_expr_state, m);
|
||||
auto then_s = alloc(then_simplifier, m, ctx->get_params(), *ctx->m_fmls);
|
||||
s = then_s;
|
||||
ctx->m_simplifier = s;
|
||||
init_preprocess(m, ctx->get_params(), *then_s, *ctx->m_fmls);
|
||||
}
|
||||
|
||||
dependent_expr_state &fmls = *ctx->m_fmls.get();
|
||||
// extract assertions from ctx.
|
||||
// it is possible to track proof objects here if wanted.
|
||||
// feed them to the simplifier
|
||||
ptr_vector<expr> assertions;
|
||||
expr_ref_vector units(m);
|
||||
ctx->get_assertions(assertions);
|
||||
ctx->get_units(units);
|
||||
for (expr *e : assertions)
|
||||
fmls.add(dependent_expr(m, e, nullptr, nullptr));
|
||||
for (expr *e : units)
|
||||
fmls.add(dependent_expr(m, e, nullptr, nullptr));
|
||||
|
||||
// run in-processing on the assertions
|
||||
s->reduce();
|
||||
|
||||
scoped_ptr<context> new_ctx = alloc(context, m, m_smt_params, p.ctx.get_params());
|
||||
// extract simplified assertions from the simplifier
|
||||
// create a new context with the simplified assertions
|
||||
// update ctx with the new context.
|
||||
for (unsigned i = 0; i < fmls.qtail(); ++i) {
|
||||
auto const &de = fmls[i];
|
||||
new_ctx->assert_expr(de.fml());
|
||||
}
|
||||
|
||||
asserted_formulas &src_af = ctx->m_asserted_formulas;
|
||||
asserted_formulas &dst_af = new_ctx->m_asserted_formulas;
|
||||
src_af.get_macro_manager().copy_to(dst_af.get_macro_manager());
|
||||
new_ctx->copy_user_propagator(*ctx, true);
|
||||
ctx = new_ctx.detach();
|
||||
ctx->setup_context(true);
|
||||
ctx->internalize_assertions();
|
||||
auto old_atoms = m_num_initial_atoms;
|
||||
m_num_shared_units = ctx->assigned_literals().size();
|
||||
m_num_initial_atoms = ctx->get_num_bool_vars();
|
||||
LOG_WORKER(1, " inprocess " << old_atoms << " -> " << m_num_initial_atoms << "\n");
|
||||
}
|
||||
|
||||
void parallel::worker::collect_statistics(::statistics &st) const {
|
||||
ctx->collect_statistics(st);
|
||||
}
|
||||
|
||||
void parallel::worker::cancel() {
|
||||
LOG_WORKER(1, " canceling\n");
|
||||
m.limit().cancel();
|
||||
}
|
||||
|
||||
void parallel::batch_manager::backtrack(ast_translation &l2g, expr_ref_vector const &core,
|
||||
search_tree::node<cube_config> *node) {
|
||||
std::scoped_lock lock(mux);
|
||||
IF_VERBOSE(1, verbose_stream() << "Batch manager backtracking.\n");
|
||||
if (m_state != state::is_running)
|
||||
return;
|
||||
vector<cube_config::literal> g_core;
|
||||
for (auto c : core) {
|
||||
expr_ref g_c(l2g(c), m);
|
||||
if (!is_assumption(g_c))
|
||||
g_core.push_back(expr_ref(l2g(c), m));
|
||||
}
|
||||
m_search_tree.backtrack(node, g_core);
|
||||
|
||||
IF_VERBOSE(1, m_search_tree.display(verbose_stream() << bounded_pp_exprs(core) << "\n"););
|
||||
if (m_search_tree.is_closed()) {
|
||||
m_state = state::is_unsat;
|
||||
cancel_workers();
|
||||
}
|
||||
}
|
||||
|
||||
void parallel::batch_manager::split(ast_translation &l2g, unsigned source_worker_id,
|
||||
search_tree::node<cube_config> *node, expr *atom) {
|
||||
std::scoped_lock lock(mux);
|
||||
expr_ref lit(m), nlit(m);
|
||||
lit = l2g(atom);
|
||||
nlit = mk_not(m, lit);
|
||||
IF_VERBOSE(1, verbose_stream() << "Batch manager splitting on literal: " << mk_bounded_pp(lit, m, 3) << "\n");
|
||||
if (m_state != state::is_running)
|
||||
return;
|
||||
// optional heuristic:
|
||||
// node->get_status() == status::active
|
||||
// and depth is 'high' enough
|
||||
// then ignore split, and instead set the status of node to open.
|
||||
m_search_tree.split(node, lit, nlit);
|
||||
}
|
||||
|
||||
void parallel::batch_manager::collect_clause(ast_translation &l2g, unsigned source_worker_id, expr *clause) {
|
||||
std::scoped_lock lock(mux);
|
||||
expr *g_clause = l2g(clause);
|
||||
if (!shared_clause_set.contains(g_clause)) {
|
||||
shared_clause_set.insert(g_clause);
|
||||
shared_clause sc{source_worker_id, expr_ref(g_clause, m)};
|
||||
shared_clause_trail.push_back(sc);
|
||||
}
|
||||
}
|
||||
|
||||
void parallel::worker::collect_shared_clauses(ast_translation &g2l) {
|
||||
expr_ref_vector new_clauses = b.return_shared_clauses(g2l, m_shared_clause_limit, id);
|
||||
// iterate over new clauses and assert them in the local context
|
||||
for (expr *e : new_clauses) {
|
||||
ctx->assert_expr(e);
|
||||
LOG_WORKER(2, " asserting shared clause: " << mk_bounded_pp(e, m, 3) << "\n");
|
||||
}
|
||||
}
|
||||
|
||||
expr_ref_vector parallel::batch_manager::return_shared_clauses(ast_translation &g2l, unsigned &worker_limit,
|
||||
unsigned worker_id) {
|
||||
std::scoped_lock lock(mux);
|
||||
expr_ref_vector result(g2l.to());
|
||||
for (unsigned i = worker_limit; i < shared_clause_trail.size(); ++i) {
|
||||
if (shared_clause_trail[i].source_worker_id != worker_id)
|
||||
result.push_back(g2l(shared_clause_trail[i].clause.get()));
|
||||
}
|
||||
worker_limit = shared_clause_trail.size(); // update the worker limit to the end of the current trail
|
||||
return result;
|
||||
}
|
||||
|
||||
enum par_exception_kind {
|
||||
DEFAULT_EX,
|
||||
ERROR_EX
|
||||
};
|
||||
lbool parallel::worker::check_cube(expr_ref_vector const &cube) {
|
||||
for (auto &atom : cube)
|
||||
asms.push_back(atom);
|
||||
lbool r = l_undef;
|
||||
|
||||
vector<smt_params> smt_params;
|
||||
scoped_ptr_vector<ast_manager> pms;
|
||||
scoped_ptr_vector<context> pctxs;
|
||||
vector<expr_ref_vector> pasms;
|
||||
ctx->get_fparams().m_max_conflicts = std::min(m_config.m_threads_max_conflicts, m_config.m_max_conflicts);
|
||||
IF_VERBOSE(1, verbose_stream() << " Checking cube\n"
|
||||
<< bounded_pp_exprs(cube)
|
||||
<< "with max_conflicts: " << ctx->get_fparams().m_max_conflicts << "\n";);
|
||||
try {
|
||||
r = ctx->check(asms.size(), asms.data());
|
||||
} catch (z3_error &err) {
|
||||
b.set_exception(err.error_code());
|
||||
} catch (z3_exception &ex) {
|
||||
b.set_exception(ex.what());
|
||||
} catch (...) {
|
||||
b.set_exception("unknown exception");
|
||||
}
|
||||
asms.shrink(asms.size() - cube.size());
|
||||
LOG_WORKER(1, " DONE checking cube " << r << "\n";);
|
||||
return r;
|
||||
}
|
||||
|
||||
expr_ref parallel::worker::get_split_atom() {
|
||||
expr_ref result(m);
|
||||
double score = 0;
|
||||
unsigned n = 0;
|
||||
ctx->pop_to_search_lvl();
|
||||
for (bool_var v = 0; v < ctx->get_num_bool_vars(); ++v) {
|
||||
if (ctx->get_assignment(v) != l_undef)
|
||||
continue;
|
||||
expr *e = ctx->bool_var2expr(v);
|
||||
if (!e)
|
||||
continue;
|
||||
|
||||
double new_score = ctx->m_lit_scores[0][v] * ctx->m_lit_scores[1][v];
|
||||
|
||||
ctx->m_lit_scores[0][v] /= 2;
|
||||
ctx->m_lit_scores[1][v] /= 2;
|
||||
|
||||
if (new_score > score || !result || (new_score == score && m_rand(++n) == 0)) {
|
||||
score = new_score;
|
||||
result = e;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void parallel::batch_manager::set_sat(ast_translation &l2g, model &m) {
|
||||
std::scoped_lock lock(mux);
|
||||
IF_VERBOSE(1, verbose_stream() << "Batch manager setting SAT.\n");
|
||||
if (m_state != state::is_running)
|
||||
return;
|
||||
m_state = state::is_sat;
|
||||
p.ctx.set_model(m.translate(l2g));
|
||||
cancel_workers();
|
||||
}
|
||||
|
||||
void parallel::batch_manager::set_unsat(ast_translation &l2g, expr_ref_vector const &unsat_core) {
|
||||
std::scoped_lock lock(mux);
|
||||
IF_VERBOSE(1, verbose_stream() << "Batch manager setting UNSAT.\n");
|
||||
if (m_state != state::is_running)
|
||||
return;
|
||||
m_state = state::is_unsat;
|
||||
|
||||
// each call to check_sat needs to have a fresh unsat core
|
||||
SASSERT(p.ctx.m_unsat_core.empty());
|
||||
for (expr *e : unsat_core)
|
||||
p.ctx.m_unsat_core.push_back(l2g(e));
|
||||
cancel_workers();
|
||||
}
|
||||
|
||||
void parallel::batch_manager::set_exception(unsigned error_code) {
|
||||
std::scoped_lock lock(mux);
|
||||
IF_VERBOSE(1, verbose_stream() << "Batch manager setting exception code: " << error_code << ".\n");
|
||||
if (m_state != state::is_running)
|
||||
return;
|
||||
m_state = state::is_exception_code;
|
||||
m_exception_code = error_code;
|
||||
cancel_workers();
|
||||
}
|
||||
|
||||
void parallel::batch_manager::set_exception(std::string const &msg) {
|
||||
std::scoped_lock lock(mux);
|
||||
IF_VERBOSE(1, verbose_stream() << "Batch manager setting exception msg: " << msg << ".\n");
|
||||
if (m_state != state::is_running)
|
||||
return;
|
||||
m_state = state::is_exception_msg;
|
||||
m_exception_msg = msg;
|
||||
cancel_workers();
|
||||
}
|
||||
|
||||
void parallel::batch_manager::report_assumption_used(ast_translation &l2g, expr *assumption) {
|
||||
std::scoped_lock lock(mux);
|
||||
p.m_assumptions_used.insert(l2g(assumption));
|
||||
}
|
||||
|
||||
lbool parallel::batch_manager::get_result() const {
|
||||
if (m.limit().is_canceled())
|
||||
return l_undef; // the main context was cancelled, so we return undef.
|
||||
switch (m_state) {
|
||||
case state::is_running: // batch manager is still running, but all threads have processed their cubes, which
|
||||
// means all cubes were unsat
|
||||
if (!m_search_tree.is_closed())
|
||||
throw default_exception("inconsistent end state");
|
||||
if (!p.m_assumptions_used.empty()) {
|
||||
// collect unsat core from assumptions used, if any --> case when all cubes were unsat, but depend on
|
||||
// nonempty asms, so we need to add these asms to final unsat core
|
||||
SASSERT(p.ctx.m_unsat_core.empty());
|
||||
for (auto a : p.m_assumptions_used)
|
||||
p.ctx.m_unsat_core.push_back(a);
|
||||
}
|
||||
return l_false;
|
||||
case state::is_unsat:
|
||||
return l_false;
|
||||
case state::is_sat:
|
||||
return l_true;
|
||||
case state::is_exception_msg:
|
||||
throw default_exception(m_exception_msg.c_str());
|
||||
case state::is_exception_code:
|
||||
throw z3_error(m_exception_code);
|
||||
default:
|
||||
UNREACHABLE();
|
||||
return l_undef;
|
||||
}
|
||||
}
|
||||
|
||||
bool parallel::batch_manager::get_cube(ast_translation &g2l, unsigned id, expr_ref_vector &cube, node *&n) {
|
||||
cube.reset();
|
||||
std::unique_lock<std::mutex> lock(mux);
|
||||
if (m_search_tree.is_closed()) {
|
||||
IF_VERBOSE(1, verbose_stream() << "all done\n";);
|
||||
return false;
|
||||
}
|
||||
if (m_state != state::is_running) {
|
||||
IF_VERBOSE(1, verbose_stream() << "aborting get_cube\n";);
|
||||
return false;
|
||||
}
|
||||
node *t = m_search_tree.activate_node(n);
|
||||
if (!t)
|
||||
t = m_search_tree.find_active_node();
|
||||
if (!t)
|
||||
return false;
|
||||
IF_VERBOSE(1, m_search_tree.display(verbose_stream()); verbose_stream() << "\n";);
|
||||
n = t;
|
||||
while (t) {
|
||||
if (cube_config::literal_is_null(t->get_literal()))
|
||||
break;
|
||||
expr_ref lit(g2l.to());
|
||||
lit = g2l(t->get_literal().get());
|
||||
cube.push_back(lit);
|
||||
t = t->parent();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void parallel::batch_manager::initialize() {
|
||||
m_state = state::is_running;
|
||||
m_search_tree.reset();
|
||||
}
|
||||
|
||||
void parallel::batch_manager::collect_statistics(::statistics &st) const {
|
||||
st.update("parallel-num_cubes", m_stats.m_num_cubes);
|
||||
st.update("parallel-max-cube-size", m_stats.m_max_cube_depth);
|
||||
}
|
||||
|
||||
lbool parallel::operator()(expr_ref_vector const &asms) {
|
||||
ast_manager &m = ctx.m;
|
||||
scoped_limits sl(m.limit());
|
||||
unsigned finished_id = UINT_MAX;
|
||||
std::string ex_msg;
|
||||
par_exception_kind ex_kind = DEFAULT_EX;
|
||||
unsigned error_code = 0;
|
||||
bool done = false;
|
||||
unsigned num_rounds = 0;
|
||||
|
||||
if (m.has_trace_stream())
|
||||
throw default_exception("trace streams have to be off in parallel mode");
|
||||
|
||||
|
||||
params_ref params = ctx.get_params();
|
||||
for (unsigned i = 0; i < num_threads; ++i) {
|
||||
smt_params.push_back(ctx.get_fparams());
|
||||
smt_params.back().m_preprocess = false;
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < num_threads; ++i) {
|
||||
ast_manager* new_m = alloc(ast_manager, m, true);
|
||||
pms.push_back(new_m);
|
||||
pctxs.push_back(alloc(context, *new_m, smt_params[i], params));
|
||||
context& new_ctx = *pctxs.back();
|
||||
context::copy(ctx, new_ctx, true);
|
||||
new_ctx.set_random_seed(i + ctx.get_fparams().m_random_seed);
|
||||
ast_translation tr(m, *new_m);
|
||||
pasms.push_back(tr(asms));
|
||||
sl.push_child(&(new_m->limit()));
|
||||
}
|
||||
|
||||
auto cube = [](context& ctx, expr_ref_vector& lasms, expr_ref& c) {
|
||||
lookahead lh(ctx);
|
||||
c = lh.choose();
|
||||
if (c) {
|
||||
if ((ctx.get_random_value() % 2) == 0)
|
||||
c = c.get_manager().mk_not(c);
|
||||
lasms.push_back(c);
|
||||
struct scoped_clear {
|
||||
parallel &p;
|
||||
scoped_clear(parallel &p) : p(p) {}
|
||||
~scoped_clear() {
|
||||
p.m_workers.reset();
|
||||
p.m_assumptions_used.reset();
|
||||
p.m_assumptions.reset();
|
||||
}
|
||||
};
|
||||
scoped_clear clear(*this);
|
||||
|
||||
obj_hashtable<expr> unit_set;
|
||||
expr_ref_vector unit_trail(ctx.m);
|
||||
unsigned_vector unit_lim;
|
||||
for (unsigned i = 0; i < num_threads; ++i) unit_lim.push_back(0);
|
||||
m_batch_manager.initialize();
|
||||
m_workers.reset();
|
||||
for (auto e : asms)
|
||||
m_assumptions.insert(e);
|
||||
scoped_limits sl(m.limit());
|
||||
flet<unsigned> _nt(ctx.m_fparams.m_threads, 1);
|
||||
SASSERT(num_threads > 1);
|
||||
for (unsigned i = 0; i < num_threads; ++i)
|
||||
m_workers.push_back(alloc(worker, i, *this, asms));
|
||||
|
||||
std::function<void(void)> collect_units = [&,this]() {
|
||||
//return; -- has overhead
|
||||
for (unsigned i = 0; i < num_threads; ++i) {
|
||||
context& pctx = *pctxs[i];
|
||||
pctx.pop_to_base_lvl();
|
||||
ast_translation tr(pctx.m, ctx.m);
|
||||
unsigned sz = pctx.assigned_literals().size();
|
||||
for (unsigned j = unit_lim[i]; j < sz; ++j) {
|
||||
literal lit = pctx.assigned_literals()[j];
|
||||
//IF_VERBOSE(0, verbose_stream() << "(smt.thread " << i << " :unit " << lit << " " << pctx.is_relevant(lit.var()) << ")\n";);
|
||||
if (!pctx.is_relevant(lit.var()))
|
||||
continue;
|
||||
expr_ref e(pctx.bool_var2expr(lit.var()), pctx.m);
|
||||
if (lit.sign()) e = pctx.m.mk_not(e);
|
||||
expr_ref ce(tr(e.get()), ctx.m);
|
||||
if (!unit_set.contains(ce)) {
|
||||
unit_set.insert(ce);
|
||||
unit_trail.push_back(ce);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto w : m_workers)
|
||||
sl.push_child(&(w->limit()));
|
||||
|
||||
unsigned sz = unit_trail.size();
|
||||
for (unsigned i = 0; i < num_threads; ++i) {
|
||||
context& pctx = *pctxs[i];
|
||||
ast_translation tr(ctx.m, pctx.m);
|
||||
for (unsigned j = unit_lim[i]; j < sz; ++j) {
|
||||
expr_ref src(ctx.m), dst(pctx.m);
|
||||
dst = tr(unit_trail.get(j));
|
||||
pctx.assert_expr(dst);
|
||||
}
|
||||
unit_lim[i] = pctx.assigned_literals().size();
|
||||
}
|
||||
IF_VERBOSE(1, verbose_stream() << "(smt.thread :units " << sz << ")\n");
|
||||
};
|
||||
|
||||
std::mutex mux;
|
||||
|
||||
auto worker_thread = [&](int i) {
|
||||
try {
|
||||
context& pctx = *pctxs[i];
|
||||
ast_manager& pm = *pms[i];
|
||||
expr_ref_vector lasms(pasms[i]);
|
||||
expr_ref c(pm);
|
||||
|
||||
pctx.get_fparams().m_max_conflicts = std::min(thread_max_conflicts, max_conflicts);
|
||||
if (num_rounds > 0 && (num_rounds % pctx.get_fparams().m_threads_cube_frequency) == 0)
|
||||
cube(pctx, lasms, c);
|
||||
IF_VERBOSE(1, verbose_stream() << "(smt.thread " << i;
|
||||
if (num_rounds > 0) verbose_stream() << " :round " << num_rounds;
|
||||
if (c) verbose_stream() << " :cube " << mk_bounded_pp(c, pm, 3);
|
||||
verbose_stream() << ")\n";);
|
||||
lbool r = pctx.check(lasms.size(), lasms.data());
|
||||
|
||||
if (r == l_undef && pctx.m_num_conflicts >= max_conflicts)
|
||||
; // no-op
|
||||
else if (r == l_undef && pctx.m_num_conflicts >= thread_max_conflicts)
|
||||
return;
|
||||
else if (r == l_false && pctx.unsat_core().contains(c)) {
|
||||
IF_VERBOSE(1, verbose_stream() << "(smt.thread " << i << " :learn " << mk_bounded_pp(c, pm, 3) << ")");
|
||||
pctx.assert_expr(mk_not(mk_and(pctx.unsat_core())));
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
bool first = false;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mux);
|
||||
if (finished_id == UINT_MAX) {
|
||||
finished_id = i;
|
||||
first = true;
|
||||
result = r;
|
||||
done = true;
|
||||
}
|
||||
if (!first && r != l_undef && result == l_undef) {
|
||||
finished_id = i;
|
||||
result = r;
|
||||
}
|
||||
else if (!first) return;
|
||||
}
|
||||
|
||||
for (ast_manager* m : pms) {
|
||||
if (m != &pm) m->limit().cancel();
|
||||
}
|
||||
|
||||
}
|
||||
catch (z3_error & err) {
|
||||
if (finished_id == UINT_MAX) {
|
||||
error_code = err.error_code();
|
||||
ex_kind = ERROR_EX;
|
||||
done = true;
|
||||
}
|
||||
}
|
||||
catch (z3_exception & ex) {
|
||||
if (finished_id == UINT_MAX) {
|
||||
ex_msg = ex.what();
|
||||
ex_kind = DEFAULT_EX;
|
||||
done = true;
|
||||
}
|
||||
}
|
||||
catch (...) {
|
||||
if (finished_id == UINT_MAX) {
|
||||
ex_msg = "unknown exception";
|
||||
ex_kind = ERROR_EX;
|
||||
done = true;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// for debugging: num_threads = 1;
|
||||
|
||||
while (true) {
|
||||
// Launch threads
|
||||
vector<std::thread> threads(num_threads);
|
||||
for (unsigned i = 0; i < num_threads; ++i) {
|
||||
threads[i] = std::thread([&, i]() { worker_thread(i); });
|
||||
threads[i] = std::thread([&, i]() { m_workers[i]->run(); });
|
||||
}
|
||||
for (auto & th : threads) {
|
||||
|
||||
// Wait for all threads to finish
|
||||
for (auto &th : threads)
|
||||
th.join();
|
||||
}
|
||||
if (done) break;
|
||||
|
||||
collect_units();
|
||||
++num_rounds;
|
||||
max_conflicts = (max_conflicts < thread_max_conflicts) ? 0 : (max_conflicts - thread_max_conflicts);
|
||||
thread_max_conflicts *= 2;
|
||||
for (auto w : m_workers)
|
||||
w->collect_statistics(ctx.m_aux_stats);
|
||||
m_batch_manager.collect_statistics(ctx.m_aux_stats);
|
||||
|
||||
return m_batch_manager.get_result();
|
||||
}
|
||||
|
||||
for (context* c : pctxs) {
|
||||
c->collect_statistics(ctx.m_aux_stats);
|
||||
}
|
||||
|
||||
if (finished_id == UINT_MAX) {
|
||||
switch (ex_kind) {
|
||||
case ERROR_EX: throw z3_error(error_code);
|
||||
default: throw default_exception(std::move(ex_msg));
|
||||
}
|
||||
}
|
||||
|
||||
model_ref mdl;
|
||||
context& pctx = *pctxs[finished_id];
|
||||
ast_translation tr(*pms[finished_id], m);
|
||||
switch (result) {
|
||||
case l_true:
|
||||
pctx.get_model(mdl);
|
||||
if (mdl)
|
||||
ctx.set_model(mdl->translate(tr));
|
||||
break;
|
||||
case l_false:
|
||||
ctx.m_unsat_core.reset();
|
||||
for (expr* e : pctx.unsat_core())
|
||||
ctx.m_unsat_core.push_back(tr(e));
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
}
|
||||
} // namespace smt
|
||||
#endif
|
||||
|
|
|
@ -11,7 +11,7 @@ Abstract:
|
|||
|
||||
Author:
|
||||
|
||||
nbjorner 2020-01-31
|
||||
Ilana 2025
|
||||
|
||||
Revision History:
|
||||
|
||||
|
@ -19,16 +19,164 @@ Revision History:
|
|||
#pragma once
|
||||
|
||||
#include "smt/smt_context.h"
|
||||
#include "util/search_tree.h"
|
||||
#include <thread>
|
||||
#include <mutex>
|
||||
|
||||
|
||||
namespace smt {
|
||||
|
||||
struct cube_config {
|
||||
using literal = expr_ref;
|
||||
static bool literal_is_null(expr_ref const& l) { return l == nullptr; }
|
||||
static std::ostream& display_literal(std::ostream& out, expr_ref const& l) { return out << mk_bounded_pp(l, l.get_manager()); }
|
||||
};
|
||||
|
||||
class parallel {
|
||||
context& ctx;
|
||||
public:
|
||||
parallel(context& ctx): ctx(ctx) {}
|
||||
unsigned num_threads;
|
||||
|
||||
lbool operator()(expr_ref_vector const& asms);
|
||||
struct shared_clause {
|
||||
unsigned source_worker_id;
|
||||
expr_ref clause;
|
||||
};
|
||||
|
||||
class batch_manager {
|
||||
|
||||
enum state {
|
||||
is_running,
|
||||
is_sat,
|
||||
is_unsat,
|
||||
is_exception_msg,
|
||||
is_exception_code
|
||||
};
|
||||
|
||||
struct stats {
|
||||
unsigned m_max_cube_depth = 0;
|
||||
unsigned m_num_cubes = 0;
|
||||
};
|
||||
|
||||
|
||||
ast_manager& m;
|
||||
parallel& p;
|
||||
std::mutex mux;
|
||||
state m_state = state::is_running;
|
||||
stats m_stats;
|
||||
using node = search_tree::node<cube_config>;
|
||||
search_tree::tree<cube_config> m_search_tree;
|
||||
|
||||
unsigned m_exception_code = 0;
|
||||
std::string m_exception_msg;
|
||||
vector<shared_clause> shared_clause_trail; // store all shared clauses with worker IDs
|
||||
obj_hashtable<expr> shared_clause_set; // for duplicate filtering on per-thread clause expressions
|
||||
|
||||
// called from batch manager to cancel other workers if we've reached a verdict
|
||||
void cancel_workers() {
|
||||
IF_VERBOSE(1, verbose_stream() << "Canceling workers\n");
|
||||
for (auto& w : p.m_workers)
|
||||
w->cancel();
|
||||
}
|
||||
|
||||
void init_parameters_state();
|
||||
|
||||
bool is_assumption(expr* e) const {
|
||||
return p.m_assumptions.contains(e);
|
||||
}
|
||||
|
||||
public:
|
||||
batch_manager(ast_manager& m, parallel& p) : m(m), p(p), m_search_tree(expr_ref(m)) { }
|
||||
|
||||
void initialize();
|
||||
|
||||
void set_unsat(ast_translation& l2g, expr_ref_vector const& unsat_core);
|
||||
void set_sat(ast_translation& l2g, model& m);
|
||||
void set_exception(std::string const& msg);
|
||||
void set_exception(unsigned error_code);
|
||||
void collect_statistics(::statistics& st) const;
|
||||
|
||||
bool get_cube(ast_translation& g2l, unsigned id, expr_ref_vector& cube, node*& n);
|
||||
void backtrack(ast_translation& l2g, expr_ref_vector const& core, node* n);
|
||||
void split(ast_translation& l2g, unsigned id, node* n, expr* atom);
|
||||
|
||||
void report_assumption_used(ast_translation& l2g, expr* assumption);
|
||||
void collect_clause(ast_translation& l2g, unsigned source_worker_id, expr* clause);
|
||||
expr_ref_vector return_shared_clauses(ast_translation& g2l, unsigned& worker_limit, unsigned worker_id);
|
||||
|
||||
lbool get_result() const;
|
||||
};
|
||||
|
||||
class worker {
|
||||
struct config {
|
||||
unsigned m_threads_max_conflicts = 1000;
|
||||
bool m_share_units = true;
|
||||
bool m_share_units_relevant_only = true;
|
||||
bool m_share_units_initial_only = true;
|
||||
double m_max_conflict_mul = 1.5;
|
||||
bool m_cube_initial_only = true;
|
||||
bool m_inprocessing = true;
|
||||
unsigned m_inprocessing_delay = 1;
|
||||
unsigned m_max_cube_depth = 20;
|
||||
unsigned m_max_conflicts = UINT_MAX;
|
||||
};
|
||||
|
||||
using node = search_tree::node<cube_config>;
|
||||
|
||||
unsigned id; // unique identifier for the worker
|
||||
parallel& p;
|
||||
batch_manager& b;
|
||||
ast_manager m;
|
||||
expr_ref_vector asms;
|
||||
smt_params m_smt_params;
|
||||
config m_config;
|
||||
random_gen m_rand;
|
||||
scoped_ptr<context> ctx;
|
||||
ast_translation m_g2l, m_l2g;
|
||||
search_tree::tree<cube_config> m_search_tree;
|
||||
|
||||
unsigned m_num_shared_units = 0;
|
||||
unsigned m_num_initial_atoms = 0;
|
||||
unsigned m_shared_clause_limit = 0; // remembers the index into shared_clause_trail marking the boundary between "old" and "new" clauses to share
|
||||
|
||||
expr_ref get_split_atom();
|
||||
|
||||
lbool check_cube(expr_ref_vector const& cube);
|
||||
void share_units(ast_translation& l2g);
|
||||
|
||||
void update_max_thread_conflicts() {
|
||||
m_config.m_threads_max_conflicts = (unsigned)(m_config.m_max_conflict_mul * m_config.m_threads_max_conflicts);
|
||||
} // allow for backoff scheme of conflicts within the thread for cube timeouts.
|
||||
|
||||
void simplify();
|
||||
|
||||
public:
|
||||
worker(unsigned id, parallel& p, expr_ref_vector const& _asms);
|
||||
void run();
|
||||
|
||||
void collect_shared_clauses(ast_translation& g2l);
|
||||
|
||||
void cancel();
|
||||
void collect_statistics(::statistics& st) const;
|
||||
|
||||
reslimit& limit() {
|
||||
return m.limit();
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
obj_hashtable<expr> m_assumptions_used; // assumptions used in unsat cores, to be used in final core
|
||||
obj_hashtable<expr> m_assumptions; // all assumptions
|
||||
batch_manager m_batch_manager;
|
||||
scoped_ptr_vector<worker> m_workers;
|
||||
|
||||
public:
|
||||
parallel(context& ctx) :
|
||||
ctx(ctx),
|
||||
num_threads(std::min(
|
||||
(unsigned)std::thread::hardware_concurrency(),
|
||||
ctx.get_fparams().m_threads)),
|
||||
m_batch_manager(ctx.m, *this) {}
|
||||
|
||||
lbool operator()(expr_ref_vector const& asms);
|
||||
};
|
||||
|
||||
}
|
||||
|
|
265
src/util/search_tree.h
Normal file
265
src/util/search_tree.h
Normal file
|
@ -0,0 +1,265 @@
|
|||
/*++
|
||||
Copyright (c) 2025 Microsoft Corporation
|
||||
|
||||
Module Name:
|
||||
|
||||
search_tree.h
|
||||
|
||||
Abstract:
|
||||
|
||||
A binary search tree for managing the search space of a DPLL(T) solver.
|
||||
It supports splitting on atoms, backtracking on conflicts, and activating nodes.
|
||||
|
||||
Nodes can be in one of three states: open, closed, or active.
|
||||
- Closed nodes are fully explored (both children are closed).
|
||||
- Active nodes have no children and are currently being explored.
|
||||
- Open nodes either have children that are open or are leaves.
|
||||
|
||||
A node can be split if it is active. After splitting, it becomes open and has two open children.
|
||||
|
||||
Backtracking on a conflict closes all nodes below the last node whose atom is in the conflict set.
|
||||
|
||||
Activation searches an open node closest to a seed node.
|
||||
|
||||
Author:
|
||||
|
||||
Ilana Shapiro 2025-9-06
|
||||
|
||||
--*/
|
||||
|
||||
#include "util/util.h"
|
||||
#include "util/vector.h"
|
||||
#pragma once
|
||||
|
||||
namespace search_tree {
|
||||
|
||||
enum class status { open, closed, active };
|
||||
|
||||
template<typename Config>
|
||||
class node {
|
||||
typedef typename Config::literal literal;
|
||||
literal m_literal;
|
||||
node* m_left = nullptr, * m_right = nullptr, * m_parent = nullptr;
|
||||
status m_status;
|
||||
public:
|
||||
node(literal const& l, node* parent) :
|
||||
m_literal(l), m_parent(parent), m_status(status::open) {}
|
||||
~node() {
|
||||
dealloc(m_left);
|
||||
dealloc(m_right);
|
||||
}
|
||||
|
||||
status get_status() const { return m_status; }
|
||||
void set_status(status s) { m_status = s; }
|
||||
literal const& get_literal() const { return m_literal; }
|
||||
bool literal_is_null() const { return Config::is_null(m_literal); }
|
||||
void split(literal const& a, literal const& b) {
|
||||
SASSERT(!Config::literal_is_null(a));
|
||||
SASSERT(!Config::literal_is_null(b));
|
||||
if (m_status != status::active)
|
||||
return;
|
||||
SASSERT(!m_left);
|
||||
SASSERT(!m_right);
|
||||
m_left = alloc(node<Config>, a, this);
|
||||
m_right = alloc(node<Config>, b, this);
|
||||
m_status = status::open;
|
||||
}
|
||||
|
||||
node* left() const { return m_left; }
|
||||
node* right() const { return m_right; }
|
||||
node* parent() const { return m_parent; }
|
||||
|
||||
node* find_active_node() {
|
||||
if (m_status == status::active)
|
||||
return this;
|
||||
if (m_status != status::open)
|
||||
return nullptr;
|
||||
node* nodes[2] = { m_left, m_right };
|
||||
for (unsigned i = 0; i < 2; ++i) {
|
||||
auto res = nodes[i] ? nodes[i]->find_active_node() : nullptr;
|
||||
if (res)
|
||||
return res;
|
||||
}
|
||||
if (m_left->get_status() == status::closed && m_right->get_status() == status::closed)
|
||||
m_status = status::closed;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void display(std::ostream& out, unsigned indent) const {
|
||||
for (unsigned i = 0; i < indent; ++i)
|
||||
out << " ";
|
||||
Config::display_literal(out, m_literal);
|
||||
out << (get_status() == status::open ? " (o)" : get_status() == status::closed ? " (c)" : " (a)");
|
||||
out << "\n";
|
||||
if (m_left)
|
||||
m_left->display(out, indent + 2);
|
||||
if (m_right)
|
||||
m_right->display(out, indent + 2);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename Config>
|
||||
class tree {
|
||||
typedef typename Config::literal literal;
|
||||
scoped_ptr<node<Config>> m_root = nullptr;
|
||||
literal m_null_literal;
|
||||
random_gen m_rand;
|
||||
|
||||
// return an active node in the subtree rooted at n, or nullptr if there is none
|
||||
// close nodes that are fully explored (whose children are all closed)
|
||||
node<Config>* activate_from_root(node<Config>* n) {
|
||||
if (!n)
|
||||
return nullptr;
|
||||
if (n->get_status() != status::open)
|
||||
return nullptr;
|
||||
auto left = n->left();
|
||||
auto right = n->right();
|
||||
if (!left && !right) {
|
||||
n->set_status(status::active);
|
||||
return n;
|
||||
}
|
||||
node<Config>* nodes[2] = { left, right };
|
||||
unsigned index = m_rand(2);
|
||||
auto child = activate_from_root(nodes[index]);
|
||||
if (child)
|
||||
return child;
|
||||
child = activate_from_root(nodes[1 - index]);
|
||||
if (child)
|
||||
return child;
|
||||
if (left && right && left->get_status() == status::closed && right->get_status() == status::closed)
|
||||
n->set_status(status::closed);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void close_node(node<Config>* n) {
|
||||
if (!n)
|
||||
return;
|
||||
if (n->get_status() == status::closed)
|
||||
return;
|
||||
n->set_status(status::closed);
|
||||
close_node(n->left());
|
||||
close_node(n->right());
|
||||
while (n) {
|
||||
auto p = n->parent();
|
||||
if (!p)
|
||||
return;
|
||||
if (p->get_status() != status::open)
|
||||
return;
|
||||
if (p->left()->get_status() != status::closed)
|
||||
return;
|
||||
if (p->right()->get_status() != status::closed)
|
||||
return;
|
||||
p->set_status(status::closed);
|
||||
n = p;
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
tree(literal const& null_literal) : m_null_literal(null_literal) {
|
||||
reset();
|
||||
}
|
||||
|
||||
void set_seed(unsigned seed) {
|
||||
m_rand.set_seed(seed);
|
||||
}
|
||||
|
||||
void reset() {
|
||||
m_root = alloc(node<Config>, m_null_literal, nullptr);
|
||||
m_root->set_status(status::active);
|
||||
}
|
||||
|
||||
// Split current node if it is active.
|
||||
// After the call, n is open and has two children.
|
||||
void split(node<Config>* n, literal const& a, literal const& b) {
|
||||
n->split(a, b);
|
||||
}
|
||||
|
||||
// conflict is given by a set of literals.
|
||||
// they are a subset of literals on the path from root to n
|
||||
void backtrack(node<Config>* n, vector<literal> const& conflict) {
|
||||
if (conflict.empty()) {
|
||||
close_node(m_root.get());
|
||||
m_root->set_status(status::closed);
|
||||
return;
|
||||
}
|
||||
SASSERT(n != m_root.get());
|
||||
// all literals in conflict are on the path from root to n
|
||||
// remove assumptions from conflict to ensure this.
|
||||
DEBUG_CODE(
|
||||
auto on_path = [&](literal const& a) {
|
||||
node<Config>* p = n;
|
||||
while (p) {
|
||||
if (p->get_literal() == a)
|
||||
return true;
|
||||
p = p->parent();
|
||||
}
|
||||
return false;
|
||||
};
|
||||
SASSERT(all_of(conflict, [&](auto const& a) { return on_path(a); }));
|
||||
);
|
||||
|
||||
while (n) {
|
||||
if (any_of(conflict, [&](auto const& a) { return a == n->get_literal(); })) {
|
||||
close_node(n);
|
||||
return;
|
||||
}
|
||||
n = n->parent();
|
||||
}
|
||||
UNREACHABLE();
|
||||
}
|
||||
|
||||
// return an active node in the tree, or nullptr if there is none
|
||||
// first check if there is a node to activate under n,
|
||||
// if not, go up the tree and try to activate a sibling subtree
|
||||
node<Config>* activate_node(node<Config>* n) {
|
||||
if (!n) {
|
||||
if (m_root->get_status() == status::active)
|
||||
return m_root.get();
|
||||
n = m_root.get();
|
||||
}
|
||||
auto res = activate_from_root(n);
|
||||
if (res)
|
||||
return res;
|
||||
|
||||
auto p = n->parent();
|
||||
while (p) {
|
||||
if (p->left() && p->left()->get_status() == status::closed &&
|
||||
p->right() && p->right()->get_status() == status::closed) {
|
||||
p->set_status(status::closed);
|
||||
n = p;
|
||||
p = n->parent();
|
||||
continue;
|
||||
}
|
||||
if (n == p->left()) {
|
||||
res = activate_from_root(p->right());
|
||||
if (res)
|
||||
return res;
|
||||
}
|
||||
else {
|
||||
VERIFY(n == p->right());
|
||||
res = activate_from_root(p->left());
|
||||
if (res)
|
||||
return res;
|
||||
}
|
||||
n = p;
|
||||
p = n->parent();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
node<Config>* find_active_node() {
|
||||
return m_root->find_active_node();
|
||||
}
|
||||
|
||||
bool is_closed() const {
|
||||
return m_root->get_status() == status::closed;
|
||||
}
|
||||
|
||||
std::ostream& display(std::ostream& out) const {
|
||||
m_root->display(out, 0);
|
||||
return out;
|
||||
}
|
||||
|
||||
};
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue