From 9fab72b3efc154a4814ce199dd95d23b44b26f50 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Fri, 31 Jan 2020 22:20:25 -0800 Subject: [PATCH] fix build Signed-off-by: Nikolaj Bjorner --- src/smt/smt_context.h | 19 +++---- src/smt/smt_parallel.cpp | 112 ++++++++++++++++++++++++++++----------- src/smt/smt_parallel.h | 9 +--- 3 files changed, 88 insertions(+), 52 deletions(-) diff --git a/src/smt/smt_context.h b/src/smt/smt_context.h index b2826e2de..29b84f5c3 100644 --- a/src/smt/smt_context.h +++ b/src/smt/smt_context.h @@ -112,7 +112,7 @@ namespace smt { unsigned m_final_check_idx; // circular counter used for implementing fairness bool m_is_auxiliary; // used to prevent unwanted information from being logged. - parallel* m_par; + class parallel* m_par; unsigned m_par_index; // ----------------------------------- @@ -413,26 +413,17 @@ namespace smt { return js.get_kind() == b_justification::JUSTIFICATION && js.get_justification()->get_from_theory() == th_id; } - - void set_par(unsigned idx, parallel* p) { m_par = p; m_par_index = idx; } - void set_random_seed(unsigned s) { m_random.set_seed(s); } int get_random_value() { return m_random(); } bool is_searching() const { return m_searching; } - svector const & get_activity_vector() const { - return m_activity; - } + svector const & get_activity_vector() const { return m_activity; } - double get_activity(bool_var v) const { - return m_activity[v]; - } + double get_activity(bool_var v) const { return m_activity[v]; } - void set_activity(bool_var v, double act) { - m_activity[v] = act; - } + void set_activity(bool_var v, double act) { m_activity[v] = act; } void activity_changed(bool_var v, bool increased) { if (increased) { @@ -1617,6 +1608,8 @@ namespace smt { return m_unsat_core.get(idx); } + expr_ref_vector const& unsat_core() const { return m_unsat_core; } + void get_levels(ptr_vector const& vars, unsigned_vector& depth); expr_ref_vector get_trail(); diff --git a/src/smt/smt_parallel.cpp b/src/smt/smt_parallel.cpp index b36a33002..60676d016 100644 --- a/src/smt/smt_parallel.cpp +++ b/src/smt/smt_parallel.cpp @@ -15,29 +15,13 @@ Author: --*/ -#include "smt/smt_parallel.h" #include "util/scoped_ptr_vector.h" +#include "ast/ast_util.h" +#include "ast/ast_translation.h" +#include "smt/smt_parallel.h" +#include "smt/smt_lookahead.h" namespace smt { - - void parallel::add_unit(context& pctx, expr* e) { - std::lock_guard lock(m_mux); - ast_translation tr(pctx.m, ctx.m); - expr_ref u (tr(e), ctx.m); - if (!m_unit_set.contains(u)) { - m_unit_trail.push_back(u); - m_unit_set.insert(u); - } - } - - void parallel::get_units(unsigned idx, context& pctx) { - std::lock_guard lock(m_mux); - ast_translation tr(ctx.m, pctx.m); - for (unsigned i = m_unit_lim[idx]; i < m_unit_trail.size(); ++i) { - expr_ref u (tr(m_unit_trail.get(i)), pctx.m); - pctx.assert_expr(u); - } - } lbool parallel::operator()(expr_ref_vector const& asms) { @@ -58,6 +42,9 @@ namespace smt { std::string ex_msg; par_exception_kind ex_kind = DEFAULT_EX; unsigned error_code = 0; + bool done = false; + unsigned num_rounds = 0; + unsigned max_conflicts = 400; for (unsigned i = 0; i < num_threads; ++i) { ast_manager* new_m = alloc(ast_manager, m, true); @@ -68,18 +55,74 @@ namespace smt { new_ctx.set_random_seed(i + ctx.get_fparams().m_random_seed); ast_translation tr(*new_m, m); pasms.push_back(tr(asms)); - m_unit_lim.push_back(0); - new_ctx.set_par(i, this); } + std::function cube = [&](context& ctx, expr_ref_vector& lasms, expr_ref& c) { + lookahead lh(ctx); + c = lh.choose(); + if (c) lasms.push_back(c); + }; + + obj_hashtable unit_set; + expr_ref_vector unit_trail(ctx.m); + unsigned_vector unit_lim; + for (unsigned i = 0; i < num_threads; ++i) unit_lim.push_back(0); + + std::function collect_units = [&,this]() { + for (unsigned i = 0; i < num_threads; ++i) { + context& pctx = *pctxs[i]; + pctx.pop_to_base_lvl(); + ast_translation tr(pctx.m, ctx.m); + unsigned sz = pctx.assigned_literals().size(); + for (unsigned j = unit_lim[i]; j < sz; ++j) { + literal lit = pctx.assigned_literals()[j]; + expr_ref e(pctx.bool_var2expr(lit.var()), pctx.m); + if (lit.sign()) e = pctx.m.mk_not(e); + expr_ref ce(tr(e.get()), ctx.m); + if (!unit_set.contains(ce)) { + unit_set.insert(ce); + unit_trail.push_back(ce); + } + } + } + + for (unsigned i = 0; i < num_threads; ++i) { + context& pctx = *pctxs[i]; + ast_translation tr(ctx.m, pctx.m); + unsigned sz = unit_trail.size(); + for (unsigned j = unit_lim[i]; j < sz; ++j) { + expr_ref src(ctx.m), dst(pctx.m); + dst = tr(unit_trail.get(j)); + pctx.assert_expr(dst); + } + unit_lim[i] = sz; + } + }; + std::mutex mux; auto worker_thread = [&](int i) { try { IF_VERBOSE(0, verbose_stream() << "thread " << i << "\n";); context& pctx = *pctxs[i]; ast_manager& pm = *pms[i]; + expr_ref_vector lasms(pasms[i]); + expr_ref c(pm); + + pctx.get_fparams().m_max_conflicts = max_conflicts; + if (num_iterations > 0) { + cube(pctx, lasms, c); + } + lbool r = pctx.check(lasms.size(), lasms.c_ptr()); + + if (r == l_undef && pctx.m_num_conflicts >= max_conflicts) { + return; + } + + if (r == l_false && pctx.unsat_core().contains(c)) { + pctx.assert_expr(mk_not(mk_and(pctx.unsat_core()))); + return; + } - lbool r = pctx.check(pasms[i].size(), pasms[i].c_ptr()); bool first = false; { std::lock_guard lock(mux); @@ -87,6 +130,7 @@ namespace smt { finished_id = i; first = true; result = r; + done = true; } } if (!first) return; @@ -106,14 +150,20 @@ namespace smt { } }; - vector threads(num_threads); - for (unsigned i = 0; i < num_threads; ++i) { - threads[i] = std::thread([&, i]() { worker_thread(i); }); - } - for (auto & th : threads) { - th.join(); - } + while (!done) { + vector threads(num_threads); + for (unsigned i = 0; i < num_threads; ++i) { + threads[i] = std::thread([&, i]() { worker_thread(i); }); + } + for (auto & th : threads) { + th.join(); + } + if (done) break; + collect_units(); + ++num_rounds; + max_conflicts *= 2; + } for (context* c : pctxs) { c->collect_statistics(ctx.m_aux_stats); @@ -137,7 +187,7 @@ namespace smt { } break; case l_false: - for (expr* e : pctx.m_unsat_core) + for (expr* e : pctx.unsat_core()) ctx.m_unsat_core.push_back(tr(e)); break; default: diff --git a/src/smt/smt_parallel.h b/src/smt/smt_parallel.h index 65f66bfc4..07b04019d 100644 --- a/src/smt/smt_parallel.h +++ b/src/smt/smt_parallel.h @@ -24,18 +24,11 @@ namespace smt { class parallel { context& ctx; - expr_ref_vector m_unit_trail; - obj_hashtable m_unit_set; - unsigned_vector m_unit_lim; - std::mutex m_mux; public: - parallel(context& ctx): ctx(ctx), m_unit_trail(ctx.m) {} + parallel(context& ctx): ctx(ctx) {} lbool operator()(expr_ref_vector const& asms); - void add_unit(context& ctx, expr* e); - - void get_units(unsigned idx, context& pctx); }; }