mirror of
https://github.com/Z3Prover/z3
synced 2025-10-24 00:14:35 +00:00
set up worker thread batch manager for multithreaded batch cubes paradigm, need to debug as I am getting segfault still
This commit is contained in:
parent
36fbee3a2d
commit
2c188a525e
2 changed files with 132 additions and 20 deletions
|
@ -51,6 +51,7 @@ Revision History:
|
|||
#include "solver/progress_callback.h"
|
||||
#include "solver/assertions/asserted_formulas.h"
|
||||
#include "smt/priority_queue.h"
|
||||
#include "util/dlist.h"
|
||||
#include <tuple>
|
||||
|
||||
// there is a significant space overhead with allocating 1000+ contexts in
|
||||
|
@ -191,6 +192,11 @@ namespace smt {
|
|||
svector<bool_var_data> m_bdata; //!< mapping bool_var -> data
|
||||
svector<double> m_activity;
|
||||
updatable_priority_queue::priority_queue<bool_var, double> m_pq_scores;
|
||||
struct lit_node : dll_base<lit_node> {
|
||||
literal lit;
|
||||
lit_node(literal l) : lit(l) { init(this); }
|
||||
};
|
||||
lit_node* m_dll_lits;
|
||||
svector<std::array<double, 2>> m_lit_scores;
|
||||
clause_vector m_aux_clauses;
|
||||
clause_vector m_lemmas;
|
||||
|
|
|
@ -102,8 +102,8 @@ namespace smt {
|
|||
}
|
||||
};
|
||||
|
||||
auto cube_batch_pq = [&](context& ctx, expr_ref_vector& lasms, expr_ref& c) {
|
||||
unsigned k = 4; // Number of top literals you want
|
||||
auto cube_pq = [&](context& ctx, expr_ref_vector& lasms, expr_ref& c) {
|
||||
unsigned k = 3; // Number of top literals you want
|
||||
ast_manager& m = ctx.get_manager();
|
||||
|
||||
// Get the entire fixed-size priority queue (it's not that big)
|
||||
|
@ -132,7 +132,7 @@ namespace smt {
|
|||
lasms.push_back(c);
|
||||
};
|
||||
|
||||
auto cube_batch = [&](context& ctx, expr_ref_vector& lasms, expr_ref& c) {
|
||||
auto cube_score = [&](context& ctx, expr_ref_vector& lasms, expr_ref& c) {
|
||||
std::vector<std::pair<expr_ref, double>> candidates;
|
||||
unsigned k = 4; // Get top-k scoring literals
|
||||
ast_manager& m = ctx.get_manager();
|
||||
|
@ -201,7 +201,7 @@ namespace smt {
|
|||
for (unsigned j = unit_lim[i]; j < sz; ++j) {
|
||||
expr_ref src(ctx.m), dst(pctx.m);
|
||||
dst = tr(unit_trail.get(j));
|
||||
pctx.assert_expr(dst);
|
||||
pctx.assert_expr(dst); // Assert that the conjunction of the assumptions in this unsat core is not satisfiable — prune it from future search
|
||||
}
|
||||
unit_lim[i] = pctx.assigned_literals().size();
|
||||
}
|
||||
|
@ -211,43 +211,47 @@ namespace smt {
|
|||
std::mutex mux;
|
||||
|
||||
// Lambda defining the work each SMT thread performs
|
||||
auto worker_thread = [&](int i) {
|
||||
auto worker_thread = [&](int i, expr_ref_vector cube_batch) {
|
||||
try {
|
||||
std::cout << "Starting thread " << i <<"\n";
|
||||
// Get thread-specific context and AST manager
|
||||
context& pctx = *pctxs[i];
|
||||
ast_manager& pm = *pms[i];
|
||||
|
||||
// Initialize local assumptions and cube
|
||||
expr_ref_vector lasms(pasms[i]);
|
||||
expr_ref c(pm);
|
||||
expr_ref c(mk_or(cube_batch), pm);
|
||||
lasms.push_back(c); // <-- add cube to assumptions
|
||||
|
||||
// Set the max conflict limit for this thread
|
||||
pctx.get_fparams().m_max_conflicts = std::min(thread_max_conflicts, max_conflicts);
|
||||
|
||||
// Periodically generate cubes based on frequency
|
||||
if (num_rounds > 0 && (num_rounds % pctx.get_fparams().m_threads_cube_frequency) == 0)
|
||||
cube_batch(pctx, lasms, c);
|
||||
|
||||
// Optional verbose logging
|
||||
IF_VERBOSE(1, verbose_stream() << "(smt.thread " << i;
|
||||
if (num_rounds > 0) verbose_stream() << " :round " << num_rounds;
|
||||
if (c) verbose_stream() << " :cube " << mk_bounded_pp(c, pm, 3);
|
||||
verbose_stream() << ")\n";);
|
||||
|
||||
// Check satisfiability of assumptions
|
||||
auto intersects = [&](const expr_ref_vector &a, const expr_ref_vector &b) {
|
||||
for (expr *e : a) {
|
||||
if (b.contains(e)) return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
lbool r = pctx.check(lasms.size(), lasms.data());
|
||||
|
||||
// Handle results based on outcome and conflict count
|
||||
if (r == l_undef && pctx.m_num_conflicts >= max_conflicts)
|
||||
; // no-op, allow loop to continue
|
||||
else if (r == l_undef && pctx.m_num_conflicts >= thread_max_conflicts)
|
||||
return; // quit thread early
|
||||
|
||||
// If cube was unsat and it's in the core, learn from it
|
||||
else if (r == l_false && pctx.unsat_core().contains(c)) {
|
||||
return r; // quit thread early
|
||||
// If cube was unsat and it's in the core, learn from it. i.e. a thread can be UNSAT because the cube c contradicted F. In this case learn the negation of the cube ¬c
|
||||
// TAKE THE INTERSECTION INSTEAD OF CHECKING MEMBERSHIP, SEE WHITEBOARD NOTES
|
||||
else if (r == l_false && intersects(cube_batch, pctx.unsat_core())) { // pctx.unsat_core().contains(c)) { THIS IS THE VERSION FOR SINGLE LITERAL CUBES
|
||||
IF_VERBOSE(1, verbose_stream() << "(smt.thread " << i << " :learn " << mk_bounded_pp(c, pm, 3) << ")");
|
||||
pctx.assert_expr(mk_not(mk_and(pctx.unsat_core())));
|
||||
return;
|
||||
return r;
|
||||
}
|
||||
|
||||
// Begin thread-safe update of shared result state
|
||||
|
@ -264,7 +268,7 @@ namespace smt {
|
|||
finished_id = i;
|
||||
result = r;
|
||||
}
|
||||
else if (!first) return; // nothing new to contribute
|
||||
else if (!first) return r; // nothing new to contribute
|
||||
}
|
||||
|
||||
// Cancel limits on other threads now that a result is known
|
||||
|
@ -272,6 +276,7 @@ namespace smt {
|
|||
if (m != &pm) m->limit().cancel();
|
||||
}
|
||||
|
||||
return r;
|
||||
} catch (z3_error & err) {
|
||||
if (finished_id == UINT_MAX) {
|
||||
error_code = err.error_code();
|
||||
|
@ -291,16 +296,117 @@ namespace smt {
|
|||
done = true;
|
||||
}
|
||||
}
|
||||
return l_undef; // Return undef if an exception occurred
|
||||
};
|
||||
|
||||
struct BatchManager {
|
||||
std::mutex mtx;
|
||||
std::vector<expr_ref_vector> batches;
|
||||
size_t batch_idx = 0;
|
||||
size_t batch_size = 1; // num batches
|
||||
|
||||
BatchManager(size_t batch_size) : batch_size(batch_size) {}
|
||||
|
||||
// translate the next SINGLE batch of batch_size cubes to the thread
|
||||
expr_ref_vector get_next_batch(
|
||||
ast_manager &main_ctx_m,
|
||||
ast_manager &thread_m
|
||||
) {
|
||||
std::lock_guard<std::mutex> lock(mtx);
|
||||
expr_ref_vector cube_batch(thread_m); // ensure bound to thread manager
|
||||
|
||||
for (expr* cube : batches[batch_idx]) {
|
||||
cube_batch.push_back(
|
||||
expr_ref(translate(cube, main_ctx_m, thread_m), thread_m)
|
||||
);
|
||||
}
|
||||
|
||||
++batch_idx;
|
||||
return cube_batch;
|
||||
}
|
||||
|
||||
expr_ref_vector cube_batch_pq(context& ctx) {
|
||||
unsigned k = 3; // generates 2^k cubes in the batch
|
||||
ast_manager& m = ctx.get_manager();
|
||||
|
||||
auto candidates = ctx.m_pq_scores.get_heap();
|
||||
std::sort(candidates.begin(), candidates.end(),
|
||||
[](const auto& a, const auto& b) { return a.priority > b.priority; });
|
||||
|
||||
expr_ref_vector top_lits(m);
|
||||
for (const auto& node : candidates) {
|
||||
if (ctx.get_assignment(node.key) != l_undef) continue;
|
||||
|
||||
expr* e = ctx.bool_var2expr(node.key);
|
||||
if (!e) continue;
|
||||
|
||||
top_lits.push_back(expr_ref(e, m));
|
||||
if (top_lits.size() >= k) break;
|
||||
}
|
||||
|
||||
std::cout << "Top lits:\n";
|
||||
for (size_t j = 0; j < top_lits.size(); ++j) {
|
||||
std::cout << " [" << j << "] " << top_lits[j].get() << "\n";
|
||||
}
|
||||
|
||||
unsigned num_lits = top_lits.size();
|
||||
unsigned num_cubes = 1 << num_lits; // 2^num_lits combinations
|
||||
|
||||
expr_ref_vector cube_batch(m);
|
||||
|
||||
for (unsigned mask = 0; mask < num_cubes; ++mask) {
|
||||
expr_ref_vector cube_conj(m);
|
||||
for (unsigned i = 0; i < num_lits; ++i) {
|
||||
expr_ref lit(top_lits[i].get(), m);
|
||||
if ((mask >> i) & 1)
|
||||
cube_conj.push_back(mk_not(lit));
|
||||
else
|
||||
cube_conj.push_back(lit);
|
||||
}
|
||||
cube_batch.push_back(mk_and(cube_conj));
|
||||
}
|
||||
|
||||
std::cout << "Cubes out:\n";
|
||||
for (size_t j = 0; j < cube_batch.size(); ++j) {
|
||||
std::cout << " [" << j << "] " << cube_batch[j].get() << "\n";
|
||||
}
|
||||
|
||||
return cube_batch;
|
||||
};
|
||||
|
||||
std::vector<expr_ref_vector> gen_new_batches(context& main_ctx) {
|
||||
std::lock_guard<std::mutex> lock(mtx);
|
||||
std::vector<expr_ref_vector> cube_batches;
|
||||
|
||||
size_t num_batches = 0;
|
||||
while (num_batches < batch_size) {
|
||||
expr_ref_vector cube_batch = cube_batch_pq(main_ctx);
|
||||
cube_batches.push_back(cube_batch);
|
||||
num_batches += cube_batch.size();
|
||||
}
|
||||
return cube_batches;
|
||||
}
|
||||
};
|
||||
|
||||
BatchManager batch_manager(1);
|
||||
batch_manager.batches = batch_manager.gen_new_batches(ctx);
|
||||
|
||||
// Thread scheduling loop
|
||||
while (true) {
|
||||
vector<std::thread> threads(num_threads);
|
||||
std::vector<std::thread> threads(num_threads);
|
||||
|
||||
// Launch threads
|
||||
for (unsigned i = 0; i < num_threads; ++i) {
|
||||
// [&, i] is the lambda's capture clause: capture all variables by reference (&) except i, which is captured by value.
|
||||
threads[i] = std::thread([&, i]() { worker_thread(i); });
|
||||
threads[i] = std::thread([&, i]() {
|
||||
while (!done) {
|
||||
auto next_batch = batch_manager.get_next_batch(ctx.m, *pms[i]);
|
||||
if (next_batch.empty()) break; // No more work
|
||||
|
||||
lbool r = worker_thread(i, next_batch);
|
||||
}
|
||||
});
|
||||
|
||||
}
|
||||
|
||||
// Wait for all threads to finish
|
||||
|
@ -315,7 +421,7 @@ namespace smt {
|
|||
collect_units();
|
||||
++num_rounds;
|
||||
max_conflicts = (max_conflicts < thread_max_conflicts) ? 0 : (max_conflicts - thread_max_conflicts);
|
||||
thread_max_conflicts *= 2;
|
||||
thread_max_conflicts *= 2;
|
||||
}
|
||||
|
||||
// Gather statistics from all solver contexts
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue