mirror of
https://github.com/Z3Prover/z3
synced 2025-08-26 13:06:05 +00:00
process cubes as lists of individual lits
This commit is contained in:
parent
33c184f60b
commit
aac8787ac3
1 changed files with 96 additions and 63 deletions
|
@ -36,6 +36,7 @@ namespace smt {
|
|||
#else
|
||||
|
||||
#include <thread>
|
||||
#include <cassert>
|
||||
|
||||
namespace smt {
|
||||
|
||||
|
@ -138,7 +139,7 @@ namespace smt {
|
|||
};
|
||||
|
||||
auto cube_score = [&](context& ctx, expr_ref_vector& lasms, expr_ref& c) {
|
||||
std::vector<std::pair<expr_ref, double>> candidates;
|
||||
vector<std::pair<expr_ref, double>> candidates;
|
||||
unsigned k = 4; // Get top-k scoring literals
|
||||
ast_manager& m = ctx.get_manager();
|
||||
|
||||
|
@ -153,7 +154,7 @@ namespace smt {
|
|||
double score = ctx.get_score(lit);
|
||||
if (score == 0.0) continue;
|
||||
|
||||
candidates.emplace_back(expr_ref(e, m), score);
|
||||
candidates.push_back(std::make_pair(expr_ref(e, m), score));
|
||||
}
|
||||
|
||||
// Sort all candidate literals descending by score
|
||||
|
@ -216,7 +217,7 @@ namespace smt {
|
|||
std::mutex mux;
|
||||
|
||||
// Lambda defining the work each SMT thread performs
|
||||
auto worker_thread = [&](int i, expr_ref_vector cube_batch) {
|
||||
auto worker_thread = [&](int i, vector<expr_ref_vector>& cube_batch) {
|
||||
try {
|
||||
// Get thread-specific context and AST manager
|
||||
context& pctx = *pctxs[i];
|
||||
|
@ -224,20 +225,45 @@ namespace smt {
|
|||
|
||||
// Initialize local assumptions and cube
|
||||
expr_ref_vector lasms(pasms[i]);
|
||||
expr_ref cube_batch_disjunction = mk_or(cube_batch);
|
||||
// std::cout << "Thread " << i << " initial cube: " << mk_pp(cube_batch_disjunction, pm) << "\n";
|
||||
lasms.push_back(cube_batch_disjunction);
|
||||
|
||||
// Set the max conflict limit for this thread
|
||||
pctx.get_fparams().m_max_conflicts = std::min(thread_max_conflicts, max_conflicts);
|
||||
vector<lbool> results;
|
||||
for (expr_ref_vector& cube : cube_batch) {
|
||||
expr_ref_vector lasms_copy(lasms);
|
||||
|
||||
// Optional verbose logging
|
||||
IF_VERBOSE(1, verbose_stream() << "(smt.thread " << i;
|
||||
if (num_rounds > 0) verbose_stream() << " :round " << num_rounds;
|
||||
if (cube_batch_disjunction) verbose_stream() << " :cube " << mk_bounded_pp(cube_batch_disjunction, pm, 3);
|
||||
verbose_stream() << ")\n";);
|
||||
|
||||
auto cube_intersects_core = [&](expr* cube, const expr_ref_vector &core) {
|
||||
if (&cube.get_manager() != &pm) {
|
||||
std::cerr << "Manager mismatch on cube: " << mk_bounded_pp(mk_and(cube), pm, 3) << "\n";
|
||||
UNREACHABLE(); // or throw
|
||||
}
|
||||
|
||||
for (expr* cube_lit : cube) {
|
||||
lasms_copy.push_back(expr_ref(cube_lit, pm));
|
||||
}
|
||||
|
||||
// Set the max conflict limit for this thread
|
||||
pctx.get_fparams().m_max_conflicts = std::min(thread_max_conflicts, max_conflicts);
|
||||
|
||||
// Optional verbose logging
|
||||
IF_VERBOSE(1, verbose_stream() << "(smt.thread " << i;
|
||||
if (num_rounds > 0) verbose_stream() << " :round " << num_rounds;
|
||||
verbose_stream() << " :cube " << mk_bounded_pp(mk_and(cube), pm, 3);
|
||||
verbose_stream() << ")\n";);
|
||||
|
||||
lbool r = pctx.check(lasms_copy.size(), lasms_copy.data());
|
||||
std::cout << "Thread " << i << " finished cube " << mk_bounded_pp(mk_and(cube), pm, 3) << " with result: " << r << "\n";
|
||||
results.push_back(r);
|
||||
}
|
||||
|
||||
lbool r = l_false;
|
||||
for (lbool res : results) {
|
||||
if (res == l_true) {
|
||||
r = l_true;
|
||||
} else if (res == l_undef) {
|
||||
if (r == l_false)
|
||||
r = l_undef;
|
||||
}
|
||||
}
|
||||
|
||||
auto cube_intersects_core = [&](expr* cube, const expr_ref_vector &core) {
|
||||
expr_ref_vector cube_lits(pctx.m);
|
||||
flatten_and(cube, cube_lits);
|
||||
for (expr* lit : cube_lits)
|
||||
|
@ -246,26 +272,21 @@ namespace smt {
|
|||
return false;
|
||||
};
|
||||
|
||||
lbool r = pctx.check(lasms.size(), lasms.data());
|
||||
|
||||
// Handle results based on outcome and conflict count
|
||||
if (r == l_undef && pctx.m_num_conflicts >= max_conflicts)
|
||||
; // no-op, allow loop to continue
|
||||
else if (r == l_undef && pctx.m_num_conflicts >= thread_max_conflicts)
|
||||
return; // quit thread early
|
||||
// If cube was unsat and it's in the core, learn from it. i.e. a thread can be UNSAT because the cube c contradicted F. In this case learn the negation of the cube ¬c
|
||||
else if (r == l_false) {
|
||||
IF_VERBOSE(1, verbose_stream() << "(smt.thread " << i << " :learn cube batch " << mk_bounded_pp(cube_batch_disjunction, pm, 3) << ")" << " unsat_core: " << pctx.unsat_core() << "\n");
|
||||
bool learned_cube = false;
|
||||
for (expr* cube : cube_batch) { // iterate over each cube in the batch
|
||||
if (cube_intersects_core(cube, pctx.unsat_core())) {
|
||||
IF_VERBOSE(1, verbose_stream() << "(pruning cube: " << mk_bounded_pp(cube, pm, 3) << " given unsat core: " << pctx.unsat_core() << ")");
|
||||
pctx.assert_expr(mk_not(mk_and(pctx.unsat_core())));
|
||||
learned_cube = true;
|
||||
}
|
||||
}
|
||||
if (learned_cube) return;
|
||||
}
|
||||
// else if (r == l_false) {
|
||||
// // IF_VERBOSE(1, verbose_stream() << "(smt.thread " << i << " :learn cube batch " << mk_bounded_pp(cube, pm, 3) << ")" << " unsat_core: " << pctx.unsat_core() << ")");
|
||||
// for (expr* cube : cube_batch) { // iterate over each cube in the batch
|
||||
// if (cube_intersects_core(cube, pctx.unsat_core())) {
|
||||
// // IF_VERBOSE(1, verbose_stream() << "(pruning cube: " << mk_bounded_pp(cube, pm, 3) << " given unsat core: " << pctx.unsat_core() << ")");
|
||||
// pctx.assert_expr(mk_not(mk_and(pctx.unsat_core())));
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// Begin thread-safe update of shared result state
|
||||
bool first = false;
|
||||
|
@ -281,7 +302,7 @@ namespace smt {
|
|||
finished_id = i;
|
||||
result = r;
|
||||
}
|
||||
else if (!first) return; // nothing new to contribute
|
||||
else if (!first) return;
|
||||
}
|
||||
|
||||
// Cancel limits on other threads now that a result is known
|
||||
|
@ -311,33 +332,39 @@ namespace smt {
|
|||
|
||||
struct BatchManager {
|
||||
std::mutex mtx;
|
||||
std::vector<expr_ref_vector> batches;
|
||||
vector<vector<expr_ref_vector>> batches;
|
||||
unsigned batch_idx = 0;
|
||||
unsigned batch_size = 1;
|
||||
|
||||
BatchManager(unsigned batch_size) : batch_size(batch_size) {}
|
||||
|
||||
// translate the next SINGLE batch of batch_size cubes to the thread
|
||||
expr_ref_vector get_next_batch(
|
||||
vector<expr_ref_vector> get_next_batch(
|
||||
ast_manager &main_ctx_m,
|
||||
ast_manager &thread_m
|
||||
) {
|
||||
std::lock_guard<std::mutex> lock(mtx);
|
||||
expr_ref_vector cube_batch(thread_m); // ensure bound to thread manager
|
||||
vector<expr_ref_vector> cube_batch; // ensure bound to thread manager
|
||||
if (batch_idx >= batches.size()) return cube_batch;
|
||||
|
||||
for (expr* cube : batches[batch_idx]) {
|
||||
cube_batch.push_back(
|
||||
expr_ref(translate(cube, main_ctx_m, thread_m), thread_m)
|
||||
);
|
||||
vector<expr_ref_vector> next_batch = batches[batch_idx];
|
||||
|
||||
for (const expr_ref_vector& cube : next_batch) {
|
||||
expr_ref_vector translated_cube_lits(thread_m);
|
||||
for (expr* lit : cube) {
|
||||
// Translate each literal to the thread's manager
|
||||
translated_cube_lits.push_back(translate(lit, main_ctx_m, thread_m));
|
||||
}
|
||||
cube_batch.push_back(translated_cube_lits);
|
||||
}
|
||||
|
||||
++batch_idx;
|
||||
// std::cout << "Thread batch " << batch_idx - 1 << " size: " << cube_batch.size() << "\n";
|
||||
|
||||
return cube_batch;
|
||||
}
|
||||
|
||||
expr_ref_vector cube_batch_pq(context& ctx) {
|
||||
// returns a list (vector) of cubes, where each cube is an expr_ref_vector of literals
|
||||
vector<expr_ref_vector> cube_batch_pq(context& ctx) {
|
||||
unsigned k = 1; // generates 2^k cubes in the batch
|
||||
ast_manager& m = ctx.get_manager();
|
||||
|
||||
|
@ -364,62 +391,69 @@ namespace smt {
|
|||
unsigned num_lits = top_lits.size();
|
||||
unsigned num_cubes = 1 << num_lits; // 2^num_lits combinations
|
||||
|
||||
expr_ref_vector cube_batch(m);
|
||||
vector<expr_ref_vector> cube_batch;
|
||||
|
||||
for (unsigned mask = 0; mask < num_cubes; ++mask) {
|
||||
expr_ref_vector cube_conj(m);
|
||||
expr_ref_vector cube_lits(m);
|
||||
for (unsigned i = 0; i < num_lits; ++i) {
|
||||
expr_ref lit(top_lits[i].get(), m);
|
||||
if ((mask >> i) & 1)
|
||||
cube_conj.push_back(mk_not(lit));
|
||||
cube_lits.push_back(mk_not(lit));
|
||||
else
|
||||
cube_conj.push_back(lit);
|
||||
cube_lits.push_back(lit);
|
||||
}
|
||||
cube_batch.push_back(mk_and(cube_conj));
|
||||
cube_batch.push_back(cube_lits);
|
||||
}
|
||||
|
||||
// std::cout << "Cubes out:\n";
|
||||
// for (size_t j = 0; j < cube_batch.size(); ++j) {
|
||||
// std::cout << " [" << j << "] " << mk_pp(cube_batch[j].get(), m) << "\n";
|
||||
// }
|
||||
std::cout << "Cubes out:\n";
|
||||
for (size_t j = 0; j < cube_batch.size(); ++j) {
|
||||
std::cout << " [" << j << "]\n";
|
||||
for (size_t k = 0; k < cube_batch[j].size(); ++k) {
|
||||
std::cout << " [" << k << "] " << mk_pp(cube_batch[j][k].get(), m) << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
return cube_batch;
|
||||
};
|
||||
|
||||
std::vector<expr_ref_vector> gen_new_batches(context& main_ctx) {
|
||||
std::lock_guard<std::mutex> lock(mtx);
|
||||
std::vector<expr_ref_vector> cube_batches;
|
||||
|
||||
// returns a vector of new cubes batches. each cube batch is a vector of expr_ref_vector cubes
|
||||
vector<vector<expr_ref_vector>> gen_new_batches(context& main_ctx) {
|
||||
vector<vector<expr_ref_vector>> cube_batches;
|
||||
|
||||
// Get all cubes in the main context's manager
|
||||
expr_ref_vector all_cubes = cube_batch_pq(main_ctx);
|
||||
vector<expr_ref_vector> all_cubes = cube_batch_pq(main_ctx);
|
||||
|
||||
ast_manager &m = main_ctx.get_manager();
|
||||
|
||||
// Partition into batches
|
||||
for (unsigned start = 0; start < all_cubes.size(); start += batch_size) {
|
||||
expr_ref_vector batch(m);
|
||||
vector<expr_ref_vector> batch;
|
||||
|
||||
unsigned end = std::min(start + batch_size, all_cubes.size());
|
||||
for (unsigned j = start; j < end; ++j) {
|
||||
batch.push_back(all_cubes[j].get());
|
||||
batch.push_back(all_cubes[j]);
|
||||
}
|
||||
|
||||
cube_batches.push_back(std::move(batch));
|
||||
cube_batches.push_back(batch);
|
||||
}
|
||||
batch_idx = 0; // Reset index for next round
|
||||
return cube_batches;
|
||||
}
|
||||
|
||||
void check_for_new_batches(context& main_ctx) {
|
||||
std::lock_guard<std::mutex> lock(mtx);
|
||||
if (batch_idx >= batches.size()) {
|
||||
batches = gen_new_batches(main_ctx);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
BatchManager batch_manager(2);
|
||||
BatchManager batch_manager(1);
|
||||
|
||||
// Thread scheduling loop
|
||||
while (true) {
|
||||
if (batch_manager.batch_idx >= batch_manager.batches.size()) {
|
||||
batch_manager.batches = batch_manager.gen_new_batches(ctx);
|
||||
}
|
||||
|
||||
std::vector<std::thread> threads(num_threads);
|
||||
vector<std::thread> threads(num_threads);
|
||||
batch_manager.check_for_new_batches(ctx);
|
||||
|
||||
// Launch threads
|
||||
for (unsigned i = 0; i < num_threads; ++i) {
|
||||
|
@ -432,7 +466,6 @@ namespace smt {
|
|||
worker_thread(i, next_batch);
|
||||
}
|
||||
});
|
||||
|
||||
}
|
||||
|
||||
// Wait for all threads to finish
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue