diff --git a/src/smt/smt_context.h b/src/smt/smt_context.h index 08ee9800f..8e548e678 100644 --- a/src/smt/smt_context.h +++ b/src/smt/smt_context.h @@ -51,6 +51,7 @@ Revision History: #include "solver/progress_callback.h" #include "solver/assertions/asserted_formulas.h" #include "smt/priority_queue.h" +#include "util/dlist.h" #include // there is a significant space overhead with allocating 1000+ contexts in @@ -191,6 +192,13 @@ namespace smt { svector m_bdata; //!< mapping bool_var -> data svector m_activity; updatable_priority_queue::priority_queue m_pq_scores; + + struct lit_node : dll_base { + literal lit; + lit_node(literal l) : lit(l) { init(this); } + }; + lit_node* m_dll_lits; + svector> m_lit_scores; clause_vector m_aux_clauses; @@ -908,6 +916,8 @@ namespace smt { void add_or_rel_watches(app * n); + void add_implies_rel_watches(app* n); + void add_ite_rel_watches(app * n); void mk_not_cnstr(app * n); @@ -916,6 +926,8 @@ namespace smt { void mk_or_cnstr(app * n); + void mk_implies_cnstr(app* n); + void mk_iff_cnstr(app * n, bool sign); void mk_ite_cnstr(app * n); diff --git a/src/smt/smt_internalizer.cpp b/src/smt/smt_internalizer.cpp index fdccea4e7..ecb56e516 100644 --- a/src/smt/smt_internalizer.cpp +++ b/src/smt/smt_internalizer.cpp @@ -696,6 +696,10 @@ namespace smt { mk_or_cnstr(to_app(n)); add_or_rel_watches(to_app(n)); break; + case OP_IMPLIES: + mk_implies_cnstr(to_app(n)); + add_implies_rel_watches(to_app(n)); + break; case OP_EQ: if (m.is_iff(n)) mk_iff_cnstr(to_app(n), false); @@ -711,8 +715,7 @@ namespace smt { mk_iff_cnstr(to_app(n), true); break; case OP_DISTINCT: - case OP_IMPLIES: - throw default_exception("formula has not been simplified"); + throw default_exception(std::string("formula has not been simplified") + " : " + mk_pp(n, m)); case OP_OEQ: UNREACHABLE(); default: @@ -1712,6 +1715,14 @@ namespace smt { } } + void context::add_implies_rel_watches(app* n) { + if (relevancy()) { + relevancy_eh* eh = m_relevancy_propagator->mk_implies_relevancy_eh(n); + add_rel_watch(~get_literal(n->get_arg(0)), eh); + add_rel_watch(get_literal(n->get_arg(1)), eh); + } + } + void context::add_ite_rel_watches(app * n) { if (relevancy()) { relevancy_eh * eh = m_relevancy_propagator->mk_ite_relevancy_eh(n); @@ -1758,9 +1769,24 @@ namespace smt { mk_gate_clause(buffer.size(), buffer.data()); } + void context::mk_implies_cnstr(app* n) { + literal l = get_literal(n); + literal_buffer buffer; + buffer.push_back(~l); + auto arg1 = n->get_arg(0); + literal l_arg1 = get_literal(arg1); + mk_gate_clause(l, l_arg1); + buffer.push_back(~l_arg1); + auto arg2 = n->get_arg(1); + literal l_arg2 = get_literal(arg2); + mk_gate_clause(l, ~l_arg2); + buffer.push_back(l_arg2); + mk_gate_clause(buffer.size(), buffer.data()); + } + void context::mk_iff_cnstr(app * n, bool sign) { if (n->get_num_args() != 2) - throw default_exception("formula has not been simplified"); + throw default_exception(std::string("formula has not been simplified") + " : " + mk_pp(n, m)); literal l = get_literal(n); literal l1 = get_literal(n->get_arg(0)); literal l2 = get_literal(n->get_arg(1)); diff --git a/src/smt/smt_lookahead.cpp b/src/smt/smt_lookahead.cpp index 221c2d0ea..eb4f96320 100644 --- a/src/smt/smt_lookahead.cpp +++ b/src/smt/smt_lookahead.cpp @@ -72,9 +72,14 @@ namespace smt { svector vars; for (bool_var v = 0; v < static_cast(sz); ++v) { expr* b = ctx.bool_var2expr(v); - if (b && ctx.get_assignment(v) == l_undef) { - vars.push_back(v); - } + if (!b) + continue; + if (ctx.get_assignment(v) != l_undef) + continue; + if (m.is_and(b) || m.is_or(b) || m.is_not(b) || m.is_ite(b) || m.is_implies(b) || m.is_iff(b) || m.is_xor(b)) + continue; // do not choose connectives + vars.push_back(v); + } compare comp(ctx); std::sort(vars.begin(), vars.end(), comp); diff --git a/src/smt/smt_parallel.cpp b/src/smt/smt_parallel.cpp index feea6fc17..ce8b699aa 100644 --- a/src/smt/smt_parallel.cpp +++ b/src/smt/smt_parallel.cpp @@ -36,6 +36,7 @@ namespace smt { #else #include +#include namespace smt { @@ -77,13 +78,16 @@ namespace smt { throw default_exception("trace streams have to be off in parallel mode"); + params_ref params = ctx.get_params(); for (unsigned i = 0; i < num_threads; ++i) { smt_params.push_back(ctx.get_fparams()); + smt_params.back().m_preprocess = false; } + for (unsigned i = 0; i < num_threads; ++i) { ast_manager* new_m = alloc(ast_manager, m, true); pms.push_back(new_m); - pctxs.push_back(alloc(context, *new_m, smt_params[i], ctx.get_params())); + pctxs.push_back(alloc(context, *new_m, smt_params[i], params)); context& new_ctx = *pctxs.back(); context::copy(ctx, new_ctx, true); new_ctx.set_random_seed(i + ctx.get_fparams().m_random_seed); @@ -103,8 +107,9 @@ namespace smt { } }; - auto cube_batch_pq = [&](context& ctx, expr_ref_vector& lasms, expr_ref& c) { - unsigned k = 4; // Number of top literals you want + auto cube_pq = [&](context& ctx, expr_ref_vector& lasms, expr_ref& c) { + unsigned k = 3; // Number of top literals you want + ast_manager& m = ctx.get_manager(); // Get the entire fixed-size priority queue (it's not that big) @@ -134,12 +139,11 @@ namespace smt { lasms.push_back(c); }; - auto cube_batch = [&](context& ctx, expr_ref_vector& lasms, expr_ref& c) { - std::vector> candidates; + auto cube_score = [&](context& ctx, expr_ref_vector& lasms, expr_ref& c) { + vector> candidates; unsigned k = 4; // Get top-k scoring literals ast_manager& m = ctx.get_manager(); - // std::cout << ctx.m_bool_var2expr.size() << std::endl; // Prints the size of m_bool_var2expr // Loop over first 100 Boolean vars for (bool_var v = 0; v < 100; ++v) { if (ctx.get_assignment(v) != l_undef) continue; @@ -151,7 +155,7 @@ namespace smt { double score = ctx.get_score(lit); if (score == 0.0) continue; - candidates.emplace_back(expr_ref(e, m), score); + candidates.push_back(std::make_pair(expr_ref(e, m), score)); } // Sort all candidate literals descending by score @@ -179,6 +183,7 @@ namespace smt { for (unsigned i = 0; i < num_threads; ++i) unit_lim.push_back(0); std::function collect_units = [&,this]() { + //return; -- has overhead for (unsigned i = 0; i < num_threads; ++i) { context& pctx = *pctxs[i]; pctx.pop_to_base_lvl(); @@ -203,7 +208,7 @@ namespace smt { for (unsigned j = unit_lim[i]; j < sz; ++j) { expr_ref src(ctx.m), dst(pctx.m); dst = tr(unit_trail.get(j)); - pctx.assert_expr(dst); + pctx.assert_expr(dst); // Assert that the conjunction of the assumptions in this unsat core is not satisfiable — prune it from future search } unit_lim[i] = pctx.assigned_literals().size(); } @@ -213,7 +218,7 @@ namespace smt { std::mutex mux; // Lambda defining the work each SMT thread performs - auto worker_thread = [&](int i) { + auto worker_thread = [&](int i, vector& cube_batch) { try { // Get thread-specific context and AST manager context& pctx = *pctxs[i]; @@ -221,36 +226,68 @@ namespace smt { // Initialize local assumptions and cube expr_ref_vector lasms(pasms[i]); - expr_ref c(pm); - // Set the max conflict limit for this thread - pctx.get_fparams().m_max_conflicts = std::min(thread_max_conflicts, max_conflicts); + vector results; + for (expr_ref_vector& cube : cube_batch) { + expr_ref_vector lasms_copy(lasms); - // Periodically generate cubes based on frequency - if (num_rounds > 0 && (num_rounds % pctx.get_fparams().m_threads_cube_frequency) == 0) - cube_batch(pctx, lasms, c); + if (&cube.get_manager() != &pm) { + std::cerr << "Manager mismatch on cube: " << mk_bounded_pp(mk_and(cube), pm, 3) << "\n"; + UNREACHABLE(); // or throw + } - // Optional verbose logging - IF_VERBOSE(1, verbose_stream() << "(smt.thread " << i; - if (num_rounds > 0) verbose_stream() << " :round " << num_rounds; - if (c) verbose_stream() << " :cube " << mk_bounded_pp(c, pm, 3); - verbose_stream() << ")\n";); + for (expr* cube_lit : cube) { + lasms_copy.push_back(expr_ref(cube_lit, pm)); + } - // Check satisfiability of assumptions - lbool r = pctx.check(lasms.size(), lasms.data()); + // Set the max conflict limit for this thread + pctx.get_fparams().m_max_conflicts = std::min(thread_max_conflicts, max_conflicts); + + // Optional verbose logging + IF_VERBOSE(1, verbose_stream() << "(smt.thread " << i; + if (num_rounds > 0) verbose_stream() << " :round " << num_rounds; + verbose_stream() << " :cube " << mk_bounded_pp(mk_and(cube), pm, 3); + verbose_stream() << ")\n";); + + lbool r = pctx.check(lasms_copy.size(), lasms_copy.data()); + std::cout << "Thread " << i << " finished cube " << mk_bounded_pp(mk_and(cube), pm, 3) << " with result: " << r << "\n"; + results.push_back(r); + } + + lbool r = l_false; + for (lbool res : results) { + if (res == l_true) { + r = l_true; + } else if (res == l_undef) { + if (r == l_false) + r = l_undef; + } + } + + auto cube_intersects_core = [&](expr* cube, const expr_ref_vector &core) { + expr_ref_vector cube_lits(pctx.m); + flatten_and(cube, cube_lits); + for (expr* lit : cube_lits) + if (core.contains(lit)) + return true; + return false; + }; // Handle results based on outcome and conflict count if (r == l_undef && pctx.m_num_conflicts >= max_conflicts) ; // no-op, allow loop to continue else if (r == l_undef && pctx.m_num_conflicts >= thread_max_conflicts) return; // quit thread early - - // If cube was unsat and it's in the core, learn from it - else if (r == l_false && pctx.unsat_core().contains(c)) { - IF_VERBOSE(1, verbose_stream() << "(smt.thread " << i << " :learn " << mk_bounded_pp(c, pm, 3) << ")"); - pctx.assert_expr(mk_not(mk_and(pctx.unsat_core()))); - return; - } + // If cube was unsat and it's in the core, learn from it. i.e. a thread can be UNSAT because the cube c contradicted F. In this case learn the negation of the cube ¬c + // else if (r == l_false) { + // // IF_VERBOSE(1, verbose_stream() << "(smt.thread " << i << " :learn cube batch " << mk_bounded_pp(cube, pm, 3) << ")" << " unsat_core: " << pctx.unsat_core() << ")"); + // for (expr* cube : cube_batch) { // iterate over each cube in the batch + // if (cube_intersects_core(cube, pctx.unsat_core())) { + // // IF_VERBOSE(1, verbose_stream() << "(pruning cube: " << mk_bounded_pp(cube, pm, 3) << " given unsat core: " << pctx.unsat_core() << ")"); + // pctx.assert_expr(mk_not(mk_and(pctx.unsat_core()))); + // } + // } + // } // Begin thread-safe update of shared result state bool first = false; @@ -273,7 +310,6 @@ namespace smt { for (ast_manager* m : pms) { if (m != &pm) m->limit().cancel(); } - } catch (z3_error & err) { if (finished_id == UINT_MAX) { error_code = err.error_code(); @@ -295,14 +331,142 @@ namespace smt { } }; + struct BatchManager { + std::mutex mtx; + vector> batches; + unsigned batch_idx = 0; + unsigned batch_size = 1; + + BatchManager(unsigned batch_size) : batch_size(batch_size) {} + + // translate the next SINGLE batch of batch_size cubes to the thread + vector get_next_batch( + ast_manager &main_ctx_m, + ast_manager &thread_m + ) { + std::lock_guard lock(mtx); + vector cube_batch; // ensure bound to thread manager + if (batch_idx >= batches.size()) return cube_batch; + + vector next_batch = batches[batch_idx]; + + for (const expr_ref_vector& cube : next_batch) { + expr_ref_vector translated_cube_lits(thread_m); + for (expr* lit : cube) { + // Translate each literal to the thread's manager + translated_cube_lits.push_back(translate(lit, main_ctx_m, thread_m)); + } + cube_batch.push_back(translated_cube_lits); + } + + ++batch_idx; + + return cube_batch; + } + + // returns a list (vector) of cubes, where each cube is an expr_ref_vector of literals + vector cube_batch_pq(context& ctx) { + unsigned k = 1; // generates 2^k cubes in the batch + ast_manager& m = ctx.get_manager(); + + auto candidates = ctx.m_pq_scores.get_heap(); + std::sort(candidates.begin(), candidates.end(), + [](const auto& a, const auto& b) { return a.priority > b.priority; }); + + expr_ref_vector top_lits(m); + for (const auto& node : candidates) { + if (ctx.get_assignment(node.key) != l_undef) continue; + + expr* e = ctx.bool_var2expr(node.key); + if (!e) continue; + + top_lits.push_back(expr_ref(e, m)); + if (top_lits.size() >= k) break; + } + + // std::cout << "Top lits:\n"; + // for (unsigned j = 0; j < top_lits.size(); ++j) { + // std::cout << " [" << j << "] " << mk_pp(top_lits[j].get(), m) << "\n"; + // } + + unsigned num_lits = top_lits.size(); + unsigned num_cubes = 1 << num_lits; // 2^num_lits combinations + + vector cube_batch; + + for (unsigned mask = 0; mask < num_cubes; ++mask) { + expr_ref_vector cube_lits(m); + for (unsigned i = 0; i < num_lits; ++i) { + expr_ref lit(top_lits[i].get(), m); + if ((mask >> i) & 1) + cube_lits.push_back(mk_not(lit)); + else + cube_lits.push_back(lit); + } + cube_batch.push_back(cube_lits); + } + + std::cout << "Cubes out:\n"; + for (size_t j = 0; j < cube_batch.size(); ++j) { + std::cout << " [" << j << "]\n"; + for (size_t k = 0; k < cube_batch[j].size(); ++k) { + std::cout << " [" << k << "] " << mk_pp(cube_batch[j][k].get(), m) << "\n"; + } + } + + return cube_batch; + }; + + // returns a vector of new cubes batches. each cube batch is a vector of expr_ref_vector cubes + vector> gen_new_batches(context& main_ctx) { + vector> cube_batches; + + // Get all cubes in the main context's manager + vector all_cubes = cube_batch_pq(main_ctx); + + ast_manager &m = main_ctx.get_manager(); + + // Partition into batches + for (unsigned start = 0; start < all_cubes.size(); start += batch_size) { + vector batch; + + unsigned end = std::min(start + batch_size, all_cubes.size()); + for (unsigned j = start; j < end; ++j) { + batch.push_back(all_cubes[j]); + } + + cube_batches.push_back(batch); + } + batch_idx = 0; // Reset index for next round + return cube_batches; + } + + void check_for_new_batches(context& main_ctx) { + std::lock_guard lock(mtx); + if (batch_idx >= batches.size()) { + batches = gen_new_batches(main_ctx); + } + } + }; + + BatchManager batch_manager(1); + // Thread scheduling loop while (true) { vector threads(num_threads); + batch_manager.check_for_new_batches(ctx); // Launch threads for (unsigned i = 0; i < num_threads; ++i) { // [&, i] is the lambda's capture clause: capture all variables by reference (&) except i, which is captured by value. - threads[i] = std::thread([&, i]() { worker_thread(i); }); + threads[i] = std::thread([&, i]() { + while (!done) { + auto next_batch = batch_manager.get_next_batch(ctx.m, *pms[i]); + if (next_batch.empty()) break; // No more work + + worker_thread(i, next_batch); + } + }); } // Wait for all threads to finish @@ -317,7 +481,7 @@ namespace smt { collect_units(); ++num_rounds; max_conflicts = (max_conflicts < thread_max_conflicts) ? 0 : (max_conflicts - thread_max_conflicts); - thread_max_conflicts *= 2; + thread_max_conflicts *= 2; } // Gather statistics from all solver contexts diff --git a/src/smt/smt_relevancy.cpp b/src/smt/smt_relevancy.cpp index f7ba3dcce..48fa3657d 100644 --- a/src/smt/smt_relevancy.cpp +++ b/src/smt/smt_relevancy.cpp @@ -62,6 +62,13 @@ namespace smt { void operator()(relevancy_propagator & rp) override; }; + class implies_relevancy_eh : public relevancy_eh { + app* m_parent; + public: + implies_relevancy_eh(app* p) :m_parent(p) {} + void operator()(relevancy_propagator& rp) override; + }; + class ite_relevancy_eh : public relevancy_eh { app * m_parent; public: @@ -108,6 +115,11 @@ namespace smt { return mk_relevancy_eh(or_relevancy_eh(n)); } + relevancy_eh* relevancy_propagator::mk_implies_relevancy_eh(app* n) { + SASSERT(get_manager().is_implies(n)); + return mk_relevancy_eh(implies_relevancy_eh(n)); + } + relevancy_eh * relevancy_propagator::mk_and_relevancy_eh(app * n) { SASSERT(get_manager().is_and(n)); return mk_relevancy_eh(and_relevancy_eh(n)); @@ -357,8 +369,38 @@ namespace smt { --j; mark_as_relevant(n->get_arg(j)); } - } + } + void propagate_relevant_implies(app* n) { + SASSERT(get_manager().is_implies(n)); + lbool val = m_context.find_assignment(n); + // If val is l_undef, then the expression + // is a root, and no boolean variable was created for it. + if (val == l_undef) + val = l_true; + switch (val) { + case l_false: + propagate_relevant_app(n); + break; + case l_undef: + break; + case l_true: { + expr* true_arg = nullptr; + auto arg0 = n->get_arg(0); + auto arg1 = n->get_arg(1); + if (m_context.find_assignment(arg0) == l_false) { + if (!is_relevant_core(arg0)) + mark_as_relevant(arg0); + return; + } + if (m_context.find_assignment(arg1) == l_true) { + if (!is_relevant_core(arg1)) + mark_as_relevant(arg1); + return; + } + } + } + } /** \brief Propagate relevancy for an or-application. */ @@ -470,6 +512,9 @@ namespace smt { case OP_AND: propagate_relevant_and(to_app(n)); break; + case OP_IMPLIES: + propagate_relevant_implies(to_app(n)); + break; case OP_ITE: propagate_relevant_ite(to_app(n)); break; @@ -505,6 +550,8 @@ namespace smt { propagate_relevant_or(to_app(n)); else if (m.is_and(n)) propagate_relevant_and(to_app(n)); + else if (m.is_implies(n)) + propagate_relevant_implies(to_app(n)); } relevancy_ehs * ehs = get_watches(n, val); while (ehs != nullptr) { @@ -644,6 +691,11 @@ namespace smt { static_cast(rp).propagate_relevant_or(m_parent); } + void implies_relevancy_eh::operator()(relevancy_propagator& rp) { + if (rp.is_relevant(m_parent)) + static_cast(rp).propagate_relevant_implies(m_parent); + } + void ite_relevancy_eh::operator()(relevancy_propagator & rp) { if (rp.is_relevant(m_parent)) { static_cast(rp).propagate_relevant_ite(m_parent); diff --git a/src/smt/smt_relevancy.h b/src/smt/smt_relevancy.h index 8dea2842f..4827fffcb 100644 --- a/src/smt/smt_relevancy.h +++ b/src/smt/smt_relevancy.h @@ -188,6 +188,7 @@ namespace smt { void add_dependency(expr * src, expr * target); relevancy_eh * mk_or_relevancy_eh(app * n); + relevancy_eh* mk_implies_relevancy_eh(app* n); relevancy_eh * mk_and_relevancy_eh(app * n); relevancy_eh * mk_ite_relevancy_eh(app * n); relevancy_eh * mk_term_ite_relevancy_eh(app * c, app * t, app * e);