From 2c188a525ec766deea296aea27bd0edece9a10f5 Mon Sep 17 00:00:00 2001 From: Ilana Shapiro Date: Tue, 29 Jul 2025 16:45:38 -0700 Subject: [PATCH] set up worker thread batch manager for multithreaded batch cubes paradigm, need to debug as I am getting segfault still --- src/smt/smt_context.h | 6 ++ src/smt/smt_parallel.cpp | 146 +++++++++++++++++++++++++++++++++------ 2 files changed, 132 insertions(+), 20 deletions(-) diff --git a/src/smt/smt_context.h b/src/smt/smt_context.h index 4c46b2bbc..5f5ae9af4 100644 --- a/src/smt/smt_context.h +++ b/src/smt/smt_context.h @@ -51,6 +51,7 @@ Revision History: #include "solver/progress_callback.h" #include "solver/assertions/asserted_formulas.h" #include "smt/priority_queue.h" +#include "util/dlist.h" #include // there is a significant space overhead with allocating 1000+ contexts in @@ -191,6 +192,11 @@ namespace smt { svector m_bdata; //!< mapping bool_var -> data svector m_activity; updatable_priority_queue::priority_queue m_pq_scores; + struct lit_node : dll_base { + literal lit; + lit_node(literal l) : lit(l) { init(this); } + }; + lit_node* m_dll_lits; svector> m_lit_scores; clause_vector m_aux_clauses; clause_vector m_lemmas; diff --git a/src/smt/smt_parallel.cpp b/src/smt/smt_parallel.cpp index 04dfe7310..ab795a927 100644 --- a/src/smt/smt_parallel.cpp +++ b/src/smt/smt_parallel.cpp @@ -102,8 +102,8 @@ namespace smt { } }; - auto cube_batch_pq = [&](context& ctx, expr_ref_vector& lasms, expr_ref& c) { - unsigned k = 4; // Number of top literals you want + auto cube_pq = [&](context& ctx, expr_ref_vector& lasms, expr_ref& c) { + unsigned k = 3; // Number of top literals you want ast_manager& m = ctx.get_manager(); // Get the entire fixed-size priority queue (it's not that big) @@ -132,7 +132,7 @@ namespace smt { lasms.push_back(c); }; - auto cube_batch = [&](context& ctx, expr_ref_vector& lasms, expr_ref& c) { + auto cube_score = [&](context& ctx, expr_ref_vector& lasms, expr_ref& c) { std::vector> candidates; unsigned k = 4; // Get top-k scoring literals ast_manager& m = ctx.get_manager(); @@ -201,7 +201,7 @@ namespace smt { for (unsigned j = unit_lim[i]; j < sz; ++j) { expr_ref src(ctx.m), dst(pctx.m); dst = tr(unit_trail.get(j)); - pctx.assert_expr(dst); + pctx.assert_expr(dst); // Assert that the conjunction of the assumptions in this unsat core is not satisfiable — prune it from future search } unit_lim[i] = pctx.assigned_literals().size(); } @@ -211,43 +211,47 @@ namespace smt { std::mutex mux; // Lambda defining the work each SMT thread performs - auto worker_thread = [&](int i) { + auto worker_thread = [&](int i, expr_ref_vector cube_batch) { try { + std::cout << "Starting thread " << i <<"\n"; // Get thread-specific context and AST manager context& pctx = *pctxs[i]; ast_manager& pm = *pms[i]; // Initialize local assumptions and cube expr_ref_vector lasms(pasms[i]); - expr_ref c(pm); + expr_ref c(mk_or(cube_batch), pm); + lasms.push_back(c); // <-- add cube to assumptions // Set the max conflict limit for this thread pctx.get_fparams().m_max_conflicts = std::min(thread_max_conflicts, max_conflicts); - // Periodically generate cubes based on frequency - if (num_rounds > 0 && (num_rounds % pctx.get_fparams().m_threads_cube_frequency) == 0) - cube_batch(pctx, lasms, c); - // Optional verbose logging IF_VERBOSE(1, verbose_stream() << "(smt.thread " << i; if (num_rounds > 0) verbose_stream() << " :round " << num_rounds; if (c) verbose_stream() << " :cube " << mk_bounded_pp(c, pm, 3); verbose_stream() << ")\n";); - // Check satisfiability of assumptions + auto intersects = [&](const expr_ref_vector &a, const expr_ref_vector &b) { + for (expr *e : a) { + if (b.contains(e)) return true; + } + return false; + }; + lbool r = pctx.check(lasms.size(), lasms.data()); // Handle results based on outcome and conflict count if (r == l_undef && pctx.m_num_conflicts >= max_conflicts) ; // no-op, allow loop to continue else if (r == l_undef && pctx.m_num_conflicts >= thread_max_conflicts) - return; // quit thread early - - // If cube was unsat and it's in the core, learn from it - else if (r == l_false && pctx.unsat_core().contains(c)) { + return r; // quit thread early + // If cube was unsat and it's in the core, learn from it. i.e. a thread can be UNSAT because the cube c contradicted F. In this case learn the negation of the cube ¬c + // TAKE THE INTERSECTION INSTEAD OF CHECKING MEMBERSHIP, SEE WHITEBOARD NOTES + else if (r == l_false && intersects(cube_batch, pctx.unsat_core())) { // pctx.unsat_core().contains(c)) { THIS IS THE VERSION FOR SINGLE LITERAL CUBES IF_VERBOSE(1, verbose_stream() << "(smt.thread " << i << " :learn " << mk_bounded_pp(c, pm, 3) << ")"); pctx.assert_expr(mk_not(mk_and(pctx.unsat_core()))); - return; + return r; } // Begin thread-safe update of shared result state @@ -264,7 +268,7 @@ namespace smt { finished_id = i; result = r; } - else if (!first) return; // nothing new to contribute + else if (!first) return r; // nothing new to contribute } // Cancel limits on other threads now that a result is known @@ -272,6 +276,7 @@ namespace smt { if (m != &pm) m->limit().cancel(); } + return r; } catch (z3_error & err) { if (finished_id == UINT_MAX) { error_code = err.error_code(); @@ -291,16 +296,117 @@ namespace smt { done = true; } } + return l_undef; // Return undef if an exception occurred }; + struct BatchManager { + std::mutex mtx; + std::vector batches; + size_t batch_idx = 0; + size_t batch_size = 1; // num batches + + BatchManager(size_t batch_size) : batch_size(batch_size) {} + + // translate the next SINGLE batch of batch_size cubes to the thread + expr_ref_vector get_next_batch( + ast_manager &main_ctx_m, + ast_manager &thread_m + ) { + std::lock_guard lock(mtx); + expr_ref_vector cube_batch(thread_m); // ensure bound to thread manager + + for (expr* cube : batches[batch_idx]) { + cube_batch.push_back( + expr_ref(translate(cube, main_ctx_m, thread_m), thread_m) + ); + } + + ++batch_idx; + return cube_batch; + } + + expr_ref_vector cube_batch_pq(context& ctx) { + unsigned k = 3; // generates 2^k cubes in the batch + ast_manager& m = ctx.get_manager(); + + auto candidates = ctx.m_pq_scores.get_heap(); + std::sort(candidates.begin(), candidates.end(), + [](const auto& a, const auto& b) { return a.priority > b.priority; }); + + expr_ref_vector top_lits(m); + for (const auto& node : candidates) { + if (ctx.get_assignment(node.key) != l_undef) continue; + + expr* e = ctx.bool_var2expr(node.key); + if (!e) continue; + + top_lits.push_back(expr_ref(e, m)); + if (top_lits.size() >= k) break; + } + + std::cout << "Top lits:\n"; + for (size_t j = 0; j < top_lits.size(); ++j) { + std::cout << " [" << j << "] " << top_lits[j].get() << "\n"; + } + + unsigned num_lits = top_lits.size(); + unsigned num_cubes = 1 << num_lits; // 2^num_lits combinations + + expr_ref_vector cube_batch(m); + + for (unsigned mask = 0; mask < num_cubes; ++mask) { + expr_ref_vector cube_conj(m); + for (unsigned i = 0; i < num_lits; ++i) { + expr_ref lit(top_lits[i].get(), m); + if ((mask >> i) & 1) + cube_conj.push_back(mk_not(lit)); + else + cube_conj.push_back(lit); + } + cube_batch.push_back(mk_and(cube_conj)); + } + + std::cout << "Cubes out:\n"; + for (size_t j = 0; j < cube_batch.size(); ++j) { + std::cout << " [" << j << "] " << cube_batch[j].get() << "\n"; + } + + return cube_batch; + }; + + std::vector gen_new_batches(context& main_ctx) { + std::lock_guard lock(mtx); + std::vector cube_batches; + + size_t num_batches = 0; + while (num_batches < batch_size) { + expr_ref_vector cube_batch = cube_batch_pq(main_ctx); + cube_batches.push_back(cube_batch); + num_batches += cube_batch.size(); + } + return cube_batches; + } + }; + + BatchManager batch_manager(1); + batch_manager.batches = batch_manager.gen_new_batches(ctx); + // Thread scheduling loop while (true) { - vector threads(num_threads); + std::vector threads(num_threads); // Launch threads for (unsigned i = 0; i < num_threads; ++i) { // [&, i] is the lambda's capture clause: capture all variables by reference (&) except i, which is captured by value. - threads[i] = std::thread([&, i]() { worker_thread(i); }); + threads[i] = std::thread([&, i]() { + while (!done) { + auto next_batch = batch_manager.get_next_batch(ctx.m, *pms[i]); + if (next_batch.empty()) break; // No more work + + lbool r = worker_thread(i, next_batch); + } + }); + } // Wait for all threads to finish @@ -315,7 +421,7 @@ namespace smt { collect_units(); ++num_rounds; max_conflicts = (max_conflicts < thread_max_conflicts) ? 0 : (max_conflicts - thread_max_conflicts); - thread_max_conflicts *= 2; + thread_max_conflicts *= 2; } // Gather statistics from all solver contexts