From ce53e06e29433cdc511656691d79c9d6b7634a31 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sun, 21 Sep 2025 10:11:04 +0300 Subject: [PATCH] Par (#7945) * port parallel Signed-off-by: Nikolaj Bjorner * updates Signed-off-by: Nikolaj Bjorner * update smt-parallel Signed-off-by: Nikolaj Bjorner * cleanup Signed-off-by: Nikolaj Bjorner * neat Signed-off-by: Nikolaj Bjorner * configuration parameter renaming Signed-off-by: Nikolaj Bjorner * config parameters Signed-off-by: Nikolaj Bjorner --------- Signed-off-by: Nikolaj Bjorner --- src/ast/simplifiers/dependent_expr_state.h | 74 ++- src/smt/smt_context.cpp | 5 + src/smt/smt_context.h | 8 + src/smt/smt_internalizer.cpp | 14 + src/smt/smt_parallel.cpp | 714 ++++++++++++++------- src/smt/smt_parallel.h | 154 ++++- src/util/search_tree.h | 265 ++++++++ 7 files changed, 1006 insertions(+), 228 deletions(-) create mode 100644 src/util/search_tree.h diff --git a/src/ast/simplifiers/dependent_expr_state.h b/src/ast/simplifiers/dependent_expr_state.h index e187f19c6..047dc4652 100644 --- a/src/ast/simplifiers/dependent_expr_state.h +++ b/src/ast/simplifiers/dependent_expr_state.h @@ -33,6 +33,7 @@ Author: #include "util/statistics.h" #include "util/params.h" #include "util/z3_exception.h" +#include "ast/ast_util.h" #include "ast/converters/model_converter.h" #include "ast/simplifiers/dependent_expr.h" #include "ast/simplifiers/model_reconstruction_trail.h" @@ -113,9 +114,80 @@ public: model_reconstruction_trail& model_trail() override { throw default_exception("unexpected access to model reconstruction"); } bool updated() override { return false; } void reset_updated() override {} - }; + +struct base_dependent_expr_state : public dependent_expr_state { + ast_manager& m; + model_reconstruction_trail m_reconstruction_trail; + bool m_updated = false; + bool m_inconsistent = false; + vector m_fmls; + base_dependent_expr_state(ast_manager& m) :dependent_expr_state(m), m(m), m_reconstruction_trail(m, m_trail) {} + unsigned qtail() const override { return m_fmls.size(); } + dependent_expr const& operator[](unsigned i) override { return m_fmls[i]; } + void update(unsigned i, dependent_expr const& j) override { + SASSERT(j.fml()); + check_false(j.fml()); + m_fmls[i] = j; + m_updated = true; + } + void add(dependent_expr const& j) override { m_updated = true; check_false(j.fml()); m_fmls.push_back(j); } + bool inconsistent() override { return m_inconsistent; } + bool updated() override { return m_updated; } + void reset_updated() override { m_updated = false; } + model_reconstruction_trail& model_trail() override { return m_reconstruction_trail; } + std::ostream& display(std::ostream& out) const override { + unsigned i = 0; + for (auto const& d : m_fmls) { + if (i > 0 && i == qhead()) + out << "---- head ---\n"; + out << d << "\n"; + ++i; + } + m_reconstruction_trail.display(out); + return out; + } + void check_false(expr* f) { + if (m.is_false(f)) + m_inconsistent = true; + } + void replay(unsigned qhead, expr_ref_vector& assumptions) { + m_reconstruction_trail.replay(qhead, assumptions, *this); + } + void flatten_suffix() override { + expr_mark seen; + unsigned j = qhead(); + expr_ref_vector pinned(m); + for (unsigned i = qhead(); i < qtail(); ++i) { + expr* f = m_fmls[i].fml(), * g = nullptr; + pinned.push_back(f); + if (seen.is_marked(f)) + continue; + seen.mark(f, true); + if (m.is_true(f)) + continue; + if (m.is_and(f)) { + auto* d = m_fmls[i].dep(); + for (expr* arg : *to_app(f)) + add(dependent_expr(m, arg, nullptr, d)); + continue; + } + if (m.is_not(f, g) && m.is_or(g)) { + auto* d = m_fmls[i].dep(); + for (expr* arg : *to_app(g)) + add(dependent_expr(m, mk_not(m, arg), nullptr, d)); + continue; + } + if (i != j) + m_fmls[j] = m_fmls[i]; + ++j; + } + m_fmls.shrink(j); + } +}; + + inline std::ostream& operator<<(std::ostream& out, dependent_expr_state& st) { return st.display(out); } diff --git a/src/smt/smt_context.cpp b/src/smt/smt_context.cpp index af460d549..01432cabf 100644 --- a/src/smt/smt_context.cpp +++ b/src/smt/smt_context.cpp @@ -4751,6 +4751,11 @@ namespace smt { } mdl = m_model.get(); } + if (m_fmls && mdl) { + auto convert = m_fmls->model_trail().get_model_converter(); + if (convert) + (*convert)(mdl); + } } void context::get_levels(ptr_vector const& vars, unsigned_vector& depth) { diff --git a/src/smt/smt_context.h b/src/smt/smt_context.h index 2fbc1d705..09a358e0e 100644 --- a/src/smt/smt_context.h +++ b/src/smt/smt_context.h @@ -19,6 +19,7 @@ Revision History: #pragma once #include "ast/quantifier_stat.h" +#include "ast/simplifiers/dependent_expr_state.h" #include "smt/smt_clause.h" #include "smt/smt_setup.h" #include "smt/smt_enode.h" @@ -132,6 +133,11 @@ namespace smt { bool m_internalizing_assertions = false; lbool m_internal_completed = l_undef; + scoped_ptr m_simplifier; + scoped_ptr m_fmls; + + svector m_lit_scores[2]; + // ----------------------------------- // @@ -1292,6 +1298,8 @@ namespace smt { virtual bool resolve_conflict(); + void add_scores(unsigned n, literal const *lits); + // ----------------------------------- // diff --git a/src/smt/smt_internalizer.cpp b/src/smt/smt_internalizer.cpp index 9aa6d68f4..7f0fe1e9e 100644 --- a/src/smt/smt_internalizer.cpp +++ b/src/smt/smt_internalizer.cpp @@ -933,6 +933,10 @@ namespace smt { m_activity.reserve(v+1); m_bool_var2expr.reserve(v+1); m_bool_var2expr[v] = n; + m_lit_scores[0].reserve(v + 1); + m_lit_scores[1].reserve(v + 1); + m_lit_scores[0][v] = m_lit_scores[1][v] = 0.0; + literal l(v, false); literal not_l(v, true); unsigned aux = std::max(l.index(), not_l.index()) + 1; @@ -960,6 +964,15 @@ namespace smt { SASSERT(check_bool_var_vector_sizes()); return v; } + + void context::add_scores(unsigned n, literal const *lits) { + for (unsigned i = 0; i < n; ++i) { + auto lit = lits[i]; + unsigned v = lit.var(); // unique key per literal + m_lit_scores[lit.sign()][v] += 1.0 / n; + } + } + void context::undo_mk_bool_var() { SASSERT(!m_b_internalized_stack.empty()); @@ -1419,6 +1432,7 @@ namespace smt { break; case CLS_LEARNED: dump_lemma(num_lits, lits); + add_scores(num_lits, lits); break; default: break; diff --git a/src/smt/smt_parallel.cpp b/src/smt/smt_parallel.cpp index 4941e4df9..cbae4a3ef 100644 --- a/src/smt/smt_parallel.cpp +++ b/src/smt/smt_parallel.cpp @@ -12,262 +12,528 @@ Abstract: Author: nbjorner 2020-01-31 + Ilana Shapiro 2025 --*/ - #include "util/scoped_ptr_vector.h" #include "ast/ast_util.h" #include "ast/ast_pp.h" #include "ast/ast_ll_pp.h" #include "ast/ast_translation.h" +#include "ast/simplifiers/then_simplifier.h" #include "smt/smt_parallel.h" #include "smt/smt_lookahead.h" +#include "solver/solver_preprocess.h" + +#include +#include + +class bounded_pp_exprs { + expr_ref_vector const &es; + +public: + bounded_pp_exprs(expr_ref_vector const &es) : es(es) {} + + std::ostream &display(std::ostream &out) const { + for (auto e : es) + out << mk_bounded_pp(e, es.get_manager()) << "\n"; + return out; + } +}; + +inline std::ostream &operator<<(std::ostream &out, bounded_pp_exprs const &pp) { + return pp.display(out); +} #ifdef SINGLE_THREAD namespace smt { - - lbool parallel::operator()(expr_ref_vector const& asms) { + + lbool parallel::operator()(expr_ref_vector const &asms) { return l_undef; } -} +} // namespace smt #else #include +#define LOG_WORKER(lvl, s) IF_VERBOSE(lvl, verbose_stream() << "Worker " << id << s) + namespace smt { - - lbool parallel::operator()(expr_ref_vector const& asms) { - - lbool result = l_undef; - unsigned num_threads = std::min((unsigned) std::thread::hardware_concurrency(), ctx.get_fparams().m_threads); - flet _nt(ctx.m_fparams.m_threads, 1); - unsigned thread_max_conflicts = ctx.get_fparams().m_threads_max_conflicts; - unsigned max_conflicts = ctx.get_fparams().m_max_conflicts; - - // try first sequential with a low conflict budget to make super easy problems cheap - unsigned max_c = std::min(thread_max_conflicts, 40u); - flet _mc(ctx.get_fparams().m_max_conflicts, max_c); - result = ctx.check(asms.size(), asms.data()); - if (result != l_undef || ctx.m_num_conflicts < max_c) { - return result; - } - - enum par_exception_kind { - DEFAULT_EX, - ERROR_EX - }; - - vector smt_params; - scoped_ptr_vector pms; - scoped_ptr_vector pctxs; - vector pasms; - - ast_manager& m = ctx.m; - scoped_limits sl(m.limit()); - unsigned finished_id = UINT_MAX; - std::string ex_msg; - par_exception_kind ex_kind = DEFAULT_EX; - unsigned error_code = 0; - bool done = false; - unsigned num_rounds = 0; - if (m.has_trace_stream()) - throw default_exception("trace streams have to be off in parallel mode"); - - - params_ref params = ctx.get_params(); - for (unsigned i = 0; i < num_threads; ++i) { - smt_params.push_back(ctx.get_fparams()); - smt_params.back().m_preprocess = false; - } - - for (unsigned i = 0; i < num_threads; ++i) { - ast_manager* new_m = alloc(ast_manager, m, true); - pms.push_back(new_m); - pctxs.push_back(alloc(context, *new_m, smt_params[i], params)); - context& new_ctx = *pctxs.back(); - context::copy(ctx, new_ctx, true); - new_ctx.set_random_seed(i + ctx.get_fparams().m_random_seed); - ast_translation tr(m, *new_m); - pasms.push_back(tr(asms)); - sl.push_child(&(new_m->limit())); - } - - auto cube = [](context& ctx, expr_ref_vector& lasms, expr_ref& c) { - lookahead lh(ctx); - c = lh.choose(); - if (c) { - if ((ctx.get_random_value() % 2) == 0) - c = c.get_manager().mk_not(c); - lasms.push_back(c); - } - }; - - obj_hashtable unit_set; - expr_ref_vector unit_trail(ctx.m); - unsigned_vector unit_lim; - for (unsigned i = 0; i < num_threads; ++i) unit_lim.push_back(0); - - std::function collect_units = [&,this]() { - //return; -- has overhead - for (unsigned i = 0; i < num_threads; ++i) { - context& pctx = *pctxs[i]; - pctx.pop_to_base_lvl(); - ast_translation tr(pctx.m, ctx.m); - unsigned sz = pctx.assigned_literals().size(); - for (unsigned j = unit_lim[i]; j < sz; ++j) { - literal lit = pctx.assigned_literals()[j]; - //IF_VERBOSE(0, verbose_stream() << "(smt.thread " << i << " :unit " << lit << " " << pctx.is_relevant(lit.var()) << ")\n";); - if (!pctx.is_relevant(lit.var())) - continue; - expr_ref e(pctx.bool_var2expr(lit.var()), pctx.m); - if (lit.sign()) e = pctx.m.mk_not(e); - expr_ref ce(tr(e.get()), ctx.m); - if (!unit_set.contains(ce)) { - unit_set.insert(ce); - unit_trail.push_back(ce); - } - } - } - - unsigned sz = unit_trail.size(); - for (unsigned i = 0; i < num_threads; ++i) { - context& pctx = *pctxs[i]; - ast_translation tr(ctx.m, pctx.m); - for (unsigned j = unit_lim[i]; j < sz; ++j) { - expr_ref src(ctx.m), dst(pctx.m); - dst = tr(unit_trail.get(j)); - pctx.assert_expr(dst); - } - unit_lim[i] = pctx.assigned_literals().size(); - } - IF_VERBOSE(1, verbose_stream() << "(smt.thread :units " << sz << ")\n"); - }; - - std::mutex mux; - - auto worker_thread = [&](int i) { - try { - context& pctx = *pctxs[i]; - ast_manager& pm = *pms[i]; - expr_ref_vector lasms(pasms[i]); - expr_ref c(pm); - - pctx.get_fparams().m_max_conflicts = std::min(thread_max_conflicts, max_conflicts); - if (num_rounds > 0 && (num_rounds % pctx.get_fparams().m_threads_cube_frequency) == 0) - cube(pctx, lasms, c); - IF_VERBOSE(1, verbose_stream() << "(smt.thread " << i; - if (num_rounds > 0) verbose_stream() << " :round " << num_rounds; - if (c) verbose_stream() << " :cube " << mk_bounded_pp(c, pm, 3); - verbose_stream() << ")\n";); - lbool r = pctx.check(lasms.size(), lasms.data()); - - if (r == l_undef && pctx.m_num_conflicts >= max_conflicts) - ; // no-op - else if (r == l_undef && pctx.m_num_conflicts >= thread_max_conflicts) - return; - else if (r == l_false && pctx.unsat_core().contains(c)) { - IF_VERBOSE(1, verbose_stream() << "(smt.thread " << i << " :learn " << mk_bounded_pp(c, pm, 3) << ")"); - pctx.assert_expr(mk_not(mk_and(pctx.unsat_core()))); - return; - } - - - bool first = false; - { - std::lock_guard lock(mux); - if (finished_id == UINT_MAX) { - finished_id = i; - first = true; - result = r; - done = true; - } - if (!first && r != l_undef && result == l_undef) { - finished_id = i; - result = r; - } - else if (!first) return; - } - - for (ast_manager* m : pms) { - if (m != &pm) m->limit().cancel(); - } - - } - catch (z3_error & err) { - if (finished_id == UINT_MAX) { - error_code = err.error_code(); - ex_kind = ERROR_EX; - done = true; - } - } - catch (z3_exception & ex) { - if (finished_id == UINT_MAX) { - ex_msg = ex.what(); - ex_kind = DEFAULT_EX; - done = true; - } - } - catch (...) { - if (finished_id == UINT_MAX) { - ex_msg = "unknown exception"; - ex_kind = ERROR_EX; - done = true; - } - } - }; - - // for debugging: num_threads = 1; + void parallel::worker::run() { + search_tree::node *node = nullptr; + expr_ref_vector cube(m); while (true) { - vector threads(num_threads); - for (unsigned i = 0; i < num_threads; ++i) { - threads[i] = std::thread([&, i]() { worker_thread(i); }); - } - for (auto & th : threads) { - th.join(); - } - if (done) break; - collect_units(); - ++num_rounds; - max_conflicts = (max_conflicts < thread_max_conflicts) ? 0 : (max_conflicts - thread_max_conflicts); - thread_max_conflicts *= 2; + if (!b.get_cube(m_g2l, id, cube, node)) { + LOG_WORKER(1, " no more cubes\n"); + return; + } + collect_shared_clauses(m_g2l); + + check_cube_start: + LOG_WORKER(1, " CUBE SIZE IN MAIN LOOP: " << cube.size() << "\n"); + lbool r = check_cube(cube); + + if (!m.inc()) { + b.set_exception("context cancelled"); + return; + } + + switch (r) { + case l_undef: { + update_max_thread_conflicts(); + LOG_WORKER(1, " found undef cube\n"); + // return unprocessed cubes to the batch manager + // add a split literal to the batch manager. + // optionally process other cubes and delay sending back unprocessed cubes to batch manager. + if (m_config.m_max_cube_depth <= cube.size()) + goto check_cube_start; + + auto atom = get_split_atom(); + if (!atom) + goto check_cube_start; + b.split(m_l2g, id, node, atom); + simplify(); + break; + } + case l_true: { + LOG_WORKER(1, " found sat cube\n"); + model_ref mdl; + ctx->get_model(mdl); + b.set_sat(m_l2g, *mdl); + return; + } + case l_false: { + expr_ref_vector const &unsat_core = ctx->unsat_core(); + LOG_WORKER(2, " unsat core:\n"; + for (auto c : unsat_core) verbose_stream() << mk_bounded_pp(c, m, 3) << "\n"); + // If the unsat core only contains external assumptions, + // unsatisfiability does not depend on the current cube and the entire problem is unsat. + if (all_of(unsat_core, [&](expr *e) { return asms.contains(e); })) { + LOG_WORKER(1, " determined formula unsat\n"); + b.set_unsat(m_l2g, unsat_core); + return; + } + // report assumptions used in unsat core, so they can be used in final core + for (expr *e : unsat_core) + if (asms.contains(e)) + b.report_assumption_used(m_l2g, e); + + LOG_WORKER(1, " found unsat cube\n"); + b.backtrack(m_l2g, unsat_core, node); + break; + } + } + if (m_config.m_share_units) + share_units(m_l2g); + } + } + + parallel::worker::worker(unsigned id, parallel &p, expr_ref_vector const &_asms) + : id(id), p(p), b(p.m_batch_manager), m_smt_params(p.ctx.get_fparams()), asms(m), m_g2l(p.ctx.m, m), + m_l2g(m, p.ctx.m), m_search_tree(expr_ref(m)) { + for (auto e : _asms) + asms.push_back(m_g2l(e)); + LOG_WORKER(1, " created with " << asms.size() << " assumptions\n"); + m_smt_params.m_preprocess = false; + ctx = alloc(context, m, m_smt_params, p.ctx.get_params()); + context::copy(p.ctx, *ctx, true); + ctx->set_random_seed(id + m_smt_params.m_random_seed); + // don't share initial units + ctx->pop_to_base_lvl(); + m_num_shared_units = ctx->assigned_literals().size(); + m_num_initial_atoms = ctx->get_num_bool_vars(); + } + + void parallel::worker::share_units(ast_translation &l2g) { + // Collect new units learned locally by this worker and send to batch manager + ctx->pop_to_base_lvl(); + unsigned sz = ctx->assigned_literals().size(); + for (unsigned j = m_num_shared_units; j < sz; ++j) { // iterate only over new literals since last sync + literal lit = ctx->assigned_literals()[j]; + if (!ctx->is_relevant(lit.var()) && m_config.m_share_units_relevant_only) + continue; + + if (m_config.m_share_units_initial_only && lit.var() >= m_num_initial_atoms) { + LOG_WORKER(2, " Skipping non-initial unit: " << lit.var() << "\n"); + continue; // skip non-iniial atoms if configured to do so + } + + expr_ref e(ctx->bool_var2expr(lit.var()), ctx->m); // turn literal into a Boolean expression + if (m.is_and(e) || m.is_or(e)) + continue; + + if (lit.sign()) + e = m.mk_not(e); // negate if literal is negative + b.collect_clause(l2g, id, e); + } + m_num_shared_units = sz; + } + + void parallel::worker::simplify() { + if (!m.inc()) + return; + // first attempt: one-shot simplification of the context. + // a precise schedule of repeated simplification is TBD. + // also, the in-processing simplifier should be applied to + // a current set of irredundant clauses that may be reduced by + // unit propagation. By including the units we are effectively + // repeating unit propagation, but potentially not subsumption or + // Boolean simplifications that a solver could perform (smt_context doesnt really) + // Integration of inprocssing simplifcation here or in sat/smt solver could + // be based on taking the current clause set instead of the asserted formulas. + if (!m_config.m_inprocessing) + return; + if (m_config.m_inprocessing_delay > 0) { + --m_config.m_inprocessing_delay; + return; + } + ctx->pop_to_base_lvl(); + if (ctx->m_base_lvl > 0) + return; // simplification only at base level + m_config.m_inprocessing = false; // initial strategy is to immediately disable inprocessing for future calls. + dependent_expr_simplifier *s = ctx->m_simplifier.get(); + if (!s) { + // create a simplifier if none exists + // initialize it to a default pre-processing simplifier. + ctx->m_fmls = alloc(base_dependent_expr_state, m); + auto then_s = alloc(then_simplifier, m, ctx->get_params(), *ctx->m_fmls); + s = then_s; + ctx->m_simplifier = s; + init_preprocess(m, ctx->get_params(), *then_s, *ctx->m_fmls); } - for (context* c : pctxs) { - c->collect_statistics(ctx.m_aux_stats); + dependent_expr_state &fmls = *ctx->m_fmls.get(); + // extract assertions from ctx. + // it is possible to track proof objects here if wanted. + // feed them to the simplifier + ptr_vector assertions; + expr_ref_vector units(m); + ctx->get_assertions(assertions); + ctx->get_units(units); + for (expr *e : assertions) + fmls.add(dependent_expr(m, e, nullptr, nullptr)); + for (expr *e : units) + fmls.add(dependent_expr(m, e, nullptr, nullptr)); + + // run in-processing on the assertions + s->reduce(); + + scoped_ptr new_ctx = alloc(context, m, m_smt_params, p.ctx.get_params()); + // extract simplified assertions from the simplifier + // create a new context with the simplified assertions + // update ctx with the new context. + for (unsigned i = 0; i < fmls.qtail(); ++i) { + auto const &de = fmls[i]; + new_ctx->assert_expr(de.fml()); } - if (finished_id == UINT_MAX) { - switch (ex_kind) { - case ERROR_EX: throw z3_error(error_code); - default: throw default_exception(std::move(ex_msg)); - } - } + asserted_formulas &src_af = ctx->m_asserted_formulas; + asserted_formulas &dst_af = new_ctx->m_asserted_formulas; + src_af.get_macro_manager().copy_to(dst_af.get_macro_manager()); + new_ctx->copy_user_propagator(*ctx, true); + ctx = new_ctx.detach(); + ctx->setup_context(true); + ctx->internalize_assertions(); + auto old_atoms = m_num_initial_atoms; + m_num_shared_units = ctx->assigned_literals().size(); + m_num_initial_atoms = ctx->get_num_bool_vars(); + LOG_WORKER(1, " inprocess " << old_atoms << " -> " << m_num_initial_atoms << "\n"); + } - model_ref mdl; - context& pctx = *pctxs[finished_id]; - ast_translation tr(*pms[finished_id], m); - switch (result) { - case l_true: - pctx.get_model(mdl); - if (mdl) - ctx.set_model(mdl->translate(tr)); - break; - case l_false: - ctx.m_unsat_core.reset(); - for (expr* e : pctx.unsat_core()) - ctx.m_unsat_core.push_back(tr(e)); - break; - default: - break; - } + void parallel::worker::collect_statistics(::statistics &st) const { + ctx->collect_statistics(st); + } + void parallel::worker::cancel() { + LOG_WORKER(1, " canceling\n"); + m.limit().cancel(); + } + + void parallel::batch_manager::backtrack(ast_translation &l2g, expr_ref_vector const &core, + search_tree::node *node) { + std::scoped_lock lock(mux); + IF_VERBOSE(1, verbose_stream() << "Batch manager backtracking.\n"); + if (m_state != state::is_running) + return; + vector g_core; + for (auto c : core) { + expr_ref g_c(l2g(c), m); + if (!is_assumption(g_c)) + g_core.push_back(expr_ref(l2g(c), m)); + } + m_search_tree.backtrack(node, g_core); + + IF_VERBOSE(1, m_search_tree.display(verbose_stream() << bounded_pp_exprs(core) << "\n");); + if (m_search_tree.is_closed()) { + m_state = state::is_unsat; + cancel_workers(); + } + } + + void parallel::batch_manager::split(ast_translation &l2g, unsigned source_worker_id, + search_tree::node *node, expr *atom) { + std::scoped_lock lock(mux); + expr_ref lit(m), nlit(m); + lit = l2g(atom); + nlit = mk_not(m, lit); + IF_VERBOSE(1, verbose_stream() << "Batch manager splitting on literal: " << mk_bounded_pp(lit, m, 3) << "\n"); + if (m_state != state::is_running) + return; + // optional heuristic: + // node->get_status() == status::active + // and depth is 'high' enough + // then ignore split, and instead set the status of node to open. + m_search_tree.split(node, lit, nlit); + } + + void parallel::batch_manager::collect_clause(ast_translation &l2g, unsigned source_worker_id, expr *clause) { + std::scoped_lock lock(mux); + expr *g_clause = l2g(clause); + if (!shared_clause_set.contains(g_clause)) { + shared_clause_set.insert(g_clause); + shared_clause sc{source_worker_id, expr_ref(g_clause, m)}; + shared_clause_trail.push_back(sc); + } + } + + void parallel::worker::collect_shared_clauses(ast_translation &g2l) { + expr_ref_vector new_clauses = b.return_shared_clauses(g2l, m_shared_clause_limit, id); + // iterate over new clauses and assert them in the local context + for (expr *e : new_clauses) { + ctx->assert_expr(e); + LOG_WORKER(2, " asserting shared clause: " << mk_bounded_pp(e, m, 3) << "\n"); + } + } + + expr_ref_vector parallel::batch_manager::return_shared_clauses(ast_translation &g2l, unsigned &worker_limit, + unsigned worker_id) { + std::scoped_lock lock(mux); + expr_ref_vector result(g2l.to()); + for (unsigned i = worker_limit; i < shared_clause_trail.size(); ++i) { + if (shared_clause_trail[i].source_worker_id != worker_id) + result.push_back(g2l(shared_clause_trail[i].clause.get())); + } + worker_limit = shared_clause_trail.size(); // update the worker limit to the end of the current trail return result; } -} + lbool parallel::worker::check_cube(expr_ref_vector const &cube) { + for (auto &atom : cube) + asms.push_back(atom); + lbool r = l_undef; + + ctx->get_fparams().m_max_conflicts = std::min(m_config.m_threads_max_conflicts, m_config.m_max_conflicts); + IF_VERBOSE(1, verbose_stream() << " Checking cube\n" + << bounded_pp_exprs(cube) + << "with max_conflicts: " << ctx->get_fparams().m_max_conflicts << "\n";); + try { + r = ctx->check(asms.size(), asms.data()); + } catch (z3_error &err) { + b.set_exception(err.error_code()); + } catch (z3_exception &ex) { + b.set_exception(ex.what()); + } catch (...) { + b.set_exception("unknown exception"); + } + asms.shrink(asms.size() - cube.size()); + LOG_WORKER(1, " DONE checking cube " << r << "\n";); + return r; + } + + expr_ref parallel::worker::get_split_atom() { + expr_ref result(m); + double score = 0; + unsigned n = 0; + ctx->pop_to_search_lvl(); + for (bool_var v = 0; v < ctx->get_num_bool_vars(); ++v) { + if (ctx->get_assignment(v) != l_undef) + continue; + expr *e = ctx->bool_var2expr(v); + if (!e) + continue; + + double new_score = ctx->m_lit_scores[0][v] * ctx->m_lit_scores[1][v]; + + ctx->m_lit_scores[0][v] /= 2; + ctx->m_lit_scores[1][v] /= 2; + + if (new_score > score || !result || (new_score == score && m_rand(++n) == 0)) { + score = new_score; + result = e; + } + } + return result; + } + + void parallel::batch_manager::set_sat(ast_translation &l2g, model &m) { + std::scoped_lock lock(mux); + IF_VERBOSE(1, verbose_stream() << "Batch manager setting SAT.\n"); + if (m_state != state::is_running) + return; + m_state = state::is_sat; + p.ctx.set_model(m.translate(l2g)); + cancel_workers(); + } + + void parallel::batch_manager::set_unsat(ast_translation &l2g, expr_ref_vector const &unsat_core) { + std::scoped_lock lock(mux); + IF_VERBOSE(1, verbose_stream() << "Batch manager setting UNSAT.\n"); + if (m_state != state::is_running) + return; + m_state = state::is_unsat; + + // each call to check_sat needs to have a fresh unsat core + SASSERT(p.ctx.m_unsat_core.empty()); + for (expr *e : unsat_core) + p.ctx.m_unsat_core.push_back(l2g(e)); + cancel_workers(); + } + + void parallel::batch_manager::set_exception(unsigned error_code) { + std::scoped_lock lock(mux); + IF_VERBOSE(1, verbose_stream() << "Batch manager setting exception code: " << error_code << ".\n"); + if (m_state != state::is_running) + return; + m_state = state::is_exception_code; + m_exception_code = error_code; + cancel_workers(); + } + + void parallel::batch_manager::set_exception(std::string const &msg) { + std::scoped_lock lock(mux); + IF_VERBOSE(1, verbose_stream() << "Batch manager setting exception msg: " << msg << ".\n"); + if (m_state != state::is_running) + return; + m_state = state::is_exception_msg; + m_exception_msg = msg; + cancel_workers(); + } + + void parallel::batch_manager::report_assumption_used(ast_translation &l2g, expr *assumption) { + std::scoped_lock lock(mux); + p.m_assumptions_used.insert(l2g(assumption)); + } + + lbool parallel::batch_manager::get_result() const { + if (m.limit().is_canceled()) + return l_undef; // the main context was cancelled, so we return undef. + switch (m_state) { + case state::is_running: // batch manager is still running, but all threads have processed their cubes, which + // means all cubes were unsat + if (!m_search_tree.is_closed()) + throw default_exception("inconsistent end state"); + if (!p.m_assumptions_used.empty()) { + // collect unsat core from assumptions used, if any --> case when all cubes were unsat, but depend on + // nonempty asms, so we need to add these asms to final unsat core + SASSERT(p.ctx.m_unsat_core.empty()); + for (auto a : p.m_assumptions_used) + p.ctx.m_unsat_core.push_back(a); + } + return l_false; + case state::is_unsat: + return l_false; + case state::is_sat: + return l_true; + case state::is_exception_msg: + throw default_exception(m_exception_msg.c_str()); + case state::is_exception_code: + throw z3_error(m_exception_code); + default: + UNREACHABLE(); + return l_undef; + } + } + + bool parallel::batch_manager::get_cube(ast_translation &g2l, unsigned id, expr_ref_vector &cube, node *&n) { + cube.reset(); + std::unique_lock lock(mux); + if (m_search_tree.is_closed()) { + IF_VERBOSE(1, verbose_stream() << "all done\n";); + return false; + } + if (m_state != state::is_running) { + IF_VERBOSE(1, verbose_stream() << "aborting get_cube\n";); + return false; + } + node *t = m_search_tree.activate_node(n); + if (!t) + t = m_search_tree.find_active_node(); + if (!t) + return false; + IF_VERBOSE(1, m_search_tree.display(verbose_stream()); verbose_stream() << "\n";); + n = t; + while (t) { + if (cube_config::literal_is_null(t->get_literal())) + break; + expr_ref lit(g2l.to()); + lit = g2l(t->get_literal().get()); + cube.push_back(lit); + t = t->parent(); + } + return true; + } + + void parallel::batch_manager::initialize() { + m_state = state::is_running; + m_search_tree.reset(); + } + + void parallel::batch_manager::collect_statistics(::statistics &st) const { + st.update("parallel-num_cubes", m_stats.m_num_cubes); + st.update("parallel-max-cube-size", m_stats.m_max_cube_depth); + } + + lbool parallel::operator()(expr_ref_vector const &asms) { + ast_manager &m = ctx.m; + + if (m.has_trace_stream()) + throw default_exception("trace streams have to be off in parallel mode"); + + struct scoped_clear { + parallel &p; + scoped_clear(parallel &p) : p(p) {} + ~scoped_clear() { + p.m_workers.reset(); + p.m_assumptions_used.reset(); + p.m_assumptions.reset(); + } + }; + scoped_clear clear(*this); + + m_batch_manager.initialize(); + m_workers.reset(); + for (auto e : asms) + m_assumptions.insert(e); + scoped_limits sl(m.limit()); + flet _nt(ctx.m_fparams.m_threads, 1); + SASSERT(num_threads > 1); + for (unsigned i = 0; i < num_threads; ++i) + m_workers.push_back(alloc(worker, i, *this, asms)); + + for (auto w : m_workers) + sl.push_child(&(w->limit())); + + // Launch threads + vector threads(num_threads); + for (unsigned i = 0; i < num_threads; ++i) { + threads[i] = std::thread([&, i]() { m_workers[i]->run(); }); + } + + // Wait for all threads to finish + for (auto &th : threads) + th.join(); + + for (auto w : m_workers) + w->collect_statistics(ctx.m_aux_stats); + m_batch_manager.collect_statistics(ctx.m_aux_stats); + + return m_batch_manager.get_result(); + } + +} // namespace smt #endif diff --git a/src/smt/smt_parallel.h b/src/smt/smt_parallel.h index 07b04019d..da9e38897 100644 --- a/src/smt/smt_parallel.h +++ b/src/smt/smt_parallel.h @@ -11,7 +11,7 @@ Abstract: Author: - nbjorner 2020-01-31 + Ilana 2025 Revision History: @@ -19,16 +19,164 @@ Revision History: #pragma once #include "smt/smt_context.h" +#include "util/search_tree.h" +#include +#include + namespace smt { + struct cube_config { + using literal = expr_ref; + static bool literal_is_null(expr_ref const& l) { return l == nullptr; } + static std::ostream& display_literal(std::ostream& out, expr_ref const& l) { return out << mk_bounded_pp(l, l.get_manager()); } + }; + class parallel { context& ctx; + unsigned num_threads; + + struct shared_clause { + unsigned source_worker_id; + expr_ref clause; + }; + + class batch_manager { + + enum state { + is_running, + is_sat, + is_unsat, + is_exception_msg, + is_exception_code + }; + + struct stats { + unsigned m_max_cube_depth = 0; + unsigned m_num_cubes = 0; + }; + + + ast_manager& m; + parallel& p; + std::mutex mux; + state m_state = state::is_running; + stats m_stats; + using node = search_tree::node; + search_tree::tree m_search_tree; + + unsigned m_exception_code = 0; + std::string m_exception_msg; + vector shared_clause_trail; // store all shared clauses with worker IDs + obj_hashtable shared_clause_set; // for duplicate filtering on per-thread clause expressions + + // called from batch manager to cancel other workers if we've reached a verdict + void cancel_workers() { + IF_VERBOSE(1, verbose_stream() << "Canceling workers\n"); + for (auto& w : p.m_workers) + w->cancel(); + } + + void init_parameters_state(); + + bool is_assumption(expr* e) const { + return p.m_assumptions.contains(e); + } + + public: + batch_manager(ast_manager& m, parallel& p) : m(m), p(p), m_search_tree(expr_ref(m)) { } + + void initialize(); + + void set_unsat(ast_translation& l2g, expr_ref_vector const& unsat_core); + void set_sat(ast_translation& l2g, model& m); + void set_exception(std::string const& msg); + void set_exception(unsigned error_code); + void collect_statistics(::statistics& st) const; + + bool get_cube(ast_translation& g2l, unsigned id, expr_ref_vector& cube, node*& n); + void backtrack(ast_translation& l2g, expr_ref_vector const& core, node* n); + void split(ast_translation& l2g, unsigned id, node* n, expr* atom); + + void report_assumption_used(ast_translation& l2g, expr* assumption); + void collect_clause(ast_translation& l2g, unsigned source_worker_id, expr* clause); + expr_ref_vector return_shared_clauses(ast_translation& g2l, unsigned& worker_limit, unsigned worker_id); + + lbool get_result() const; + }; + + class worker { + struct config { + unsigned m_threads_max_conflicts = 1000; + bool m_share_units = true; + bool m_share_units_relevant_only = true; + bool m_share_units_initial_only = true; + double m_max_conflict_mul = 1.5; + bool m_cube_initial_only = true; + bool m_inprocessing = true; + unsigned m_inprocessing_delay = 1; + unsigned m_max_cube_depth = 20; + unsigned m_max_conflicts = UINT_MAX; + }; + + using node = search_tree::node; + + unsigned id; // unique identifier for the worker + parallel& p; + batch_manager& b; + ast_manager m; + expr_ref_vector asms; + smt_params m_smt_params; + config m_config; + random_gen m_rand; + scoped_ptr ctx; + ast_translation m_g2l, m_l2g; + search_tree::tree m_search_tree; + + unsigned m_num_shared_units = 0; + unsigned m_num_initial_atoms = 0; + unsigned m_shared_clause_limit = 0; // remembers the index into shared_clause_trail marking the boundary between "old" and "new" clauses to share + + expr_ref get_split_atom(); + + lbool check_cube(expr_ref_vector const& cube); + void share_units(ast_translation& l2g); + + void update_max_thread_conflicts() { + m_config.m_threads_max_conflicts = (unsigned)(m_config.m_max_conflict_mul * m_config.m_threads_max_conflicts); + } // allow for backoff scheme of conflicts within the thread for cube timeouts. + + void simplify(); + + public: + worker(unsigned id, parallel& p, expr_ref_vector const& _asms); + void run(); + + void collect_shared_clauses(ast_translation& g2l); + + void cancel(); + void collect_statistics(::statistics& st) const; + + reslimit& limit() { + return m.limit(); + } + + }; + + obj_hashtable m_assumptions_used; // assumptions used in unsat cores, to be used in final core + obj_hashtable m_assumptions; // all assumptions + batch_manager m_batch_manager; + scoped_ptr_vector m_workers; + public: - parallel(context& ctx): ctx(ctx) {} + parallel(context& ctx) : + ctx(ctx), + num_threads(std::min( + (unsigned)std::thread::hardware_concurrency(), + ctx.get_fparams().m_threads)), + m_batch_manager(ctx.m, *this) {} lbool operator()(expr_ref_vector const& asms); - }; } diff --git a/src/util/search_tree.h b/src/util/search_tree.h new file mode 100644 index 000000000..c2bae663c --- /dev/null +++ b/src/util/search_tree.h @@ -0,0 +1,265 @@ +/*++ +Copyright (c) 2025 Microsoft Corporation + +Module Name: + + search_tree.h + +Abstract: + + A binary search tree for managing the search space of a DPLL(T) solver. + It supports splitting on atoms, backtracking on conflicts, and activating nodes. + + Nodes can be in one of three states: open, closed, or active. + - Closed nodes are fully explored (both children are closed). + - Active nodes have no children and are currently being explored. + - Open nodes either have children that are open or are leaves. + + A node can be split if it is active. After splitting, it becomes open and has two open children. + + Backtracking on a conflict closes all nodes below the last node whose atom is in the conflict set. + + Activation searches an open node closest to a seed node. + +Author: + + Ilana Shapiro 2025-9-06 + +--*/ + +#include "util/util.h" +#include "util/vector.h" +#pragma once + +namespace search_tree { + + enum class status { open, closed, active }; + + template + class node { + typedef typename Config::literal literal; + literal m_literal; + node* m_left = nullptr, * m_right = nullptr, * m_parent = nullptr; + status m_status; + public: + node(literal const& l, node* parent) : + m_literal(l), m_parent(parent), m_status(status::open) {} + ~node() { + dealloc(m_left); + dealloc(m_right); + } + + status get_status() const { return m_status; } + void set_status(status s) { m_status = s; } + literal const& get_literal() const { return m_literal; } + bool literal_is_null() const { return Config::is_null(m_literal); } + void split(literal const& a, literal const& b) { + SASSERT(!Config::literal_is_null(a)); + SASSERT(!Config::literal_is_null(b)); + if (m_status != status::active) + return; + SASSERT(!m_left); + SASSERT(!m_right); + m_left = alloc(node, a, this); + m_right = alloc(node, b, this); + m_status = status::open; + } + + node* left() const { return m_left; } + node* right() const { return m_right; } + node* parent() const { return m_parent; } + + node* find_active_node() { + if (m_status == status::active) + return this; + if (m_status != status::open) + return nullptr; + node* nodes[2] = { m_left, m_right }; + for (unsigned i = 0; i < 2; ++i) { + auto res = nodes[i] ? nodes[i]->find_active_node() : nullptr; + if (res) + return res; + } + if (m_left->get_status() == status::closed && m_right->get_status() == status::closed) + m_status = status::closed; + return nullptr; + } + + void display(std::ostream& out, unsigned indent) const { + for (unsigned i = 0; i < indent; ++i) + out << " "; + Config::display_literal(out, m_literal); + out << (get_status() == status::open ? " (o)" : get_status() == status::closed ? " (c)" : " (a)"); + out << "\n"; + if (m_left) + m_left->display(out, indent + 2); + if (m_right) + m_right->display(out, indent + 2); + } + }; + + template + class tree { + typedef typename Config::literal literal; + scoped_ptr> m_root = nullptr; + literal m_null_literal; + random_gen m_rand; + + // return an active node in the subtree rooted at n, or nullptr if there is none + // close nodes that are fully explored (whose children are all closed) + node* activate_from_root(node* n) { + if (!n) + return nullptr; + if (n->get_status() != status::open) + return nullptr; + auto left = n->left(); + auto right = n->right(); + if (!left && !right) { + n->set_status(status::active); + return n; + } + node* nodes[2] = { left, right }; + unsigned index = m_rand(2); + auto child = activate_from_root(nodes[index]); + if (child) + return child; + child = activate_from_root(nodes[1 - index]); + if (child) + return child; + if (left && right && left->get_status() == status::closed && right->get_status() == status::closed) + n->set_status(status::closed); + return nullptr; + } + + void close_node(node* n) { + if (!n) + return; + if (n->get_status() == status::closed) + return; + n->set_status(status::closed); + close_node(n->left()); + close_node(n->right()); + while (n) { + auto p = n->parent(); + if (!p) + return; + if (p->get_status() != status::open) + return; + if (p->left()->get_status() != status::closed) + return; + if (p->right()->get_status() != status::closed) + return; + p->set_status(status::closed); + n = p; + } + } + + public: + + tree(literal const& null_literal) : m_null_literal(null_literal) { + reset(); + } + + void set_seed(unsigned seed) { + m_rand.set_seed(seed); + } + + void reset() { + m_root = alloc(node, m_null_literal, nullptr); + m_root->set_status(status::active); + } + + // Split current node if it is active. + // After the call, n is open and has two children. + void split(node* n, literal const& a, literal const& b) { + n->split(a, b); + } + + // conflict is given by a set of literals. + // they are a subset of literals on the path from root to n + void backtrack(node* n, vector const& conflict) { + if (conflict.empty()) { + close_node(m_root.get()); + m_root->set_status(status::closed); + return; + } + SASSERT(n != m_root.get()); + // all literals in conflict are on the path from root to n + // remove assumptions from conflict to ensure this. + DEBUG_CODE( + auto on_path = [&](literal const& a) { + node* p = n; + while (p) { + if (p->get_literal() == a) + return true; + p = p->parent(); + } + return false; + }; + SASSERT(all_of(conflict, [&](auto const& a) { return on_path(a); })); + ); + + while (n) { + if (any_of(conflict, [&](auto const& a) { return a == n->get_literal(); })) { + close_node(n); + return; + } + n = n->parent(); + } + UNREACHABLE(); + } + + // return an active node in the tree, or nullptr if there is none + // first check if there is a node to activate under n, + // if not, go up the tree and try to activate a sibling subtree + node* activate_node(node* n) { + if (!n) { + if (m_root->get_status() == status::active) + return m_root.get(); + n = m_root.get(); + } + auto res = activate_from_root(n); + if (res) + return res; + + auto p = n->parent(); + while (p) { + if (p->left() && p->left()->get_status() == status::closed && + p->right() && p->right()->get_status() == status::closed) { + p->set_status(status::closed); + n = p; + p = n->parent(); + continue; + } + if (n == p->left()) { + res = activate_from_root(p->right()); + if (res) + return res; + } + else { + VERIFY(n == p->right()); + res = activate_from_root(p->left()); + if (res) + return res; + } + n = p; + p = n->parent(); + } + return nullptr; + } + + node* find_active_node() { + return m_root->find_active_node(); + } + + bool is_closed() const { + return m_root->get_status() == status::closed; + } + + std::ostream& display(std::ostream& out) const { + m_root->display(out, 0); + return out; + } + + }; +} \ No newline at end of file