diff --git a/src/smt/smt_parallel.cpp b/src/smt/smt_parallel.cpp index c05e03a50..72bda6c0a 100644 --- a/src/smt/smt_parallel.cpp +++ b/src/smt/smt_parallel.cpp @@ -98,8 +98,8 @@ namespace smt { // TODO: can share lemmas here, such as new units and not(and(unsat_core)), binary clauses, etc. // TODO: remember assumptions used in core so that they get used for the final core. IF_VERBOSE(0, verbose_stream() << "Worker " << id << " found unsat cube\n"); - b.share_lemma(l2g, mk_not(mk_and(unsat_core))); - // share_units(); + b.collect_clause(l2g, id, mk_not(mk_and(unsat_core))); + share_units(l2g); break; } } @@ -118,55 +118,53 @@ namespace smt { ctx->set_random_seed(id + m_smt_params.m_random_seed); } - void parallel::worker::share_units() { - // obj_hashtable unit_set; - // expr_ref_vector unit_trail(ctx.m); - // unsigned_vector unit_lim; - // for (unsigned i = 0; i < num_threads; ++i) unit_lim.push_back(0); - - // // we just want to share lemmas and have a way of remembering how they are shared -- this is the next step - // // (this needs to be reworked) - // std::function collect_units = [&,this]() { - // //return; -- has overhead - // for (unsigned i = 0; i < num_threads; ++i) { - // context& pctx = *pctxs[i]; - // pctx.pop_to_base_lvl(); - // ast_translation tr(pctx.m, ctx.m); - // unsigned sz = pctx.assigned_literals().size(); - // for (unsigned j = unit_lim[i]; j < sz; ++j) { - // literal lit = pctx.assigned_literals()[j]; - // //IF_VERBOSE(0, verbose_stream() << "(smt.thread " << i << " :unit " << lit << " " << pctx.is_relevant(lit.var()) << ")\n";); - // if (!pctx.is_relevant(lit.var())) - // continue; - // expr_ref e(pctx.bool_var2expr(lit.var()), pctx.m); - // if (lit.sign()) e = pctx.m.mk_not(e); - // expr_ref ce(tr(e.get()), ctx.m); - // if (!unit_set.contains(ce)) { - // unit_set.insert(ce); - // unit_trail.push_back(ce); - // } - // } - // } - - // unsigned sz = unit_trail.size(); - // for (unsigned i = 0; i < num_threads; ++i) { - // context& pctx = *pctxs[i]; - // ast_translation tr(ctx.m, pctx.m); - // for (unsigned j = unit_lim[i]; j < sz; ++j) { - // expr_ref src(ctx.m), dst(pctx.m); - // dst = tr(unit_trail.get(j)); - // pctx.assert_expr(dst); // Assert that the conjunction of the assumptions in this unsat core is not satisfiable — prune it from future search - // } - // unit_lim[i] = pctx.assigned_literals().size(); - // } - // IF_VERBOSE(1, verbose_stream() << "(smt.thread :units " << sz << ")\n"); - // }; + void parallel::worker::share_units(ast_translation& l2g) { + // Collect new units learned locally by this worker and send to batch manager + unsigned sz = ctx->assigned_literals().size(); + for (unsigned j = shared_clause_limit; j < sz; ++j) { // iterate only over new literals since last sync -- QUESTION: I THINK THIS IS BUGGY BECAUSE THE SHARED CLAUSE LIMIT IS ONLY UPDATED (FOR ALL CLAUSE TYPES) WHEN WE GATHER NEW SHARED UNITS + literal lit = ctx->assigned_literals()[j]; + expr_ref e(ctx->bool_var2expr(lit.var()), ctx->m); // turn literal into a Boolean expression + if (lit.sign()) + e = ctx->m.mk_not(e); // negate if literal is negative + b.collect_clause(l2g, id, e); + } } - void parallel::batch_manager::share_lemma(ast_translation& l2g, expr* lemma) { + void parallel::batch_manager::collect_clause(ast_translation& l2g, unsigned source_worker_id, expr* clause) { std::scoped_lock lock(mux); - expr_ref g_lemma(l2g(lemma), l2g.to()); - p.ctx.assert_expr(g_lemma); // QUESTION: where does this get shared with the local thread contexts? -- doesn't right now, we will build the scaffolding for this later! + expr* g_clause = l2g(clause); + if (!shared_clause_set.contains(g_clause)) { + shared_clause_set.insert(g_clause); + SharedClause sc{source_worker_id, g_clause}; + shared_clause_trail.push_back(sc); + } + } + + // QUESTION -- WHERE SHOULD WE CALL THIS? + void parallel::worker::collect_shared_clauses(ast_translation& g2l) { + expr_ref_vector new_clauses = b.return_shared_clauses(g2l, shared_clause_limit, id); // get new clauses from the batch manager + // iterate over new clauses and assert them in the local context + for (expr* e : new_clauses) { + expr_ref local_clause(e, g2l.to()); // e was already translated to the local context in the batch manager!! + ctx->assert_expr(local_clause); // assert the clause in the local context + IF_VERBOSE(0, verbose_stream() << "Worker " << id << " asserting shared clause: " << mk_bounded_pp(local_clause, m, 3) << "\n"); + } + } + + // get new clauses from the batch manager and assert them in the local context + expr_ref_vector parallel::batch_manager::return_shared_clauses(ast_translation& g2l, unsigned& worker_limit, unsigned worker_id) { + expr_ref_vector result(g2l.to()); + { + std::scoped_lock lock(mux); + for (unsigned i = worker_limit; i < shared_clause_trail.size(); ++i) { + if (shared_clause_trail[i].source_worker_id == worker_id) + continue; // skip clauses from the requesting worker + expr_ref local_clause(g2l(shared_clause_trail[i].clause), g2l.to()); + result.push_back(local_clause); + } + worker_limit = shared_clause_trail.size(); // update the worker limit to the end of the current trail + } + return result; } lbool parallel::worker::check_cube(expr_ref_vector const& cube) { diff --git a/src/smt/smt_parallel.h b/src/smt/smt_parallel.h index 43380c27f..a52943763 100644 --- a/src/smt/smt_parallel.h +++ b/src/smt/smt_parallel.h @@ -27,6 +27,11 @@ namespace smt { context& ctx; unsigned num_threads; + struct SharedClause { + unsigned source_worker_id; + expr* clause; + }; + class batch_manager { enum state { is_running, @@ -45,12 +50,14 @@ namespace smt { unsigned m_max_batch_size = 10; unsigned m_exception_code = 0; std::string m_exception_msg; + std::vector shared_clause_trail; // store all shared clauses with worker IDs + obj_hashtable shared_clause_set; // for duplicate filtering on per-thread clause expressions // called from batch manager to cancel other workers if we've reached a verdict void cancel_workers() { IF_VERBOSE(0, verbose_stream() << "Canceling workers\n"); for (auto& w : p.m_workers) - w->cancel(); + w->cancel(); } public: @@ -76,7 +83,8 @@ namespace smt { // void return_cubes(ast_translation& l2g, vectorconst& cubes, expr_ref_vector const& split_atoms); void report_assumption_used(ast_translation& l2g, expr* assumption); - void share_lemma(ast_translation& l2g, expr* lemma); + void collect_clause(ast_translation& l2g, unsigned source_worker_id, expr* e); + expr_ref_vector return_shared_clauses(ast_translation& g2l, unsigned& worker_limit, unsigned worker_id); lbool get_result() const; }; @@ -90,12 +98,15 @@ namespace smt { scoped_ptr ctx; unsigned m_max_conflicts = 100; unsigned m_num_shared_units = 0; - void share_units(); + unsigned shared_clause_limit = 0; // remembers the index into shared_clause_trail marking the boundary between "old" and "new" clauses to share + void share_units(ast_translation& l2g); lbool check_cube(expr_ref_vector const& cube); public: worker(unsigned id, parallel& p, expr_ref_vector const& _asms); void run(); expr_ref_vector get_split_atoms(); + void collect_shared_clauses(ast_translation& g2l); + void cancel() { IF_VERBOSE(0, verbose_stream() << "Worker " << id << " canceling\n"); m.limit().cancel(); @@ -122,7 +133,6 @@ namespace smt { m_batch_manager(ctx.m, *this) {} lbool operator()(expr_ref_vector const& asms); - }; }